Skip to content

Commit d8a9435

Browse files
committed
Add multi-GPU parallelization for dataset generation
* Added process_qa_chunk() function to handle parallel processing on individual GPUs * Modified make_dataset() to distribute QA pairs across available GPUs using multiprocessing * Each GPU process creates its own embedding model and backend directory * Added per-GPU progress bars for real-time monitoring
1 parent 76ff9c2 commit d8a9435

File tree

2 files changed

+173
-42
lines changed

2 files changed

+173
-42
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
8484
- Added a check that confirms that custom edge types of `NumNeighbors` actually exist in the graph ([#9807](https://github.com/pyg-team/pytorch_geometric/pull/9807))
8585
- Automatic num_params in LLM + update `GRetriever` default llm ([#9938](https://github.com/pyg-team/pytorch_geometric/pull/9938))
8686
- Updated calls to NumPy's deprecated `np.in1d` to `np.isin` ([#10283](https://github.com/pyg-team/pytorch_geometric/pull/10283))
87+
- Added multi-GPU parallelization using multiprocessing for accelerated RAG dataset generation in `txt2kg_rag.py`, distributing QA pair processing across GPUs ([#10474](https://github.com/pyg-team/pytorch_geometric/pull/10474))
8788

8889
### Deprecated
8990

examples/llm/txt2kg_rag.py

Lines changed: 172 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import argparse
22
import gc
33
import json
4+
import multiprocessing as mp
45
import os
56
import random
67
import re
@@ -424,6 +425,115 @@ def update_data_lists(args, data_lists):
424425
return data_lists
425426

426427

428+
def process_qa_chunk(chunk_data):
429+
"""Process a chunk of QA pairs on a specific GPU."""
430+
(qa_chunk, gpu_id, fs_path, gs_path, graph_data, triples, model_name,
431+
dataset_path) = chunk_data
432+
433+
# Set the GPU for this process
434+
torch.cuda.set_device(gpu_id)
435+
device = torch.device(f"cuda:{gpu_id}")
436+
print(f"Process {mp.current_process().pid} using GPU {gpu_id} "
437+
f"to process {len(qa_chunk)} QA pairs")
438+
sys.stdout.flush() # Force flush to see logs immediately
439+
440+
# Create embedding model for this process
441+
sent_trans_batch_size = 256
442+
model = SentenceTransformer(
443+
model_name=ENCODER_MODEL_NAME_DEFAULT).to(device).eval()
444+
445+
# Load graph and feature stores with unique path for each GPU
446+
from torch_geometric.llm.utils.backend_utils import (
447+
create_remote_backend_from_graph_data,
448+
)
449+
backend_path = f"backend_gpu_{gpu_id}"
450+
fs, gs = create_remote_backend_from_graph_data(
451+
graph_data=graph_data, path=backend_path,
452+
graph_db=NeighborSamplingRAGGraphStore,
453+
feature_db=KNNRAGFeatureStore).load()
454+
455+
# Create subgraph filter for this process
456+
subgraph_filter = make_pcst_filter(
457+
triples,
458+
model,
459+
topk=5, # nodes
460+
topk_e=5, # edges
461+
cost_e=.5, # edge cost
462+
num_clusters=10) # num clusters
463+
464+
# Load document retriever
465+
model_kwargs = {
466+
"output_device": device,
467+
"batch_size": int(sent_trans_batch_size / 4),
468+
}
469+
doc_retriever_path = os.path.join(dataset_path, "document_retriever.pt")
470+
vector_retriever = DocumentRetriever.load(doc_retriever_path,
471+
model=model.encode,
472+
model_kwargs=model_kwargs)
473+
474+
# Query loader config
475+
fanout = 100
476+
num_hops = 2
477+
query_loader_config = {
478+
"k_nodes": 1024,
479+
"num_neighbors": [fanout] * num_hops,
480+
"encoder_model": model,
481+
}
482+
483+
# Create query loader for this process
484+
query_loader = RAGQueryLoader(graph_data=(fs, gs),
485+
subgraph_filter=subgraph_filter,
486+
vector_retriever=vector_retriever,
487+
config=query_loader_config)
488+
489+
# Process the chunk
490+
chunk_data_list = []
491+
chunk_triple_sizes = []
492+
max_answer_len = 0
493+
494+
# Add tqdm progress bar for this GPU
495+
from tqdm import tqdm
496+
497+
# Create a progress bar for this GPU with position offset to avoid overlap
498+
gpu_pbar = tqdm(qa_chunk, desc=f"GPU {gpu_id}", position=gpu_id,
499+
leave=True, total=len(qa_chunk))
500+
501+
for idx, data_point in enumerate(gpu_pbar):
502+
if data_point["is_impossible"]:
503+
continue
504+
q = data_point["question"]
505+
a = data_point["answer"]
506+
max_answer_len = max(len(a), max_answer_len)
507+
508+
# Update progress bar description with current item
509+
gpu_pbar.set_postfix({'QA': idx + 1})
510+
511+
subgraph = query_loader.query(q)
512+
subgraph.label = a
513+
chunk_data_list.append(subgraph)
514+
chunk_triple_sizes.append(len(subgraph.triples))
515+
516+
gpu_pbar.close()
517+
518+
# Log completion
519+
print(
520+
f"GPU {gpu_id}: Completed processing {len(chunk_data_list)} QA pairs!")
521+
sys.stdout.flush()
522+
523+
# Clean up
524+
del model
525+
del query_loader
526+
gc.collect()
527+
torch.cuda.empty_cache()
528+
529+
# Clean up GPU-specific backend directory
530+
import shutil
531+
if os.path.exists(backend_path):
532+
shutil.rmtree(backend_path)
533+
534+
return chunk_data_list, chunk_triple_sizes, max_answer_len
535+
536+
427537
def make_dataset(args):
428538
qa_pairs, context_docs = get_data(args)
429539
print("Number of Docs in our VectorDB =", len(context_docs))
@@ -447,6 +557,13 @@ def make_dataset(args):
447557
triples = index_kg(args, context_docs)
448558

449559
print("Number of triples in our GraphDB =", len(triples))
560+
561+
# Determine number of GPUs to use
562+
num_gpus = (args.num_gpus
563+
if args.num_gpus is not None else torch.cuda.device_count())
564+
num_gpus = min(num_gpus, torch.cuda.device_count())
565+
print(f"Using {num_gpus} GPUs for dataset generation")
566+
450567
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
451568

452569
# creating the embedding model
@@ -496,52 +613,65 @@ def make_dataset(args):
496613
model_kwargs=model_kwargs)
497614
vector_retriever.save(doc_retriever_path)
498615

499-
subgraph_filter = make_pcst_filter(
500-
triples,
501-
model,
502-
topk=5, # nodes
503-
topk_e=5, # edges
504-
cost_e=.5, # edge cost
505-
num_clusters=10) # num clusters
506-
507-
# number of neighbors for each seed node selected by KNN
508-
fanout = 100
509-
# number of hops for neighborsampling
510-
num_hops = 2
511-
512-
query_loader_config = {
513-
"k_nodes": 1024, # k for Graph KNN
514-
"num_neighbors": [fanout] * num_hops, # number of sampled neighbors
515-
"encoder_model": model,
516-
}
517-
518-
# GraphDB retrieval done with KNN+NeighborSampling+PCST
519-
# PCST = Prize Collecting Steiner Tree
520-
# VectorDB retrieval just vanilla vector RAG
616+
# Distribute QA pairs across GPUs for parallel subgraph retrieval
617+
# Each process loads its own embedding model, graph/feature stores,
618+
# and query loader
521619
print("Now to retrieve context for each query from "
522-
"our Vector and Graph DBs...")
523-
524-
query_loader = RAGQueryLoader(graph_data=(fs, gs),
525-
subgraph_filter=subgraph_filter,
526-
vector_retriever=vector_retriever,
527-
config=query_loader_config)
528-
529-
# pre-process the dataset
620+
"our Vector and Graph DBs using multiprocessing...")
621+
622+
# Filter out impossible QA pairs first
623+
valid_qa_pairs = [dp for dp in qa_pairs if not dp["is_impossible"]]
624+
625+
# Split QA pairs into chunks for each GPU
626+
chunk_size = len(valid_qa_pairs) // num_gpus
627+
qa_chunks = []
628+
for i in range(num_gpus):
629+
start_idx = i * chunk_size
630+
end_idx = start_idx + chunk_size if i < num_gpus - 1 else len(
631+
valid_qa_pairs)
632+
qa_chunks.append(valid_qa_pairs[start_idx:end_idx])
633+
634+
# Prepare data for multiprocessing
635+
chunk_data_list = []
636+
for i, chunk in enumerate(qa_chunks):
637+
chunk_data_list.append((
638+
chunk, # QA pairs chunk
639+
i % torch.cuda.device_count(), # GPU ID
640+
"backend/fs", # fs path
641+
"backend/gs", # gs path
642+
graph_data, # graph data
643+
triples, # triples
644+
ENCODER_MODEL_NAME_DEFAULT, # model name
645+
args.dataset # dataset path
646+
))
647+
648+
# Use multiprocessing to process chunks in parallel
649+
mp.set_start_method('spawn', force=True)
650+
651+
print(f"\nProcessing {len(valid_qa_pairs)} QA pairs "
652+
f"across {num_gpus} GPUs...")
653+
print("=" * 60)
654+
655+
with mp.Pool(processes=num_gpus) as pool:
656+
results = list(pool.imap(process_qa_chunk, chunk_data_list))
657+
658+
# Clean up main backend directory if it exists
659+
import shutil
660+
if os.path.exists("backend"):
661+
shutil.rmtree("backend")
662+
663+
# Combine results from all processes
530664
total_data_list = []
531665
extracted_triple_sizes = []
532666
global max_chars_in_train_answer
533-
for data_point in tqdm(qa_pairs, desc="Building un-split dataset"):
534-
if data_point["is_impossible"]:
535-
continue
536-
QA_pair = (data_point["question"], data_point["answer"])
537-
q = QA_pair[0]
538-
max_chars_in_train_answer = max(len(QA_pair[1]),
539-
max_chars_in_train_answer)
540-
# (TODO) make this batch queries for retrieving w/ CuVS+CuGraph
541-
subgraph = query_loader.query(q)
542-
subgraph.label = QA_pair[1]
543-
total_data_list.append(subgraph)
544-
extracted_triple_sizes.append(len(subgraph.triples))
667+
max_chars_in_train_answer = 0
668+
669+
for chunk_data_list, chunk_triple_sizes, chunk_max_answer_len in results:
670+
total_data_list.extend(chunk_data_list)
671+
extracted_triple_sizes.extend(chunk_triple_sizes)
672+
max_chars_in_train_answer = max(max_chars_in_train_answer,
673+
chunk_max_answer_len)
674+
545675
random.shuffle(total_data_list)
546676

547677
# stats

0 commit comments

Comments
 (0)