66 python scripts/calc_dna.py
77 python scripts/calc_dna.py --model Qwen/Qwen2.5-0.5B-Instruct
88 python scripts/calc_dna.py --model distilgpt2 --gpu 0 --samples 50
9+ python scripts/calc_dna.py --llm-list ./configs/llm_list.txt --gpus 0,1
910"""
1011
1112import argparse
1819if str (SRC ) not in sys .path :
1920 sys .path .insert (0 , str (SRC ))
2021
21- from reptrace import DNAExtractionConfig , calc_dna
22+ from reptrace import DNAExtractionConfig , calc_dna , calc_dna_parallel
2223
2324
2425def main ():
@@ -30,7 +31,13 @@ def main():
3031 "--model" ,
3132 type = str ,
3233 default = "distilgpt2" ,
33- help = "Model name or Hugging Face model ID"
34+ help = "Model name or Hugging Face model ID (ignored if --llm-list is provided)"
35+ )
36+ parser .add_argument (
37+ "--llm-list" ,
38+ type = Path ,
39+ default = None ,
40+ help = "Path to file containing model names (one per line) for batch processing"
3441 )
3542 parser .add_argument (
3643 "--dataset" ,
@@ -42,14 +49,25 @@ def main():
4249 "--gpu" ,
4350 type = int ,
4451 default = None ,
45- help = "GPU ID to use (None for CPU)"
52+ help = "GPU ID to use for single model (None for CPU)"
53+ )
54+ parser .add_argument (
55+ "--gpus" ,
56+ type = str ,
57+ default = None ,
58+ help = "Comma-separated GPU IDs for batch mode (e.g., '0,1,2')"
4659 )
4760 parser .add_argument (
4861 "--samples" ,
4962 type = int ,
5063 default = 100 ,
5164 help = "Number of probe samples to use for DNA extraction"
5265 )
66+ parser .add_argument (
67+ "--continue-on-error" ,
68+ action = "store_true" ,
69+ help = "Continue processing remaining models if one fails (batch mode only)"
70+ )
5371 parser .add_argument (
5472 "--no-save" ,
5573 action = "store_true" ,
@@ -62,6 +80,11 @@ def main():
6280 data_root = str (ROOT / "data" )
6381 output_dir = ROOT / "out"
6482
83+ # Parse GPU IDs for batch mode
84+ gpu_ids = None
85+ if args .gpus :
86+ gpu_ids = [int (g .strip ()) for g in args .gpus .split ("," ) if g .strip ()]
87+
6588 # Create configuration using the public API
6689 config = DNAExtractionConfig (
6790 model_name = args .model ,
@@ -75,7 +98,30 @@ def main():
7598 trust_remote_code = True ,
7699 )
77100
78- # Extract DNA
101+ # Batch mode: process multiple models from file
102+ if args .llm_list :
103+ print (f"Batch processing models from: { args .llm_list } " )
104+ print (f"Using { args .samples } probe samples per model" )
105+ if gpu_ids :
106+ print (f"GPUs: { gpu_ids } " )
107+
108+ results = calc_dna_parallel (
109+ config = config ,
110+ llm_list = args .llm_list ,
111+ gpu_ids = gpu_ids ,
112+ continue_on_error = args .continue_on_error ,
113+ )
114+
115+ # Display batch results
116+ print (f"\n { '=' * 60 } " )
117+ print (f"Processed { len (results )} model(s):" )
118+ for result in results :
119+ print (f" - { result .model_name } : shape={ result .vector .shape } , time={ result .elapsed_seconds :.2f} s" )
120+ if result .output_path :
121+ print (f" Saved to: { result .output_path } " )
122+ return 0
123+
124+ # Single model mode
79125 print (f"Extracting DNA from: { args .model } " )
80126 print (f"Using { args .samples } probe samples" )
81127 result = calc_dna (config )
@@ -94,3 +140,4 @@ def main():
94140
95141if __name__ == "__main__" :
96142 sys .exit (main ())
143+
0 commit comments