11import argparse
22import gc
33import json
4+ import multiprocessing as mp
45import os
56import random
67import re
@@ -424,6 +425,112 @@ def update_data_lists(args, data_lists):
424425 return data_lists
425426
426427
428+ def process_qa_chunk (chunk_data ):
429+ """Process a chunk of QA pairs on a specific GPU."""
430+ (qa_chunk , gpu_id , fs_path , gs_path , graph_data , triples , model_name ,
431+ dataset_path ) = chunk_data
432+
433+ # Set the GPU for this process
434+ torch .cuda .set_device (gpu_id )
435+ device = torch .device (f"cuda:{ gpu_id } " )
436+ print (f"Process { mp .current_process ().pid } using GPU { gpu_id } "
437+ f"to process { len (qa_chunk )} QA pairs" )
438+ sys .stdout .flush () # Force flush to see logs immediately
439+
440+ # Create embedding model for this process
441+ sent_trans_batch_size = 256
442+ model = SentenceTransformer (
443+ model_name = ENCODER_MODEL_NAME_DEFAULT ).to (device ).eval ()
444+
445+ # Load graph and feature stores with unique path for each GPU
446+ backend_path = f"backend_gpu_{ gpu_id } "
447+ fs , gs = create_remote_backend_from_graph_data (
448+ graph_data = graph_data , path = backend_path ,
449+ graph_db = NeighborSamplingRAGGraphStore ,
450+ feature_db = KNNRAGFeatureStore ).load ()
451+
452+ # Create subgraph filter for this process
453+ subgraph_filter = make_pcst_filter (
454+ triples ,
455+ model ,
456+ topk = 5 , # nodes
457+ topk_e = 5 , # edges
458+ cost_e = .5 , # edge cost
459+ num_clusters = 10 ) # num clusters
460+
461+ # Load document retriever
462+ model_kwargs = {
463+ "output_device" : device ,
464+ "batch_size" : int (sent_trans_batch_size / 4 ),
465+ }
466+ doc_retriever_path = os .path .join (dataset_path , "document_retriever.pt" )
467+ vector_retriever = DocumentRetriever .load (doc_retriever_path ,
468+ model = model .encode ,
469+ model_kwargs = model_kwargs )
470+
471+ # Query loader config
472+ fanout = 100
473+ num_hops = 2
474+ query_loader_config = {
475+ "k_nodes" : 1024 ,
476+ "num_neighbors" : [fanout ] * num_hops ,
477+ "encoder_model" : model ,
478+ }
479+
480+ # Create query loader for this process
481+ query_loader = RAGQueryLoader (graph_data = (fs , gs ),
482+ subgraph_filter = subgraph_filter ,
483+ vector_retriever = vector_retriever ,
484+ config = query_loader_config )
485+
486+ # Process the chunk
487+ chunk_data_list = []
488+ chunk_triple_sizes = []
489+ max_answer_len = 0
490+
491+ # Add tqdm progress bar for this GPU
492+ from tqdm import tqdm
493+
494+ # Create a progress bar for this GPU with position offset to avoid overlap
495+ gpu_pbar = tqdm (qa_chunk , desc = f"GPU { gpu_id } " , position = gpu_id ,
496+ leave = True , total = len (qa_chunk ))
497+
498+ for idx , data_point in enumerate (gpu_pbar ):
499+ if data_point ["is_impossible" ]:
500+ continue
501+ q = data_point ["question" ]
502+ a = data_point ["answer" ]
503+ max_answer_len = max (len (a ), max_answer_len )
504+
505+ # Update progress bar description with current item
506+ gpu_pbar .set_postfix ({'QA' : idx + 1 })
507+
508+ subgraph = query_loader .query (q )
509+ subgraph .label = a
510+ chunk_data_list .append (subgraph )
511+ chunk_triple_sizes .append (len (subgraph .triples ))
512+
513+ gpu_pbar .close ()
514+
515+ # Log completion
516+ print (
517+ f"GPU { gpu_id } : Completed processing { len (chunk_data_list )} QA pairs!" )
518+ sys .stdout .flush ()
519+
520+ # Clean up
521+ del model
522+ del query_loader
523+ gc .collect ()
524+ torch .cuda .empty_cache ()
525+
526+ # Clean up GPU-specific backend directory
527+ import shutil
528+ if os .path .exists (backend_path ):
529+ shutil .rmtree (backend_path )
530+
531+ return chunk_data_list , chunk_triple_sizes , max_answer_len
532+
533+
427534def make_dataset (args ):
428535 qa_pairs , context_docs = get_data (args )
429536 print ("Number of Docs in our VectorDB =" , len (context_docs ))
@@ -447,6 +554,13 @@ def make_dataset(args):
447554 triples = index_kg (args , context_docs )
448555
449556 print ("Number of triples in our GraphDB =" , len (triples ))
557+
558+ # Determine number of GPUs to use
559+ num_gpus = (args .num_gpus
560+ if args .num_gpus is not None else torch .cuda .device_count ())
561+ num_gpus = min (num_gpus , torch .cuda .device_count ())
562+ print (f"Using { num_gpus } GPUs for dataset generation" )
563+
450564 device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
451565
452566 # creating the embedding model
@@ -496,52 +610,65 @@ def make_dataset(args):
496610 model_kwargs = model_kwargs )
497611 vector_retriever .save (doc_retriever_path )
498612
499- subgraph_filter = make_pcst_filter (
500- triples ,
501- model ,
502- topk = 5 , # nodes
503- topk_e = 5 , # edges
504- cost_e = .5 , # edge cost
505- num_clusters = 10 ) # num clusters
506-
507- # number of neighbors for each seed node selected by KNN
508- fanout = 100
509- # number of hops for neighborsampling
510- num_hops = 2
511-
512- query_loader_config = {
513- "k_nodes" : 1024 , # k for Graph KNN
514- "num_neighbors" : [fanout ] * num_hops , # number of sampled neighbors
515- "encoder_model" : model ,
516- }
517-
518- # GraphDB retrieval done with KNN+NeighborSampling+PCST
519- # PCST = Prize Collecting Steiner Tree
520- # VectorDB retrieval just vanilla vector RAG
613+ # Distribute QA pairs across GPUs for parallel subgraph retrieval
614+ # Each process loads its own embedding model, graph/feature stores,
615+ # and query loader
521616 print ("Now to retrieve context for each query from "
522- "our Vector and Graph DBs..." )
523-
524- query_loader = RAGQueryLoader (graph_data = (fs , gs ),
525- subgraph_filter = subgraph_filter ,
526- vector_retriever = vector_retriever ,
527- config = query_loader_config )
528-
529- # pre-process the dataset
617+ "our Vector and Graph DBs using multiprocessing..." )
618+
619+ # Filter out impossible QA pairs first
620+ valid_qa_pairs = [dp for dp in qa_pairs if not dp ["is_impossible" ]]
621+
622+ # Split QA pairs into chunks for each GPU
623+ chunk_size = len (valid_qa_pairs ) // num_gpus
624+ qa_chunks = []
625+ for i in range (num_gpus ):
626+ start_idx = i * chunk_size
627+ end_idx = start_idx + chunk_size if i < num_gpus - 1 else len (
628+ valid_qa_pairs )
629+ qa_chunks .append (valid_qa_pairs [start_idx :end_idx ])
630+
631+ # Prepare data for multiprocessing
632+ chunk_data_list = []
633+ for i , chunk in enumerate (qa_chunks ):
634+ chunk_data_list .append ((
635+ chunk , # QA pairs chunk
636+ i % torch .cuda .device_count (), # GPU ID
637+ "backend/fs" , # fs path
638+ "backend/gs" , # gs path
639+ graph_data , # graph data
640+ triples , # triples
641+ ENCODER_MODEL_NAME_DEFAULT , # model name
642+ args .dataset # dataset path
643+ ))
644+
645+ # Use multiprocessing to process chunks in parallel
646+ mp .set_start_method ('spawn' , force = True )
647+
648+ print (f"\n Processing { len (valid_qa_pairs )} QA pairs "
649+ f"across { num_gpus } GPUs..." )
650+ print ("=" * 60 )
651+
652+ with mp .Pool (processes = num_gpus ) as pool :
653+ results = list (pool .imap (process_qa_chunk , chunk_data_list ))
654+
655+ # Clean up main backend directory if it exists
656+ import shutil
657+ if os .path .exists ("backend" ):
658+ shutil .rmtree ("backend" )
659+
660+ # Combine results from all processes
530661 total_data_list = []
531662 extracted_triple_sizes = []
532663 global max_chars_in_train_answer
533- for data_point in tqdm (qa_pairs , desc = "Building un-split dataset" ):
534- if data_point ["is_impossible" ]:
535- continue
536- QA_pair = (data_point ["question" ], data_point ["answer" ])
537- q = QA_pair [0 ]
538- max_chars_in_train_answer = max (len (QA_pair [1 ]),
539- max_chars_in_train_answer )
540- # (TODO) make this batch queries for retrieving w/ CuVS+CuGraph
541- subgraph = query_loader .query (q )
542- subgraph .label = QA_pair [1 ]
543- total_data_list .append (subgraph )
544- extracted_triple_sizes .append (len (subgraph .triples ))
664+ max_chars_in_train_answer = 0
665+
666+ for chunk_data_list , chunk_triple_sizes , chunk_max_answer_len in results :
667+ total_data_list .extend (chunk_data_list )
668+ extracted_triple_sizes .extend (chunk_triple_sizes )
669+ max_chars_in_train_answer = max (max_chars_in_train_answer ,
670+ chunk_max_answer_len )
671+
545672 random .shuffle (total_data_list )
546673
547674 # stats
0 commit comments