Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added a check that confirms that custom edge types of `NumNeighbors` actually exist in the graph ([#9807](https://github.com/pyg-team/pytorch_geometric/pull/9807))
- Automatic num_params in LLM + update `GRetriever` default llm ([#9938](https://github.com/pyg-team/pytorch_geometric/pull/9938))
- Updated calls to NumPy's deprecated `np.in1d` to `np.isin` ([#10283](https://github.com/pyg-team/pytorch_geometric/pull/10283))
- Added multi-GPU parallelization using multiprocessing for accelerated RAG dataset generation in `txt2kg_rag.py`, distributing QA pair processing across GPUs ([#10474](https://github.com/pyg-team/pytorch_geometric/pull/10474))

### Deprecated

Expand Down
211 changes: 169 additions & 42 deletions examples/llm/txt2kg_rag.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import argparse
import gc
import json
import multiprocessing as mp
import os
import random
import re
Expand Down Expand Up @@ -424,6 +425,112 @@ def update_data_lists(args, data_lists):
return data_lists


def process_qa_chunk(chunk_data):
"""Process a chunk of QA pairs on a specific GPU."""
(qa_chunk, gpu_id, fs_path, gs_path, graph_data, triples, model_name,
dataset_path) = chunk_data

# Set the GPU for this process
torch.cuda.set_device(gpu_id)
device = torch.device(f"cuda:{gpu_id}")
print(f"Process {mp.current_process().pid} using GPU {gpu_id} "
f"to process {len(qa_chunk)} QA pairs")
sys.stdout.flush() # Force flush to see logs immediately

# Create embedding model for this process
sent_trans_batch_size = 256
model = SentenceTransformer(
model_name=ENCODER_MODEL_NAME_DEFAULT).to(device).eval()

# Load graph and feature stores with unique path for each GPU
backend_path = f"backend_gpu_{gpu_id}"
fs, gs = create_remote_backend_from_graph_data(
graph_data=graph_data, path=backend_path,
graph_db=NeighborSamplingRAGGraphStore,
feature_db=KNNRAGFeatureStore).load()

# Create subgraph filter for this process
subgraph_filter = make_pcst_filter(
triples,
model,
topk=5, # nodes
topk_e=5, # edges
cost_e=.5, # edge cost
num_clusters=10) # num clusters

# Load document retriever
model_kwargs = {
"output_device": device,
"batch_size": int(sent_trans_batch_size / 4),
}
doc_retriever_path = os.path.join(dataset_path, "document_retriever.pt")
vector_retriever = DocumentRetriever.load(doc_retriever_path,
model=model.encode,
model_kwargs=model_kwargs)

# Query loader config
fanout = 100
num_hops = 2
query_loader_config = {
"k_nodes": 1024,
"num_neighbors": [fanout] * num_hops,
"encoder_model": model,
}

# Create query loader for this process
query_loader = RAGQueryLoader(graph_data=(fs, gs),
subgraph_filter=subgraph_filter,
vector_retriever=vector_retriever,
config=query_loader_config)

# Process the chunk
chunk_data_list = []
chunk_triple_sizes = []
max_answer_len = 0

# Add tqdm progress bar for this GPU
from tqdm import tqdm

# Create a progress bar for this GPU with position offset to avoid overlap
gpu_pbar = tqdm(qa_chunk, desc=f"GPU {gpu_id}", position=gpu_id,
leave=True, total=len(qa_chunk))

for idx, data_point in enumerate(gpu_pbar):
if data_point["is_impossible"]:
continue
q = data_point["question"]
a = data_point["answer"]
max_answer_len = max(len(a), max_answer_len)

# Update progress bar description with current item
gpu_pbar.set_postfix({'QA': idx + 1})

subgraph = query_loader.query(q)
subgraph.label = a
chunk_data_list.append(subgraph)
chunk_triple_sizes.append(len(subgraph.triples))

gpu_pbar.close()

# Log completion
print(
f"GPU {gpu_id}: Completed processing {len(chunk_data_list)} QA pairs!")
sys.stdout.flush()

# Clean up
del model
del query_loader
gc.collect()
torch.cuda.empty_cache()

# Clean up GPU-specific backend directory
import shutil
if os.path.exists(backend_path):
shutil.rmtree(backend_path)

return chunk_data_list, chunk_triple_sizes, max_answer_len


def make_dataset(args):
qa_pairs, context_docs = get_data(args)
print("Number of Docs in our VectorDB =", len(context_docs))
Expand All @@ -447,6 +554,13 @@ def make_dataset(args):
triples = index_kg(args, context_docs)

print("Number of triples in our GraphDB =", len(triples))

# Determine number of GPUs to use
num_gpus = (args.num_gpus
if args.num_gpus is not None else torch.cuda.device_count())
num_gpus = min(num_gpus, torch.cuda.device_count())
print(f"Using {num_gpus} GPUs for dataset generation")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# creating the embedding model
Expand Down Expand Up @@ -496,52 +610,65 @@ def make_dataset(args):
model_kwargs=model_kwargs)
vector_retriever.save(doc_retriever_path)

subgraph_filter = make_pcst_filter(
triples,
model,
topk=5, # nodes
topk_e=5, # edges
cost_e=.5, # edge cost
num_clusters=10) # num clusters

# number of neighbors for each seed node selected by KNN
fanout = 100
# number of hops for neighborsampling
num_hops = 2

query_loader_config = {
"k_nodes": 1024, # k for Graph KNN
"num_neighbors": [fanout] * num_hops, # number of sampled neighbors
"encoder_model": model,
}

# GraphDB retrieval done with KNN+NeighborSampling+PCST
# PCST = Prize Collecting Steiner Tree
# VectorDB retrieval just vanilla vector RAG
# Distribute QA pairs across GPUs for parallel subgraph retrieval
# Each process loads its own embedding model, graph/feature stores,
# and query loader
print("Now to retrieve context for each query from "
"our Vector and Graph DBs...")

query_loader = RAGQueryLoader(graph_data=(fs, gs),
subgraph_filter=subgraph_filter,
vector_retriever=vector_retriever,
config=query_loader_config)

# pre-process the dataset
"our Vector and Graph DBs using multiprocessing...")

# Filter out impossible QA pairs first
valid_qa_pairs = [dp for dp in qa_pairs if not dp["is_impossible"]]

# Split QA pairs into chunks for each GPU
chunk_size = len(valid_qa_pairs) // num_gpus
qa_chunks = []
for i in range(num_gpus):
start_idx = i * chunk_size
end_idx = start_idx + chunk_size if i < num_gpus - 1 else len(
valid_qa_pairs)
qa_chunks.append(valid_qa_pairs[start_idx:end_idx])

# Prepare data for multiprocessing
chunk_data_list = []
for i, chunk in enumerate(qa_chunks):
chunk_data_list.append((
chunk, # QA pairs chunk
i % torch.cuda.device_count(), # GPU ID
"backend/fs", # fs path
"backend/gs", # gs path
graph_data, # graph data
triples, # triples
ENCODER_MODEL_NAME_DEFAULT, # model name
args.dataset # dataset path
))

# Use multiprocessing to process chunks in parallel
mp.set_start_method('spawn', force=True)

print(f"\nProcessing {len(valid_qa_pairs)} QA pairs "
f"across {num_gpus} GPUs...")
print("=" * 60)

with mp.Pool(processes=num_gpus) as pool:
results = list(pool.imap(process_qa_chunk, chunk_data_list))

# Clean up main backend directory if it exists
import shutil
if os.path.exists("backend"):
shutil.rmtree("backend")

# Combine results from all processes
total_data_list = []
extracted_triple_sizes = []
global max_chars_in_train_answer
for data_point in tqdm(qa_pairs, desc="Building un-split dataset"):
if data_point["is_impossible"]:
continue
QA_pair = (data_point["question"], data_point["answer"])
q = QA_pair[0]
max_chars_in_train_answer = max(len(QA_pair[1]),
max_chars_in_train_answer)
# (TODO) make this batch queries for retrieving w/ CuVS+CuGraph
subgraph = query_loader.query(q)
subgraph.label = QA_pair[1]
total_data_list.append(subgraph)
extracted_triple_sizes.append(len(subgraph.triples))
max_chars_in_train_answer = 0

for chunk_data_list, chunk_triple_sizes, chunk_max_answer_len in results:
total_data_list.extend(chunk_data_list)
extracted_triple_sizes.extend(chunk_triple_sizes)
max_chars_in_train_answer = max(max_chars_in_train_answer,
chunk_max_answer_len)

random.shuffle(total_data_list)

# stats
Expand Down
Loading