Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
180 changes: 168 additions & 12 deletions examples/llm/txt2kg_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from glob import glob
from itertools import chain
from pathlib import Path
import multiprocessing as mp

import yaml

Expand Down Expand Up @@ -424,6 +425,114 @@ def update_data_lists(args, data_lists):
return data_lists


def process_qa_chunk(chunk_data):
"""Process a chunk of QA pairs on a specific GPU."""
qa_chunk, gpu_id, fs_path, gs_path, graph_data, triples, model_name, dataset_path = chunk_data

# Set the GPU for this process
torch.cuda.set_device(gpu_id)
device = torch.device(f"cuda:{gpu_id}")
print(
f"Process {mp.current_process().pid} using GPU {gpu_id} to process {len(qa_chunk)} QA pairs"
)
sys.stdout.flush() # Force flush to see logs immediately

# Create embedding model for this process
sent_trans_batch_size = 256
model = SentenceTransformer(
model_name=ENCODER_MODEL_NAME_DEFAULT).to(device).eval()

# Load graph and feature stores with unique path for each GPU
from torch_geometric.utils.rag.backend_utils import (
create_remote_backend_from_graph_data, )
backend_path = f"backend_gpu_{gpu_id}"
fs, gs = create_remote_backend_from_graph_data(
graph_data=graph_data, path=backend_path,
graph_db=NeighborSamplingRAGGraphStore,
feature_db=KNNRAGFeatureStore).load()

# Create subgraph filter for this process
subgraph_filter = make_pcst_filter(
triples,
model,
topk=5, # nodes
topk_e=5, # edges
cost_e=.5, # edge cost
num_clusters=10) # num clusters

# Load document retriever
model_kwargs = {
"output_device": device,
"batch_size": int(sent_trans_batch_size / 4),
}
doc_retriever_path = os.path.join(dataset_path, "document_retriever.pt")
vector_retriever = DocumentRetriever.load(doc_retriever_path,
model=model.encode,
model_kwargs=model_kwargs)

# Query loader config
fanout = 100
num_hops = 2
query_loader_config = {
"k_nodes": 1024,
"num_neighbors": [fanout] * num_hops,
"encoder_model": model,
}

# Create query loader for this process
query_loader = RAGQueryLoader(graph_data=(fs, gs),
subgraph_filter=subgraph_filter,
vector_retriever=vector_retriever,
config=query_loader_config)

# Process the chunk
chunk_data_list = []
chunk_triple_sizes = []
max_answer_len = 0

# Add tqdm progress bar for this GPU
from tqdm import tqdm

# Create a progress bar for this GPU with position offset to avoid overlap
gpu_pbar = tqdm(qa_chunk, desc=f"GPU {gpu_id}", position=gpu_id,
leave=True, total=len(qa_chunk))

for idx, data_point in enumerate(gpu_pbar):
if data_point["is_impossible"]:
continue
q = data_point["question"]
a = data_point["answer"]
max_answer_len = max(len(a), max_answer_len)

# Update progress bar description with current item
gpu_pbar.set_postfix({'QA': idx + 1})

subgraph = query_loader.query(q)
subgraph.label = a
chunk_data_list.append(subgraph)
chunk_triple_sizes.append(len(subgraph.triples))

gpu_pbar.close()

# Log completion
print(
f"GPU {gpu_id}: Completed processing {len(chunk_data_list)} QA pairs!")
sys.stdout.flush()

# Clean up
del model
del query_loader
gc.collect()
torch.cuda.empty_cache()

# Clean up GPU-specific backend directory
import shutil
if os.path.exists(backend_path):
shutil.rmtree(backend_path)

return chunk_data_list, chunk_triple_sizes, max_answer_len


def make_dataset(args):
qa_pairs, context_docs = get_data(args)
print("Number of Docs in our VectorDB =", len(context_docs))
Expand All @@ -447,6 +556,13 @@ def make_dataset(args):
triples = index_kg(args, context_docs)

print("Number of triples in our GraphDB =", len(triples))

# Determine number of GPUs to use
num_gpus = args.num_gpus if args.num_gpus is not None else torch.cuda.device_count(
)
num_gpus = min(num_gpus, torch.cuda.device_count())
print(f"Using {num_gpus} GPUs for dataset generation")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# creating the embedding model
Expand Down Expand Up @@ -527,21 +643,61 @@ def make_dataset(args):
config=query_loader_config)

# pre-process the dataset
print("Now to retrieve context for each query from "
"our Vector and Graph DBs using multiprocessing...")

# Filter out impossible QA pairs first
valid_qa_pairs = [dp for dp in qa_pairs if not dp["is_impossible"]]

# Split QA pairs into chunks for each GPU
chunk_size = len(valid_qa_pairs) // num_gpus
qa_chunks = []
for i in range(num_gpus):
start_idx = i * chunk_size
end_idx = start_idx + chunk_size if i < num_gpus - 1 else len(
valid_qa_pairs)
qa_chunks.append(valid_qa_pairs[start_idx:end_idx])

# Prepare data for multiprocessing
chunk_data_list = []
for i, chunk in enumerate(qa_chunks):
chunk_data_list.append((
chunk, # QA pairs chunk
i % torch.cuda.device_count(), # GPU ID
"backend/fs", # fs path
"backend/gs", # gs path
graph_data, # graph data
triples, # triples
ENCODER_MODEL_NAME_DEFAULT, # model name
args.dataset # dataset path
))

# Use multiprocessing to process chunks in parallel
mp.set_start_method('spawn', force=True)

print(
f"\nProcessing {len(valid_qa_pairs)} QA pairs across {num_gpus} GPUs..."
)
print("=" * 60)

with mp.Pool(processes=num_gpus) as pool:
results = list(pool.imap(process_qa_chunk, chunk_data_list))

# Clean up main backend directory if it exists
import shutil
if os.path.exists("backend"):
shutil.rmtree("backend")

# Combine results from all processes
total_data_list = []
extracted_triple_sizes = []
global max_chars_in_train_answer
for data_point in tqdm(qa_pairs, desc="Building un-split dataset"):
if data_point["is_impossible"]:
continue
QA_pair = (data_point["question"], data_point["answer"])
q = QA_pair[0]
max_chars_in_train_answer = max(len(QA_pair[1]),
max_chars_in_train_answer)
# (TODO) make this batch queries for retrieving w/ CuVS+CuGraph
subgraph = query_loader.query(q)
subgraph.label = QA_pair[1]
total_data_list.append(subgraph)
extracted_triple_sizes.append(len(subgraph.triples))

for chunk_data, chunk_sizes, max_ans_len in results:
total_data_list.extend(chunk_data)
extracted_triple_sizes.extend(chunk_sizes)
max_chars_in_train_answer = max(max_chars_in_train_answer, max_ans_len)

random.shuffle(total_data_list)

# stats
Expand Down