1+ #!/usr/bin/env python3
2+ """
3+ Concurrency Safety Test for PR #2498
4+ Uses w8a16 scheme (weight-only, no calibration needed).
5+ This test runs multiple PTQ jobs with different max_workers settings and compares outputs for consistency. It also includes a stress test with 16 workers to check for stability under heavy concurrency.
6+ By David Zheng (dqzheng1996@gmail.com)
7+ """
8+
9+ import os
10+ import sys
11+ import hashlib
12+ import shutil
13+ import glob
14+ from pathlib import Path
15+
16+ # ============================================================================
17+ # CONFIGURATION
18+ # ============================================================================
19+ MODEL_PATH = os .environ .get ("MODEL_PATH" , "TinyLlama/TinyLlama-1.1B-Chat-v1.0" )
20+ SCHEME = "w8a16" # Confirmed working with model_free_ptq
21+ MAX_WORKERS_BASELINE = 1
22+ MAX_WORKERS_STRESS = 8
23+ WORKDIR = Path ("/tmp/ptq_concurrency_test" )
24+ REPO_ROOT = Path ("/mnt/task_runtime/llm-compressor" )
25+
26+ # ============================================================================
27+ # IMPORT LLMCOMPRESSOR
28+ # ============================================================================
29+ sys .path .insert (0 , str (REPO_ROOT / "src" ))
30+
31+ print ("Importing llmcompressor..." )
32+ from llmcompressor .entrypoints import model_free_ptq
33+ print ("Successfully imported model_free_ptq" )
34+
35+ # ============================================================================
36+ # RUN PTQ JOB
37+ # ============================================================================
38+ def run_ptq_job (output_dir , max_workers ):
39+ """Run PTQ job using the model_free_ptq API."""
40+
41+ output_dir = Path (output_dir )
42+ output_dir .mkdir (parents = True , exist_ok = True )
43+
44+ print (f"\n Running PTQ: max_workers={ max_workers } , output={ output_dir } " )
45+
46+ try :
47+ model_free_ptq (
48+ model_stub = MODEL_PATH ,
49+ save_directory = str (output_dir ),
50+ scheme = SCHEME ,
51+ max_workers = max_workers ,
52+ )
53+
54+ print ("✅ Completed successfully" )
55+ return True
56+
57+ except Exception as e :
58+ print (f" Failed with error: { e } " )
59+ import traceback
60+ traceback .print_exc ()
61+ return False
62+
63+ # ============================================================================
64+ # HASH COMPARISON
65+ # ============================================================================
66+ def compute_file_hash (filepath ):
67+ sha256 = hashlib .sha256 ()
68+ with open (filepath , "rb" ) as f :
69+ for chunk in iter (lambda : f .read (8192 ), b"" ):
70+ sha256 .update (chunk )
71+ return sha256 .hexdigest ()
72+
73+ def get_output_files (output_dir ):
74+ return sorted (glob .glob (str (Path (output_dir ) / "*.safetensors" )))
75+
76+ def compare_outputs (dir1 , dir2 ):
77+ files1 = get_output_files (dir1 )
78+ files2 = get_output_files (dir2 )
79+
80+ if len (files1 ) != len (files2 ):
81+ print (f" File count: { len (files1 )} vs { len (files2 )} " )
82+ return False
83+
84+ print (f"\n Comparing { len (files1 )} files..." )
85+ for f1 , f2 in zip (files1 , files2 ):
86+ h1 = compute_file_hash (f1 )
87+ h2 = compute_file_hash (f2 )
88+ if h1 != h2 :
89+ print (f"❌ MISMATCH: { Path (f1 ).name } " )
90+ return False
91+ print (f"✅ { Path (f1 ).name } : { h1 [:16 ]} ..." )
92+
93+ return True
94+
95+ # ============================================================================
96+ # MAIN
97+ # ============================================================================
98+ def main ():
99+ print ("=" * 60 )
100+ print (" PR #2498 Concurrency Safety Validation" )
101+ print ("=" * 60 )
102+ print (f"Model: { MODEL_PATH } " )
103+ print (f"Scheme: { SCHEME } (weight-only, no calibration)" )
104+ print (f"Workdir: { WORKDIR } " )
105+
106+ # Check CUDA
107+ try :
108+ import torch
109+ if torch .cuda .is_available ():
110+ print (f" CUDA: { torch .cuda .device_count ()} GPUs" )
111+ else :
112+ print ("CUDA not available" )
113+ except :
114+ print (" torch not available" )
115+
116+ # Cleanup
117+ if WORKDIR .exists ():
118+ shutil .rmtree (WORKDIR )
119+ WORKDIR .mkdir (parents = True , exist_ok = True )
120+
121+ # Run tests
122+ out_w1 = WORKDIR / "out_w1"
123+ out_w8 = WORKDIR / "out_w8"
124+ out_stress = WORKDIR / "out_stress"
125+
126+ print ("\n " + "=" * 60 )
127+ print (" EXPERIMENT 1: Determinism (1 vs 8 workers)" )
128+ print ("=" * 60 )
129+
130+ w1_ok = run_ptq_job (out_w1 , MAX_WORKERS_BASELINE )
131+ w8_ok = run_ptq_job (out_w8 , MAX_WORKERS_STRESS )
132+
133+ if w1_ok and w8_ok :
134+ compare_ok = compare_outputs (out_w1 , out_w8 )
135+ print (f"\n { ' EXP1 PASSED' if compare_ok else ' EXP1 FAILED' } " )
136+ else :
137+ compare_ok = False
138+ print ("\n EXP1 SKIPPED" )
139+
140+ print ("\n " + "=" * 60 )
141+ print ("📋 EXPERIMENT 2: Stress Test (16 workers)" )
142+ print ("=" * 60 )
143+
144+ stress_ok = run_ptq_job (out_stress , 16 )
145+ stress_files = get_output_files (out_stress ) if stress_ok else []
146+
147+ if stress_ok and len (stress_files ) > 0 :
148+ print (f"\n EXP2 PASSED: { len (stress_files )} files" )
149+ else :
150+ print ("\n EXP2 FAILED" )
151+
152+ # Summary
153+ print ("\n " + "=" * 60 )
154+ print ("SUMMARY" )
155+ print ("=" * 60 )
156+ print (f"Worker 1: { '✅' if w1_ok else '❌' } " )
157+ print (f"Worker 8: { '✅' if w8_ok else '❌' } " )
158+ print (f"Hash Match: { '✅' if compare_ok else '❌' } " )
159+ print (f"Stress 16: { '✅' if stress_ok else '❌' } " )
160+
161+ if compare_ok and stress_ok :
162+ print ("\n 🎉 ALL TESTS PASSED!" )
163+ print ("-" * 60 )
164+ return 0
165+ else :
166+ print ("\n ⚠️ Some tests failed" )
167+ return 1
168+
169+ if __name__ == "__main__" :
170+ sys .exit (main ())
0 commit comments