44against multiple ONNX Runtime GenAI models with different execution providers. 
55
66Usage: 
7- python compare_models_generic .py --hf_model "F:\shared\Llama-3.1-8B-Instruct" 
7+ python compute_kl_divergence .py --hf_model "F:\shared\Llama-3.1-8B-Instruct" 
88    --ep cuda --path "G:\models\cuda_model" 
99    --ep directml --path "G:\models\directml_model" 
1010    --output "comparison_results.json" 
3636
3737
3838def  debug_print (message ):
39-     """Print debug message only if DEBUG is True""" 
39+     """ 
40+     Print debug message only if DEBUG flag is enabled. 
41+ 
42+     Args: 
43+         message (str): Debug message to print. 
44+     """ 
4045    if  DEBUG :
4146        print (f"[DEBUG] { message }  " )
4247
4348
4449def  run_command (cmd , description = "" , capture_output = True ):
45-     """Run a command and handle errors""" 
50+     """ 
51+     Execute a subprocess command with error handling. 
52+ 
53+     Args: 
54+         cmd (list[str]): Command and arguments to execute. 
55+         description (str, optional): Description of the command for logging. Defaults to "". 
56+         capture_output (bool, optional): Whether to capture stdout/stderr or show in real-time. 
57+                                         Defaults to True. 
58+ 
59+     Returns: 
60+         bool: True if command succeeded, False otherwise. 
61+     """ 
4662    debug_print (f"[INFO] { description }  " )
4763    debug_print (f"Running: { ' ' .join (cmd )}  " )
4864
4965    try :
5066        if  capture_output :
51-             result  =  subprocess .run (cmd , check = True , capture_output = True , text = True , shell = True )
67+             result  =  subprocess .run (cmd , check = True , capture_output = True , text = True , shell = False )
5268            if  result .stdout  and  DEBUG :
5369                print (f"[OUT] { result .stdout }  " )
5470        else :
5571            # Real-time output - shows prints as they happen 
56-             result  =  subprocess .run (cmd , check = True , shell = True )
72+             result  =  subprocess .run (cmd , check = True , shell = False )
5773        return  True 
5874    except  subprocess .CalledProcessError  as  e :
5975        print (f"[ERROR] Command failed: { e }  " )
@@ -70,7 +86,12 @@ def get_python_executable():
7086
7187
7288def  uninstall_onnxruntime_packages ():
73-     """Uninstall all ONNX Runtime packages""" 
89+     """ 
90+     Uninstall all ONNX Runtime and ONNX Runtime GenAI packages. 
91+ 
92+     This ensures a clean environment before installing provider-specific packages 
93+     to avoid version conflicts. 
94+     """ 
7495    packages_to_remove  =  [
7596        "onnxruntime" ,
7697        "onnxruntime-genai" ,
@@ -88,7 +109,15 @@ def uninstall_onnxruntime_packages():
88109
89110
90111def  install_package (package_name ):
91-     """Install a specific package""" 
112+     """ 
113+     Install a specific Python package using pip. 
114+ 
115+     Args: 
116+         package_name (str): Name of the package to install. 
117+ 
118+     Returns: 
119+         bool: True if installation succeeded, False otherwise. 
120+     """ 
92121    debug_print (f"Installing package: { package_name }  " )
93122    python_exe  =  get_python_executable ()
94123    debug_print (f"Python executable: { python_exe }  " )
@@ -101,19 +130,33 @@ def install_package(package_name):
101130
102131
103132def  extract_hf_logits_subprocess (model_path , device = "cuda" ):
104-     """Extract logits from Hugging Face model using subprocess""" 
133+     """ 
134+     Extract logits from a Hugging Face transformer model using a subprocess. 
135+ 
136+     Runs extract_logits_hf.py in a separate process to avoid package conflicts. 
137+     Uses temporary file for data transfer between processes. 
138+ 
139+     Args: 
140+         model_path (str): Path to the Hugging Face model directory. 
141+         device (str, optional): Device for inference ('cuda' or 'cpu'). Defaults to "cuda". 
142+ 
143+     """ 
105144    print ("[INFO] Extracting logits from Hugging Face baseline model..." )
106145    debug_print (f"Model path: { model_path }  , Device: { device }  " )
107146
108147    # Create temporary output file 
109-     output_file  =  f"temp_logits_hf_{ int (time .time ())}  .pkl" 
148+     import  tempfile 
149+ 
150+     script_dir  =  os .path .dirname (os .path .abspath (__file__ ))
151+     with  tempfile .NamedTemporaryFile (prefix = "temp_logits_hf_" , suffix = ".pkl" , delete = False ) as  tmp :
152+         output_file  =  tmp .name 
110153    debug_print (f"Temporary output file: { output_file }  " )
111154
112155    try :
113156        python_exe  =  get_python_executable ()
114157        cmd  =  [
115158            python_exe ,
116-             "extract_logits_hf.py" ,
159+             os . path . join ( script_dir ,  "extract_logits_hf.py" ) ,
117160            "--model_path" ,
118161            model_path ,
119162            "--output_file" ,
@@ -158,19 +201,33 @@ def extract_hf_logits_subprocess(model_path, device="cuda"):
158201
159202
160203def  extract_onnx_logits_subprocess (model_path , provider ):
161-     """Extract logits from ONNX Runtime GenAI model using subprocess""" 
204+     """ 
205+     Extract logits from an ONNX Runtime GenAI model using a subprocess. 
206+ 
207+     Runs extract_logits.py in a separate process with the appropriate ONNX Runtime 
208+     package for the specified execution provider. Uses temporary file for data transfer. 
209+ 
210+     Args: 
211+         model_path (str): Path to the ONNX Runtime GenAI model directory. 
212+         provider (str): Execution provider ('cuda', 'directml', or 'cpu'). 
213+ 
214+     """ 
162215    print (f"[INFO] Extracting logits from { provider .upper ()}   model..." )
163216    debug_print (f"Model path: { model_path }  , Provider: { provider }  " )
164217
165218    # Create temporary output file 
166-     output_file  =  f"temp_logits_{ provider }  _{ int (time .time ())}  .pkl" 
219+     import  tempfile 
220+ 
221+     script_dir  =  os .path .dirname (os .path .abspath (__file__ ))
222+     with  tempfile .NamedTemporaryFile (prefix = "temp_logits_" , suffix = ".pkl" , delete = False ) as  tmp :
223+         output_file  =  tmp .name 
167224    debug_print (f"Temporary output file: { output_file }  " )
168225
169226    try :
170227        python_exe  =  get_python_executable ()
171228        cmd  =  [
172229            python_exe ,
173-             "extract_logits.py" ,
230+             os . path . join ( script_dir ,  "extract_logits.py" ) ,
174231            "--model_path" ,
175232            model_path ,
176233            "--output_file" ,
@@ -220,8 +277,20 @@ def extract_onnx_logits_subprocess(model_path, provider):
220277
221278def  compute_kl_divergence_from_logits (log_probs_ref , log_probs_tar ):
222279    """ 
223-     Compute KL divergence between two log probability distributions. 
224-     Same logic as in compute_kl_divergence.py 
280+     Compute Kullback-Leibler divergence between two log probability distributions. 
281+ 
282+     KL divergence measures how one probability distribution diverges from a reference 
283+     distribution. Lower values indicate more similar distributions. 
284+ 
285+     Args: 
286+         log_probs_ref (np.ndarray): Reference log probabilities with shape (seq_len, vocab_size). 
287+         log_probs_tar (np.ndarray): Target log probabilities with shape (seq_len, vocab_size). 
288+ 
289+     Returns: 
290+         float: Average KL divergence across all positions. 
291+ 
292+     Note: 
293+         Formula: KL(P||Q) = sum(P(x) * |log(P(x)) - log(Q(x))|) averaged over sequence length 
225294    """ 
226295    debug_print (
227296        f"Computing KL divergence - log_probs shapes: ref={ log_probs_ref .shape }  , tar={ log_probs_tar .shape }  " 
@@ -239,7 +308,13 @@ def compute_kl_divergence_from_logits(log_probs_ref, log_probs_tar):
239308
240309def  to_serializable (obj ):
241310    """ 
242-     Recursively convert numpy types and torch types to native Python types for JSON serialization. 
311+     Recursively convert numpy and torch types to native Python types for JSON serialization. 
312+ 
313+     Args: 
314+         obj: Object to convert (dict, list, tuple, np.ndarray, torch.Tensor, etc.). 
315+ 
316+     Returns: 
317+         Converted object with native Python types (int, float, list, dict, tuple). 
243318    """ 
244319    if  isinstance (obj , dict ):
245320        return  {k : to_serializable (v ) for  k , v  in  obj .items ()}
@@ -261,8 +336,23 @@ def to_serializable(obj):
261336
262337def  compute_unified_comparison (model_logits_list , output_file ):
263338    """ 
264-     Compute KL divergence comparison between all models in a unified way 
265-     model_logits_list: List of tuples (model_name, model_data) 
339+     Compute pairwise KL divergence between all models and save results to JSON. 
340+ 
341+     This function performs an all-vs-all comparison of the provided models by computing 
342+     KL divergence for each chunk and averaging across all chunks. Results are saved 
343+     in a structured JSON format. 
344+ 
345+     Args: 
346+         model_logits_list (list): List of tuples (model_name, model_data) where: 
347+             - model_name (str): Identifier for the model (e.g., "hf_baseline", "cuda_1") 
348+             - model_data (dict): Dictionary containing: 
349+                 - 'logits': List of numpy arrays (one per chunk) 
350+                 - 'total_chunks': Number of chunks 
351+                 - 'seq_len': Sequence length 
352+                 - 'model_path': Path to model 
353+                 - 'chunk_info': Chunk position info 
354+         output_file (str): Path to save the JSON results file. 
355+ 
266356    """ 
267357    print ("\n [INFO] Computing unified KL divergence comparison..." )
268358    debug_print (f"Number of models to compare: { len (model_logits_list )}  " )
@@ -325,7 +415,7 @@ def compute_unified_comparison(model_logits_list, output_file):
325415        # Find minimum sequence length for this chunk 
326416        min_seq_len  =  min (getattr (logits , "shape" , [None , 0 ])[1 ] for  _ , logits  in  chunk_logits )
327417        # Assume all have same vocab size 
328-         vocab_size  =  getattr (chunk_logits [ 0 ][ 1 ] , "shape" , [None , None , 0 ])[2 ]
418+         vocab_size  =  min ( getattr (logits , "shape" , [None , None , 0 ])[2 ]  for   _ ,  logits   in   chunk_logits ) 
329419        debug_print (f"  Min seq len: { min_seq_len }  , Vocab size: { vocab_size }  " )
330420
331421        # Trim all logits to matching dimensions 
@@ -396,7 +486,16 @@ def compute_unified_comparison(model_logits_list, output_file):
396486
397487
398488def  validate_inputs (hf_model , ep_path_pairs ):
399-     """Validate that all input paths exist and EPs are supported""" 
489+     """ 
490+     Validate that all model paths exist and execution providers are supported. 
491+ 
492+     Args: 
493+         hf_model (str or None): Path to Hugging Face model (optional). 
494+         ep_path_pairs (list): List of (execution_provider, model_path) tuples. 
495+ 
496+     Returns: 
497+         bool: True if all inputs are valid, False otherwise. 
498+     """ 
400499    # Check HF model path (only if provided) 
401500    if  hf_model  and  not  os .path .exists (hf_model ):
402501        print (f"[ERROR] Hugging Face model path does not exist: { hf_model }  " )
@@ -556,7 +655,7 @@ def main():
556655        ]
557656
558657        print ("\n [INFO] Running KL_divergence_metrics_same_ep.py..." )
559-         result  =  subprocess .run (cmd , shell = True )
658+         result  =  subprocess .run (cmd , shell = False )
560659
561660        if  result .returncode  ==  0 :
562661            print ("\n [SUCCESS] KL divergence computation completed successfully" )
0 commit comments