3434from fms_mo .utils .import_utils import available_packages
3535
3636
37- SCRIPT = " fms_mo/run_quant.py"
37+ SCRIPT = os . path . join ( os . path . dirname ( __file__ ), "../.." , " fms_mo/run_quant.py")
3838MODEL_NAME = "Maykeye/TinyLLama-v0"
3939BASE_KWARGS = {
40+ "accelerate_launch_args" :{
41+ "num_processes" : 1
42+ },
4043 "model_name_or_path" : MODEL_NAME ,
41- "output_dir" : "tmp" ,
4244}
4345BASE_GPTQ_KWARGS = {
4446 ** BASE_KWARGS ,
4547 ** {
4648 "quant_method" : "gptq" ,
4749 "bits" : 4 ,
48- "group_size" : 128 ,
50+ "group_size" : 64 ,
4951 "training_data_path" : WIKITEXT_TOKENIZED_DATA_JSON ,
52+ "device" : "cuda"
5053 },
5154}
5255BASE_FP8_KWARGS = {
5558 "quant_method" : "fp8" ,
5659 },
5760}
58-
61+ BASE_DQ_KWARGS = {
62+ ** BASE_KWARGS ,
63+ ** {
64+ "quant_method" : "dq" ,
65+ "nbits_w" : 8 ,
66+ "nbits_a" : 8 ,
67+ "nbits_kvcache" : 32 ,
68+ "qa_mode" : "fp8_e4m3_scale" ,
69+ "qw_mode" : "fp8_e4m3_scale" ,
70+ "qmodel_calibration_new" : 0 ,
71+ "training_data_path" : WIKITEXT_TOKENIZED_DATA_JSON ,
72+ },
73+ }
5974
6075def setup_env (tempdir ):
61- os .environ ["TRAINING_SCRIPT " ] = SCRIPT
76+ os .environ ["OPTIMIZER_SCRIPT " ] = SCRIPT
6277 os .environ ["PYTHONPATH" ] = "./:$PYTHONPATH"
6378 os .environ ["TERMINATION_LOG_FILE" ] = tempdir + "/termination-log"
79+ os .environ ["SET_NUM_PROCESSES_TO_NUM_GPUS" ] = "False"
6480
6581
6682def cleanup_env ():
6783 os .environ .pop ("OPTIMIZER_SCRIPT" , None )
6884 os .environ .pop ("PYTHONPATH" , None )
6985 os .environ .pop ("TERMINATION_LOG_FILE" , None )
7086
71- ### Tests for model dtype edge cases
87+
7288@pytest .mark .skipif (not available_packages ["auto_gptq" ], reason = "Only runs if auto-gptq package is installed" )
7389def test_successful_gptq ():
7490 """Check if we can gptq models"""
91+ with tempfile .TemporaryDirectory () as tempdir :
92+ setup_env (tempdir )
93+ GPTQ_KWARGS = {** BASE_GPTQ_KWARGS , ** {"output_dir" : tempdir }}
94+ serialized_args = serialize_args (GPTQ_KWARGS )
95+ os .environ ["FMS_MO_CONFIG_JSON_ENV_VAR" ] = serialized_args
96+
97+ assert main () == 0
98+
99+ _validate_termination_files_when_quantization_succeeds (tempdir )
100+ _validate_quantization_output (tempdir , "gptq" )
101+
102+
103+ @pytest .mark .skipif (not available_packages ["llmcompressor" ], reason = "Only runs if llm-compressor package is installed" )
104+ def test_successful_fp8 ():
105+ """Check if we can fp8 quantize models"""
106+ with tempfile .TemporaryDirectory () as tempdir :
107+ setup_env (tempdir )
108+ FP8_KWARGS = {** BASE_FP8_KWARGS , ** {"output_dir" : tempdir }}
109+ serialized_args = serialize_args (FP8_KWARGS )
110+ os .environ ["FMS_MO_CONFIG_JSON_ENV_VAR" ] = serialized_args
111+
112+ assert main () == 0
113+
114+ _validate_termination_files_when_quantization_succeeds (tempdir )
115+ _validate_quantization_output (tempdir , "fp8" )
116+
117+
118+ def test_successful_dq ():
119+ """Check if we can dq models"""
120+ with tempfile .TemporaryDirectory () as tempdir :
121+ setup_env (tempdir )
122+ DQ_KWARGS = {** BASE_DQ_KWARGS , ** {"output_dir" : tempdir }}
123+ serialized_args = serialize_args (DQ_KWARGS )
124+ os .environ ["FMS_MO_CONFIG_JSON_ENV_VAR" ] = serialized_args
125+
126+ assert main () == 0
127+
128+ _validate_termination_files_when_quantization_succeeds (tempdir )
129+ _validate_quantization_output (tempdir , "dq" )
130+
131+
132+ def test_bad_script_path ():
133+ """Check for appropriate error for an invalid optimization script location"""
75134 with tempfile .TemporaryDirectory () as tempdir :
76135 setup_env (tempdir )
77136 QUANT_KWARGS = {** BASE_KWARGS , ** {"output_dir" : tempdir }}
78137 serialized_args = serialize_args (QUANT_KWARGS )
79138 os .environ ["FMS_MO_CONFIG_JSON_ENV_VAR" ] = serialized_args
139+ os .environ ["OPTIMIZER_SCRIPT" ] = "/not/here"
140+
141+ with pytest .raises (SystemExit ) as pytest_wrapped_e :
142+ main ()
143+ assert pytest_wrapped_e .type == SystemExit
144+ assert pytest_wrapped_e .value .code == INTERNAL_ERROR_EXIT_CODE
145+ assert os .stat (tempdir + "/termination-log" ).st_size > 0
146+
147+
148+ def test_blank_config_json_env_var ():
149+ with tempfile .TemporaryDirectory () as tempdir :
150+ setup_env (tempdir )
151+ os .environ ["FMS_MO_CONFIG_JSON_ENV_VAR" ] = ""
152+ with pytest .raises (SystemExit ) as pytest_wrapped_e :
153+ main ()
154+ assert pytest_wrapped_e .type == SystemExit
155+ assert pytest_wrapped_e .value .code == USER_ERROR_EXIT_CODE
156+ assert os .stat (tempdir + "/termination-log" ).st_size > 0
157+
158+ def test_blank_config_json_path ():
159+ with tempfile .TemporaryDirectory () as tempdir :
160+ setup_env (tempdir )
161+ os .environ ["FMS_MO_CONFIG_JSON_PATH" ] = ""
162+ with pytest .raises (SystemExit ) as pytest_wrapped_e :
163+ main ()
164+ assert pytest_wrapped_e .type == SystemExit
165+ assert pytest_wrapped_e .value .code == USER_ERROR_EXIT_CODE
166+ assert os .stat (tempdir + "/termination-log" ).st_size > 0
167+
168+ def test_faulty_file_path ():
169+ with tempfile .TemporaryDirectory () as tempdir :
170+ setup_env (tempdir )
171+ faulty_path = os .path .join (tempdir , "non_existent_file.pkl" )
172+ QUANT_KWARGS = {
173+ ** BASE_KWARGS ,
174+ ** {"training_data_path" : faulty_path , "output_dir" : tempdir },
175+ }
176+ serialized_args = serialize_args (QUANT_KWARGS )
177+ os .environ ["FMS_MO_CONFIG_JSON_ENV_VAR" ] = serialized_args
178+ with pytest .raises (SystemExit ) as pytest_wrapped_e :
179+ main ()
180+ assert pytest_wrapped_e .type == SystemExit
181+ assert pytest_wrapped_e .value .code == USER_ERROR_EXIT_CODE
182+ assert os .stat (tempdir + "/termination-log" ).st_size > 0
183+
184+
185+ def test_bad_base_model_path ():
186+ with tempfile .TemporaryDirectory () as tempdir :
187+ setup_env (tempdir )
188+ DQ_KWARGS = {
189+ ** BASE_DQ_KWARGS ,
190+ ** {"model_name_or_path" : "/wrong/path" , "output_dir" : tempdir },
191+ }
192+ serialized_args = serialize_args (DQ_KWARGS )
193+ os .environ ["FMS_MO_CONFIG_JSON_ENV_VAR" ] = serialized_args
194+ with pytest .raises (SystemExit ) as pytest_wrapped_e :
195+ main ()
196+ assert pytest_wrapped_e .type == SystemExit
197+ assert pytest_wrapped_e .value .code == USER_ERROR_EXIT_CODE
198+ assert os .stat (tempdir + "/termination-log" ).st_size > 0
199+
200+
201+ def test_config_parsing_error ():
202+ with tempfile .TemporaryDirectory () as tempdir :
203+ setup_env (tempdir )
204+ DQ_KWARGS = {** BASE_DQ_KWARGS , ** {"nbits_w" : "eight" , "output_dir" : tempdir }} # Intentional type error
205+ serialized_args = serialize_args (DQ_KWARGS )
206+ os .environ ["FMS_MO_CONFIG_JSON_ENV_VAR" ] = serialized_args
207+ with pytest .raises (SystemExit ) as pytest_wrapped_e :
208+ main ()
209+ assert pytest_wrapped_e .type == SystemExit
210+ assert pytest_wrapped_e .value .code == USER_ERROR_EXIT_CODE
211+ assert os .stat (tempdir + "/termination-log" ).st_size > 0
212+
213+
214+ def _validate_termination_files_when_quantization_succeeds (base_dir ):
215+ # Check termination log and .complete files exist
216+ assert os .path .exists (os .path .join (base_dir , "/termination-log" )) is False
217+ assert os .path .exists (os .path .join (base_dir , ".complete" )) is True
218+ # assert os.path.exists(os.path.join(base_dir, training_logs_filename)) is True
219+
220+
221+ def _validate_quantization_output (base_dir , quant_method ):
222+ # Check tokenizer files exist
223+ assert os .path .exists (os .path .join (base_dir , "tokenizer.json" )) is True
224+ assert os .path .exists (os .path .join (base_dir , "special_tokens_map.json" )) is True
225+ assert os .path .exists (os .path .join (base_dir , "tokenizer_config.json" )) is True
226+ assert os .path .exists (os .path .join (base_dir , "tokenizer.model" )) is True
227+
228+ # Check quantized model files exist
229+ if quant_method == "gptq" :
230+ assert len (glob .glob (os .path .join (base_dir , "gptq_model-*.safetensors" ))) > 0
231+ assert os .path .exists (os .path .join (base_dir , "quantize_config.json" )) is True
232+ assert os .path .exists (os .path .join (base_dir , "config.json" )) is True
233+
234+ elif quant_method == "fp8" :
235+ assert len (glob .glob (os .path .join (base_dir , "model*.safetensors" ))) > 0
236+ assert os .path .exists (os .path .join (base_dir , "generation_config.json" )) is True
237+ assert os .path .exists (os .path .join (base_dir , "config.json" )) is True
238+ assert os .path .exists (os .path .join (base_dir , "recipe.yaml" )) is True
239+
240+ elif quant_method == "dq" :
241+ assert len (glob .glob (os .path .join (base_dir , "model*.safetensors" ))) > 0
242+ assert os .path .exists (os .path .join (base_dir , "generation_config.json" )) is True
243+ assert os .path .exists (os .path .join (base_dir , "config.json" )) is True
244+
80245
81- assert main () == 0
246+ def test_cleanup ():
247+ # Runs to unset env variables that could disrupt other tests
248+ cleanup_env ()
249+ assert True
0 commit comments