11import argparse
22import gc
33import json
4+ import multiprocessing as mp
45import os
56import random
67import re
@@ -424,6 +425,115 @@ def update_data_lists(args, data_lists):
424425 return data_lists
425426
426427
428+ def process_qa_chunk (chunk_data ):
429+ """Process a chunk of QA pairs on a specific GPU."""
430+ (qa_chunk , gpu_id , fs_path , gs_path , graph_data , triples , model_name ,
431+ dataset_path ) = chunk_data
432+
433+ # Set the GPU for this process
434+ torch .cuda .set_device (gpu_id )
435+ device = torch .device (f"cuda:{ gpu_id } " )
436+ print (f"Process { mp .current_process ().pid } using GPU { gpu_id } "
437+ f"to process { len (qa_chunk )} QA pairs" )
438+ sys .stdout .flush () # Force flush to see logs immediately
439+
440+ # Create embedding model for this process
441+ sent_trans_batch_size = 256
442+ model = SentenceTransformer (
443+ model_name = ENCODER_MODEL_NAME_DEFAULT ).to (device ).eval ()
444+
445+ # Load graph and feature stores with unique path for each GPU
446+ from torch_geometric .llm .utils .backend_utils import (
447+ create_remote_backend_from_graph_data ,
448+ )
449+ backend_path = f"backend_gpu_{ gpu_id } "
450+ fs , gs = create_remote_backend_from_graph_data (
451+ graph_data = graph_data , path = backend_path ,
452+ graph_db = NeighborSamplingRAGGraphStore ,
453+ feature_db = KNNRAGFeatureStore ).load ()
454+
455+ # Create subgraph filter for this process
456+ subgraph_filter = make_pcst_filter (
457+ triples ,
458+ model ,
459+ topk = 5 , # nodes
460+ topk_e = 5 , # edges
461+ cost_e = .5 , # edge cost
462+ num_clusters = 10 ) # num clusters
463+
464+ # Load document retriever
465+ model_kwargs = {
466+ "output_device" : device ,
467+ "batch_size" : int (sent_trans_batch_size / 4 ),
468+ }
469+ doc_retriever_path = os .path .join (dataset_path , "document_retriever.pt" )
470+ vector_retriever = DocumentRetriever .load (doc_retriever_path ,
471+ model = model .encode ,
472+ model_kwargs = model_kwargs )
473+
474+ # Query loader config
475+ fanout = 100
476+ num_hops = 2
477+ query_loader_config = {
478+ "k_nodes" : 1024 ,
479+ "num_neighbors" : [fanout ] * num_hops ,
480+ "encoder_model" : model ,
481+ }
482+
483+ # Create query loader for this process
484+ query_loader = RAGQueryLoader (graph_data = (fs , gs ),
485+ subgraph_filter = subgraph_filter ,
486+ vector_retriever = vector_retriever ,
487+ config = query_loader_config )
488+
489+ # Process the chunk
490+ chunk_data_list = []
491+ chunk_triple_sizes = []
492+ max_answer_len = 0
493+
494+ # Add tqdm progress bar for this GPU
495+ from tqdm import tqdm
496+
497+ # Create a progress bar for this GPU with position offset to avoid overlap
498+ gpu_pbar = tqdm (qa_chunk , desc = f"GPU { gpu_id } " , position = gpu_id ,
499+ leave = True , total = len (qa_chunk ))
500+
501+ for idx , data_point in enumerate (gpu_pbar ):
502+ if data_point ["is_impossible" ]:
503+ continue
504+ q = data_point ["question" ]
505+ a = data_point ["answer" ]
506+ max_answer_len = max (len (a ), max_answer_len )
507+
508+ # Update progress bar description with current item
509+ gpu_pbar .set_postfix ({'QA' : idx + 1 })
510+
511+ subgraph = query_loader .query (q )
512+ subgraph .label = a
513+ chunk_data_list .append (subgraph )
514+ chunk_triple_sizes .append (len (subgraph .triples ))
515+
516+ gpu_pbar .close ()
517+
518+ # Log completion
519+ print (
520+ f"GPU { gpu_id } : Completed processing { len (chunk_data_list )} QA pairs!" )
521+ sys .stdout .flush ()
522+
523+ # Clean up
524+ del model
525+ del query_loader
526+ gc .collect ()
527+ torch .cuda .empty_cache ()
528+
529+ # Clean up GPU-specific backend directory
530+ import shutil
531+ if os .path .exists (backend_path ):
532+ shutil .rmtree (backend_path )
533+
534+ return chunk_data_list , chunk_triple_sizes , max_answer_len
535+
536+
427537def make_dataset (args ):
428538 qa_pairs , context_docs = get_data (args )
429539 print ("Number of Docs in our VectorDB =" , len (context_docs ))
@@ -447,6 +557,13 @@ def make_dataset(args):
447557 triples = index_kg (args , context_docs )
448558
449559 print ("Number of triples in our GraphDB =" , len (triples ))
560+
561+ # Determine number of GPUs to use
562+ num_gpus = (args .num_gpus
563+ if args .num_gpus is not None else torch .cuda .device_count ())
564+ num_gpus = min (num_gpus , torch .cuda .device_count ())
565+ print (f"Using { num_gpus } GPUs for dataset generation" )
566+
450567 device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
451568
452569 # creating the embedding model
@@ -496,52 +613,65 @@ def make_dataset(args):
496613 model_kwargs = model_kwargs )
497614 vector_retriever .save (doc_retriever_path )
498615
499- subgraph_filter = make_pcst_filter (
500- triples ,
501- model ,
502- topk = 5 , # nodes
503- topk_e = 5 , # edges
504- cost_e = .5 , # edge cost
505- num_clusters = 10 ) # num clusters
506-
507- # number of neighbors for each seed node selected by KNN
508- fanout = 100
509- # number of hops for neighborsampling
510- num_hops = 2
511-
512- query_loader_config = {
513- "k_nodes" : 1024 , # k for Graph KNN
514- "num_neighbors" : [fanout ] * num_hops , # number of sampled neighbors
515- "encoder_model" : model ,
516- }
517-
518- # GraphDB retrieval done with KNN+NeighborSampling+PCST
519- # PCST = Prize Collecting Steiner Tree
520- # VectorDB retrieval just vanilla vector RAG
616+ # Distribute QA pairs across GPUs for parallel subgraph retrieval
617+ # Each process loads its own embedding model, graph/feature stores,
618+ # and query loader
521619 print ("Now to retrieve context for each query from "
522- "our Vector and Graph DBs..." )
523-
524- query_loader = RAGQueryLoader (graph_data = (fs , gs ),
525- subgraph_filter = subgraph_filter ,
526- vector_retriever = vector_retriever ,
527- config = query_loader_config )
528-
529- # pre-process the dataset
620+ "our Vector and Graph DBs using multiprocessing..." )
621+
622+ # Filter out impossible QA pairs first
623+ valid_qa_pairs = [dp for dp in qa_pairs if not dp ["is_impossible" ]]
624+
625+ # Split QA pairs into chunks for each GPU
626+ chunk_size = len (valid_qa_pairs ) // num_gpus
627+ qa_chunks = []
628+ for i in range (num_gpus ):
629+ start_idx = i * chunk_size
630+ end_idx = start_idx + chunk_size if i < num_gpus - 1 else len (
631+ valid_qa_pairs )
632+ qa_chunks .append (valid_qa_pairs [start_idx :end_idx ])
633+
634+ # Prepare data for multiprocessing
635+ chunk_data_list = []
636+ for i , chunk in enumerate (qa_chunks ):
637+ chunk_data_list .append ((
638+ chunk , # QA pairs chunk
639+ i % torch .cuda .device_count (), # GPU ID
640+ "backend/fs" , # fs path
641+ "backend/gs" , # gs path
642+ graph_data , # graph data
643+ triples , # triples
644+ ENCODER_MODEL_NAME_DEFAULT , # model name
645+ args .dataset # dataset path
646+ ))
647+
648+ # Use multiprocessing to process chunks in parallel
649+ mp .set_start_method ('spawn' , force = True )
650+
651+ print (f"\n Processing { len (valid_qa_pairs )} QA pairs "
652+ f"across { num_gpus } GPUs..." )
653+ print ("=" * 60 )
654+
655+ with mp .Pool (processes = num_gpus ) as pool :
656+ results = list (pool .imap (process_qa_chunk , chunk_data_list ))
657+
658+ # Clean up main backend directory if it exists
659+ import shutil
660+ if os .path .exists ("backend" ):
661+ shutil .rmtree ("backend" )
662+
663+ # Combine results from all processes
530664 total_data_list = []
531665 extracted_triple_sizes = []
532666 global max_chars_in_train_answer
533- for data_point in tqdm (qa_pairs , desc = "Building un-split dataset" ):
534- if data_point ["is_impossible" ]:
535- continue
536- QA_pair = (data_point ["question" ], data_point ["answer" ])
537- q = QA_pair [0 ]
538- max_chars_in_train_answer = max (len (QA_pair [1 ]),
539- max_chars_in_train_answer )
540- # (TODO) make this batch queries for retrieving w/ CuVS+CuGraph
541- subgraph = query_loader .query (q )
542- subgraph .label = QA_pair [1 ]
543- total_data_list .append (subgraph )
544- extracted_triple_sizes .append (len (subgraph .triples ))
667+ max_chars_in_train_answer = 0
668+
669+ for chunk_data_list , chunk_triple_sizes , chunk_max_answer_len in results :
670+ total_data_list .extend (chunk_data_list )
671+ extracted_triple_sizes .extend (chunk_triple_sizes )
672+ max_chars_in_train_answer = max (max_chars_in_train_answer ,
673+ chunk_max_answer_len )
674+
545675 random .shuffle (total_data_list )
546676
547677 # stats
0 commit comments