diff --git a/examples/llm/txt2kg_rag.py b/examples/llm/txt2kg_rag.py index ae87e0bb91f6..69ed7c4e4a1d 100644 --- a/examples/llm/txt2kg_rag.py +++ b/examples/llm/txt2kg_rag.py @@ -9,6 +9,7 @@ from glob import glob from itertools import chain from pathlib import Path +import multiprocessing as mp import yaml @@ -424,6 +425,114 @@ def update_data_lists(args, data_lists): return data_lists +def process_qa_chunk(chunk_data): + """Process a chunk of QA pairs on a specific GPU.""" + qa_chunk, gpu_id, fs_path, gs_path, graph_data, triples, model_name, dataset_path = chunk_data + + # Set the GPU for this process + torch.cuda.set_device(gpu_id) + device = torch.device(f"cuda:{gpu_id}") + print( + f"Process {mp.current_process().pid} using GPU {gpu_id} to process {len(qa_chunk)} QA pairs" + ) + sys.stdout.flush() # Force flush to see logs immediately + + # Create embedding model for this process + sent_trans_batch_size = 256 + model = SentenceTransformer( + model_name=ENCODER_MODEL_NAME_DEFAULT).to(device).eval() + + # Load graph and feature stores with unique path for each GPU + from torch_geometric.utils.rag.backend_utils import ( + create_remote_backend_from_graph_data, ) + backend_path = f"backend_gpu_{gpu_id}" + fs, gs = create_remote_backend_from_graph_data( + graph_data=graph_data, path=backend_path, + graph_db=NeighborSamplingRAGGraphStore, + feature_db=KNNRAGFeatureStore).load() + + # Create subgraph filter for this process + subgraph_filter = make_pcst_filter( + triples, + model, + topk=5, # nodes + topk_e=5, # edges + cost_e=.5, # edge cost + num_clusters=10) # num clusters + + # Load document retriever + model_kwargs = { + "output_device": device, + "batch_size": int(sent_trans_batch_size / 4), + } + doc_retriever_path = os.path.join(dataset_path, "document_retriever.pt") + vector_retriever = DocumentRetriever.load(doc_retriever_path, + model=model.encode, + model_kwargs=model_kwargs) + + # Query loader config + fanout = 100 + num_hops = 2 + query_loader_config = { + "k_nodes": 1024, + "num_neighbors": [fanout] * num_hops, + "encoder_model": model, + } + + # Create query loader for this process + query_loader = RAGQueryLoader(graph_data=(fs, gs), + subgraph_filter=subgraph_filter, + vector_retriever=vector_retriever, + config=query_loader_config) + + # Process the chunk + chunk_data_list = [] + chunk_triple_sizes = [] + max_answer_len = 0 + + # Add tqdm progress bar for this GPU + from tqdm import tqdm + + # Create a progress bar for this GPU with position offset to avoid overlap + gpu_pbar = tqdm(qa_chunk, desc=f"GPU {gpu_id}", position=gpu_id, + leave=True, total=len(qa_chunk)) + + for idx, data_point in enumerate(gpu_pbar): + if data_point["is_impossible"]: + continue + q = data_point["question"] + a = data_point["answer"] + max_answer_len = max(len(a), max_answer_len) + + # Update progress bar description with current item + gpu_pbar.set_postfix({'QA': idx + 1}) + + subgraph = query_loader.query(q) + subgraph.label = a + chunk_data_list.append(subgraph) + chunk_triple_sizes.append(len(subgraph.triples)) + + gpu_pbar.close() + + # Log completion + print( + f"GPU {gpu_id}: Completed processing {len(chunk_data_list)} QA pairs!") + sys.stdout.flush() + + # Clean up + del model + del query_loader + gc.collect() + torch.cuda.empty_cache() + + # Clean up GPU-specific backend directory + import shutil + if os.path.exists(backend_path): + shutil.rmtree(backend_path) + + return chunk_data_list, chunk_triple_sizes, max_answer_len + + def make_dataset(args): qa_pairs, context_docs = get_data(args) print("Number of Docs in our VectorDB =", len(context_docs)) @@ -447,6 +556,13 @@ def make_dataset(args): triples = index_kg(args, context_docs) print("Number of triples in our GraphDB =", len(triples)) + + # Determine number of GPUs to use + num_gpus = args.num_gpus if args.num_gpus is not None else torch.cuda.device_count( + ) + num_gpus = min(num_gpus, torch.cuda.device_count()) + print(f"Using {num_gpus} GPUs for dataset generation") + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # creating the embedding model @@ -527,21 +643,61 @@ def make_dataset(args): config=query_loader_config) # pre-process the dataset + print("Now to retrieve context for each query from " + "our Vector and Graph DBs using multiprocessing...") + + # Filter out impossible QA pairs first + valid_qa_pairs = [dp for dp in qa_pairs if not dp["is_impossible"]] + + # Split QA pairs into chunks for each GPU + chunk_size = len(valid_qa_pairs) // num_gpus + qa_chunks = [] + for i in range(num_gpus): + start_idx = i * chunk_size + end_idx = start_idx + chunk_size if i < num_gpus - 1 else len( + valid_qa_pairs) + qa_chunks.append(valid_qa_pairs[start_idx:end_idx]) + + # Prepare data for multiprocessing + chunk_data_list = [] + for i, chunk in enumerate(qa_chunks): + chunk_data_list.append(( + chunk, # QA pairs chunk + i % torch.cuda.device_count(), # GPU ID + "backend/fs", # fs path + "backend/gs", # gs path + graph_data, # graph data + triples, # triples + ENCODER_MODEL_NAME_DEFAULT, # model name + args.dataset # dataset path + )) + + # Use multiprocessing to process chunks in parallel + mp.set_start_method('spawn', force=True) + + print( + f"\nProcessing {len(valid_qa_pairs)} QA pairs across {num_gpus} GPUs..." + ) + print("=" * 60) + + with mp.Pool(processes=num_gpus) as pool: + results = list(pool.imap(process_qa_chunk, chunk_data_list)) + + # Clean up main backend directory if it exists + import shutil + if os.path.exists("backend"): + shutil.rmtree("backend") + + # Combine results from all processes total_data_list = [] extracted_triple_sizes = [] global max_chars_in_train_answer - for data_point in tqdm(qa_pairs, desc="Building un-split dataset"): - if data_point["is_impossible"]: - continue - QA_pair = (data_point["question"], data_point["answer"]) - q = QA_pair[0] - max_chars_in_train_answer = max(len(QA_pair[1]), - max_chars_in_train_answer) - # (TODO) make this batch queries for retrieving w/ CuVS+CuGraph - subgraph = query_loader.query(q) - subgraph.label = QA_pair[1] - total_data_list.append(subgraph) - extracted_triple_sizes.append(len(subgraph.triples)) + + for chunk_data, chunk_sizes, max_ans_len in results: + total_data_list.extend(chunk_data) + extracted_triple_sizes.extend(chunk_sizes) + max_chars_in_train_answer = max(max_chars_in_train_answer, max_ans_len) + random.shuffle(total_data_list) # stats