Skip to content

Commit 43ca4f0

Browse files
committed
Add multi-GPU parallelization for dataset generation
* Added process_qa_chunk() function to handle parallel processing on individual GPUs * Modified make_dataset() to distribute QA pairs across available GPUs using multiprocessing * Each GPU process creates its own embedding model and backend directory * Added per-GPU progress bars for real-time monitoring
1 parent 76ff9c2 commit 43ca4f0

File tree

2 files changed

+170
-42
lines changed

2 files changed

+170
-42
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
8484
- Added a check that confirms that custom edge types of `NumNeighbors` actually exist in the graph ([#9807](https://github.com/pyg-team/pytorch_geometric/pull/9807))
8585
- Automatic num_params in LLM + update `GRetriever` default llm ([#9938](https://github.com/pyg-team/pytorch_geometric/pull/9938))
8686
- Updated calls to NumPy's deprecated `np.in1d` to `np.isin` ([#10283](https://github.com/pyg-team/pytorch_geometric/pull/10283))
87+
- Added multi-GPU parallelization using multiprocessing for accelerated RAG dataset generation in `txt2kg_rag.py`, distributing QA pair processing across GPUs ([#10474](https://github.com/pyg-team/pytorch_geometric/pull/10474))
8788

8889
### Deprecated
8990

examples/llm/txt2kg_rag.py

Lines changed: 169 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import argparse
22
import gc
33
import json
4+
import multiprocessing as mp
45
import os
56
import random
67
import re
@@ -424,6 +425,112 @@ def update_data_lists(args, data_lists):
424425
return data_lists
425426

426427

428+
def process_qa_chunk(chunk_data):
429+
"""Process a chunk of QA pairs on a specific GPU."""
430+
(qa_chunk, gpu_id, fs_path, gs_path, graph_data, triples, model_name,
431+
dataset_path) = chunk_data
432+
433+
# Set the GPU for this process
434+
torch.cuda.set_device(gpu_id)
435+
device = torch.device(f"cuda:{gpu_id}")
436+
print(f"Process {mp.current_process().pid} using GPU {gpu_id} "
437+
f"to process {len(qa_chunk)} QA pairs")
438+
sys.stdout.flush() # Force flush to see logs immediately
439+
440+
# Create embedding model for this process
441+
sent_trans_batch_size = 256
442+
model = SentenceTransformer(
443+
model_name=ENCODER_MODEL_NAME_DEFAULT).to(device).eval()
444+
445+
# Load graph and feature stores with unique path for each GPU
446+
backend_path = f"backend_gpu_{gpu_id}"
447+
fs, gs = create_remote_backend_from_graph_data(
448+
graph_data=graph_data, path=backend_path,
449+
graph_db=NeighborSamplingRAGGraphStore,
450+
feature_db=KNNRAGFeatureStore).load()
451+
452+
# Create subgraph filter for this process
453+
subgraph_filter = make_pcst_filter(
454+
triples,
455+
model,
456+
topk=5, # nodes
457+
topk_e=5, # edges
458+
cost_e=.5, # edge cost
459+
num_clusters=10) # num clusters
460+
461+
# Load document retriever
462+
model_kwargs = {
463+
"output_device": device,
464+
"batch_size": int(sent_trans_batch_size / 4),
465+
}
466+
doc_retriever_path = os.path.join(dataset_path, "document_retriever.pt")
467+
vector_retriever = DocumentRetriever.load(doc_retriever_path,
468+
model=model.encode,
469+
model_kwargs=model_kwargs)
470+
471+
# Query loader config
472+
fanout = 100
473+
num_hops = 2
474+
query_loader_config = {
475+
"k_nodes": 1024,
476+
"num_neighbors": [fanout] * num_hops,
477+
"encoder_model": model,
478+
}
479+
480+
# Create query loader for this process
481+
query_loader = RAGQueryLoader(graph_data=(fs, gs),
482+
subgraph_filter=subgraph_filter,
483+
vector_retriever=vector_retriever,
484+
config=query_loader_config)
485+
486+
# Process the chunk
487+
chunk_data_list = []
488+
chunk_triple_sizes = []
489+
max_answer_len = 0
490+
491+
# Add tqdm progress bar for this GPU
492+
from tqdm import tqdm
493+
494+
# Create a progress bar for this GPU with position offset to avoid overlap
495+
gpu_pbar = tqdm(qa_chunk, desc=f"GPU {gpu_id}", position=gpu_id,
496+
leave=True, total=len(qa_chunk))
497+
498+
for idx, data_point in enumerate(gpu_pbar):
499+
if data_point["is_impossible"]:
500+
continue
501+
q = data_point["question"]
502+
a = data_point["answer"]
503+
max_answer_len = max(len(a), max_answer_len)
504+
505+
# Update progress bar description with current item
506+
gpu_pbar.set_postfix({'QA': idx + 1})
507+
508+
subgraph = query_loader.query(q)
509+
subgraph.label = a
510+
chunk_data_list.append(subgraph)
511+
chunk_triple_sizes.append(len(subgraph.triples))
512+
513+
gpu_pbar.close()
514+
515+
# Log completion
516+
print(
517+
f"GPU {gpu_id}: Completed processing {len(chunk_data_list)} QA pairs!")
518+
sys.stdout.flush()
519+
520+
# Clean up
521+
del model
522+
del query_loader
523+
gc.collect()
524+
torch.cuda.empty_cache()
525+
526+
# Clean up GPU-specific backend directory
527+
import shutil
528+
if os.path.exists(backend_path):
529+
shutil.rmtree(backend_path)
530+
531+
return chunk_data_list, chunk_triple_sizes, max_answer_len
532+
533+
427534
def make_dataset(args):
428535
qa_pairs, context_docs = get_data(args)
429536
print("Number of Docs in our VectorDB =", len(context_docs))
@@ -447,6 +554,13 @@ def make_dataset(args):
447554
triples = index_kg(args, context_docs)
448555

449556
print("Number of triples in our GraphDB =", len(triples))
557+
558+
# Determine number of GPUs to use
559+
num_gpus = (args.num_gpus
560+
if args.num_gpus is not None else torch.cuda.device_count())
561+
num_gpus = min(num_gpus, torch.cuda.device_count())
562+
print(f"Using {num_gpus} GPUs for dataset generation")
563+
450564
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
451565

452566
# creating the embedding model
@@ -496,52 +610,65 @@ def make_dataset(args):
496610
model_kwargs=model_kwargs)
497611
vector_retriever.save(doc_retriever_path)
498612

499-
subgraph_filter = make_pcst_filter(
500-
triples,
501-
model,
502-
topk=5, # nodes
503-
topk_e=5, # edges
504-
cost_e=.5, # edge cost
505-
num_clusters=10) # num clusters
506-
507-
# number of neighbors for each seed node selected by KNN
508-
fanout = 100
509-
# number of hops for neighborsampling
510-
num_hops = 2
511-
512-
query_loader_config = {
513-
"k_nodes": 1024, # k for Graph KNN
514-
"num_neighbors": [fanout] * num_hops, # number of sampled neighbors
515-
"encoder_model": model,
516-
}
517-
518-
# GraphDB retrieval done with KNN+NeighborSampling+PCST
519-
# PCST = Prize Collecting Steiner Tree
520-
# VectorDB retrieval just vanilla vector RAG
613+
# Distribute QA pairs across GPUs for parallel subgraph retrieval
614+
# Each process loads its own embedding model, graph/feature stores,
615+
# and query loader
521616
print("Now to retrieve context for each query from "
522-
"our Vector and Graph DBs...")
523-
524-
query_loader = RAGQueryLoader(graph_data=(fs, gs),
525-
subgraph_filter=subgraph_filter,
526-
vector_retriever=vector_retriever,
527-
config=query_loader_config)
528-
529-
# pre-process the dataset
617+
"our Vector and Graph DBs using multiprocessing...")
618+
619+
# Filter out impossible QA pairs first
620+
valid_qa_pairs = [dp for dp in qa_pairs if not dp["is_impossible"]]
621+
622+
# Split QA pairs into chunks for each GPU
623+
chunk_size = len(valid_qa_pairs) // num_gpus
624+
qa_chunks = []
625+
for i in range(num_gpus):
626+
start_idx = i * chunk_size
627+
end_idx = start_idx + chunk_size if i < num_gpus - 1 else len(
628+
valid_qa_pairs)
629+
qa_chunks.append(valid_qa_pairs[start_idx:end_idx])
630+
631+
# Prepare data for multiprocessing
632+
chunk_data_list = []
633+
for i, chunk in enumerate(qa_chunks):
634+
chunk_data_list.append((
635+
chunk, # QA pairs chunk
636+
i % torch.cuda.device_count(), # GPU ID
637+
"backend/fs", # fs path
638+
"backend/gs", # gs path
639+
graph_data, # graph data
640+
triples, # triples
641+
ENCODER_MODEL_NAME_DEFAULT, # model name
642+
args.dataset # dataset path
643+
))
644+
645+
# Use multiprocessing to process chunks in parallel
646+
mp.set_start_method('spawn', force=True)
647+
648+
print(f"\nProcessing {len(valid_qa_pairs)} QA pairs "
649+
f"across {num_gpus} GPUs...")
650+
print("=" * 60)
651+
652+
with mp.Pool(processes=num_gpus) as pool:
653+
results = list(pool.imap(process_qa_chunk, chunk_data_list))
654+
655+
# Clean up main backend directory if it exists
656+
import shutil
657+
if os.path.exists("backend"):
658+
shutil.rmtree("backend")
659+
660+
# Combine results from all processes
530661
total_data_list = []
531662
extracted_triple_sizes = []
532663
global max_chars_in_train_answer
533-
for data_point in tqdm(qa_pairs, desc="Building un-split dataset"):
534-
if data_point["is_impossible"]:
535-
continue
536-
QA_pair = (data_point["question"], data_point["answer"])
537-
q = QA_pair[0]
538-
max_chars_in_train_answer = max(len(QA_pair[1]),
539-
max_chars_in_train_answer)
540-
# (TODO) make this batch queries for retrieving w/ CuVS+CuGraph
541-
subgraph = query_loader.query(q)
542-
subgraph.label = QA_pair[1]
543-
total_data_list.append(subgraph)
544-
extracted_triple_sizes.append(len(subgraph.triples))
664+
max_chars_in_train_answer = 0
665+
666+
for chunk_data_list, chunk_triple_sizes, chunk_max_answer_len in results:
667+
total_data_list.extend(chunk_data_list)
668+
extracted_triple_sizes.extend(chunk_triple_sizes)
669+
max_chars_in_train_answer = max(max_chars_in_train_answer,
670+
chunk_max_answer_len)
671+
545672
random.shuffle(total_data_list)
546673

547674
# stats

0 commit comments

Comments
 (0)