Skip to content

Commit de548ab

Browse files
Unit tests for launch script
Signed-off-by: Thara Palanivel <[email protected]>
1 parent fc78796 commit de548ab

File tree

3 files changed

+183
-7
lines changed

3 files changed

+183
-7
lines changed

fms_mo/run_quant.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,8 @@ def run_gptq(model_args, data_args, opt_args, gptq_args):
151151
quantize_config=quantize_config,
152152
torch_dtype=model_args.torch_dtype,
153153
)
154+
if model_args.device:
155+
model = model.to(model_args.device)
154156

155157
logger.info(f"Loading data from {data_args.training_data_path}")
156158
tokenizer = AutoTokenizer.from_pretrained(

fms_mo/training_args.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,12 @@ class ModelArguments:
6060
)
6161
},
6262
)
63+
device: str = field(
64+
default=None,
65+
metadata={
66+
"help": ("`torch.device`: The device on which the module is (assuming that all the module parameters are on the same device).")
67+
}
68+
)
6369

6470

6571
@dataclass

tests/build/test_launch_script.py

Lines changed: 175 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -34,19 +34,22 @@
3434
from fms_mo.utils.import_utils import available_packages
3535

3636

37-
SCRIPT = "fms_mo/run_quant.py"
37+
SCRIPT = os.path.join(os.path.dirname(__file__), "../..", "fms_mo/run_quant.py")
3838
MODEL_NAME = "Maykeye/TinyLLama-v0"
3939
BASE_KWARGS = {
40+
"accelerate_launch_args":{
41+
"num_processes": 1
42+
},
4043
"model_name_or_path": MODEL_NAME,
41-
"output_dir": "tmp",
4244
}
4345
BASE_GPTQ_KWARGS = {
4446
**BASE_KWARGS,
4547
**{
4648
"quant_method": "gptq",
4749
"bits": 4,
48-
"group_size": 128,
50+
"group_size": 64,
4951
"training_data_path": WIKITEXT_TOKENIZED_DATA_JSON,
52+
"device": "cuda"
5053
},
5154
}
5255
BASE_FP8_KWARGS = {
@@ -55,27 +58,192 @@
5558
"quant_method": "fp8",
5659
},
5760
}
58-
61+
BASE_DQ_KWARGS = {
62+
**BASE_KWARGS,
63+
**{
64+
"quant_method": "dq",
65+
"nbits_w": 8,
66+
"nbits_a": 8,
67+
"nbits_kvcache": 32,
68+
"qa_mode": "fp8_e4m3_scale",
69+
"qw_mode": "fp8_e4m3_scale",
70+
"qmodel_calibration_new": 0,
71+
"training_data_path": WIKITEXT_TOKENIZED_DATA_JSON,
72+
},
73+
}
5974

6075
def setup_env(tempdir):
61-
os.environ["TRAINING_SCRIPT"] = SCRIPT
76+
os.environ["OPTIMIZER_SCRIPT"] = SCRIPT
6277
os.environ["PYTHONPATH"] = "./:$PYTHONPATH"
6378
os.environ["TERMINATION_LOG_FILE"] = tempdir + "/termination-log"
79+
os.environ["SET_NUM_PROCESSES_TO_NUM_GPUS"] = "False"
6480

6581

6682
def cleanup_env():
6783
os.environ.pop("OPTIMIZER_SCRIPT", None)
6884
os.environ.pop("PYTHONPATH", None)
6985
os.environ.pop("TERMINATION_LOG_FILE", None)
7086

71-
### Tests for model dtype edge cases
87+
7288
@pytest.mark.skipif(not available_packages["auto_gptq"], reason="Only runs if auto-gptq package is installed")
7389
def test_successful_gptq():
7490
"""Check if we can gptq models"""
91+
with tempfile.TemporaryDirectory() as tempdir:
92+
setup_env(tempdir)
93+
GPTQ_KWARGS = {**BASE_GPTQ_KWARGS, **{"output_dir": tempdir}}
94+
serialized_args = serialize_args(GPTQ_KWARGS)
95+
os.environ["FMS_MO_CONFIG_JSON_ENV_VAR"] = serialized_args
96+
97+
assert main() == 0
98+
99+
_validate_termination_files_when_quantization_succeeds(tempdir)
100+
_validate_quantization_output(tempdir, "gptq")
101+
102+
103+
@pytest.mark.skipif(not available_packages["llmcompressor"], reason="Only runs if llm-compressor package is installed")
104+
def test_successful_fp8():
105+
"""Check if we can fp8 quantize models"""
106+
with tempfile.TemporaryDirectory() as tempdir:
107+
setup_env(tempdir)
108+
FP8_KWARGS = {**BASE_FP8_KWARGS, **{"output_dir": tempdir}}
109+
serialized_args = serialize_args(FP8_KWARGS)
110+
os.environ["FMS_MO_CONFIG_JSON_ENV_VAR"] = serialized_args
111+
112+
assert main() == 0
113+
114+
_validate_termination_files_when_quantization_succeeds(tempdir)
115+
_validate_quantization_output(tempdir, "fp8")
116+
117+
118+
def test_successful_dq():
119+
"""Check if we can dq models"""
120+
with tempfile.TemporaryDirectory() as tempdir:
121+
setup_env(tempdir)
122+
DQ_KWARGS = {**BASE_DQ_KWARGS, **{"output_dir": tempdir}}
123+
serialized_args = serialize_args(DQ_KWARGS)
124+
os.environ["FMS_MO_CONFIG_JSON_ENV_VAR"] = serialized_args
125+
126+
assert main() == 0
127+
128+
_validate_termination_files_when_quantization_succeeds(tempdir)
129+
_validate_quantization_output(tempdir, "dq")
130+
131+
132+
def test_bad_script_path():
133+
"""Check for appropriate error for an invalid optimization script location"""
75134
with tempfile.TemporaryDirectory() as tempdir:
76135
setup_env(tempdir)
77136
QUANT_KWARGS = {**BASE_KWARGS, **{"output_dir": tempdir}}
78137
serialized_args = serialize_args(QUANT_KWARGS)
79138
os.environ["FMS_MO_CONFIG_JSON_ENV_VAR"] = serialized_args
139+
os.environ["OPTIMIZER_SCRIPT"] = "/not/here"
140+
141+
with pytest.raises(SystemExit) as pytest_wrapped_e:
142+
main()
143+
assert pytest_wrapped_e.type == SystemExit
144+
assert pytest_wrapped_e.value.code == INTERNAL_ERROR_EXIT_CODE
145+
assert os.stat(tempdir + "/termination-log").st_size > 0
146+
147+
148+
def test_blank_config_json_env_var():
149+
with tempfile.TemporaryDirectory() as tempdir:
150+
setup_env(tempdir)
151+
os.environ["FMS_MO_CONFIG_JSON_ENV_VAR"] = ""
152+
with pytest.raises(SystemExit) as pytest_wrapped_e:
153+
main()
154+
assert pytest_wrapped_e.type == SystemExit
155+
assert pytest_wrapped_e.value.code == USER_ERROR_EXIT_CODE
156+
assert os.stat(tempdir + "/termination-log").st_size > 0
157+
158+
def test_blank_config_json_path():
159+
with tempfile.TemporaryDirectory() as tempdir:
160+
setup_env(tempdir)
161+
os.environ["FMS_MO_CONFIG_JSON_PATH"] = ""
162+
with pytest.raises(SystemExit) as pytest_wrapped_e:
163+
main()
164+
assert pytest_wrapped_e.type == SystemExit
165+
assert pytest_wrapped_e.value.code == USER_ERROR_EXIT_CODE
166+
assert os.stat(tempdir + "/termination-log").st_size > 0
167+
168+
def test_faulty_file_path():
169+
with tempfile.TemporaryDirectory() as tempdir:
170+
setup_env(tempdir)
171+
faulty_path = os.path.join(tempdir, "non_existent_file.pkl")
172+
QUANT_KWARGS = {
173+
**BASE_KWARGS,
174+
**{"training_data_path": faulty_path, "output_dir": tempdir},
175+
}
176+
serialized_args = serialize_args(QUANT_KWARGS)
177+
os.environ["FMS_MO_CONFIG_JSON_ENV_VAR"] = serialized_args
178+
with pytest.raises(SystemExit) as pytest_wrapped_e:
179+
main()
180+
assert pytest_wrapped_e.type == SystemExit
181+
assert pytest_wrapped_e.value.code == USER_ERROR_EXIT_CODE
182+
assert os.stat(tempdir + "/termination-log").st_size > 0
183+
184+
185+
def test_bad_base_model_path():
186+
with tempfile.TemporaryDirectory() as tempdir:
187+
setup_env(tempdir)
188+
DQ_KWARGS = {
189+
**BASE_DQ_KWARGS,
190+
**{"model_name_or_path": "/wrong/path", "output_dir": tempdir},
191+
}
192+
serialized_args = serialize_args(DQ_KWARGS)
193+
os.environ["FMS_MO_CONFIG_JSON_ENV_VAR"] = serialized_args
194+
with pytest.raises(SystemExit) as pytest_wrapped_e:
195+
main()
196+
assert pytest_wrapped_e.type == SystemExit
197+
assert pytest_wrapped_e.value.code == USER_ERROR_EXIT_CODE
198+
assert os.stat(tempdir + "/termination-log").st_size > 0
199+
200+
201+
def test_config_parsing_error():
202+
with tempfile.TemporaryDirectory() as tempdir:
203+
setup_env(tempdir)
204+
DQ_KWARGS = {**BASE_DQ_KWARGS, **{"nbits_w": "eight", "output_dir": tempdir}} # Intentional type error
205+
serialized_args = serialize_args(DQ_KWARGS)
206+
os.environ["FMS_MO_CONFIG_JSON_ENV_VAR"] = serialized_args
207+
with pytest.raises(SystemExit) as pytest_wrapped_e:
208+
main()
209+
assert pytest_wrapped_e.type == SystemExit
210+
assert pytest_wrapped_e.value.code == USER_ERROR_EXIT_CODE
211+
assert os.stat(tempdir + "/termination-log").st_size > 0
212+
213+
214+
def _validate_termination_files_when_quantization_succeeds(base_dir):
215+
# Check termination log and .complete files exist
216+
assert os.path.exists(os.path.join(base_dir, "/termination-log")) is False
217+
assert os.path.exists(os.path.join(base_dir, ".complete")) is True
218+
# assert os.path.exists(os.path.join(base_dir, training_logs_filename)) is True
219+
220+
221+
def _validate_quantization_output(base_dir, quant_method):
222+
# Check tokenizer files exist
223+
assert os.path.exists(os.path.join(base_dir, "tokenizer.json")) is True
224+
assert os.path.exists(os.path.join(base_dir, "special_tokens_map.json")) is True
225+
assert os.path.exists(os.path.join(base_dir, "tokenizer_config.json")) is True
226+
assert os.path.exists(os.path.join(base_dir, "tokenizer.model")) is True
227+
228+
# Check quantized model files exist
229+
if quant_method == "gptq":
230+
assert len(glob.glob(os.path.join(base_dir, "gptq_model-*.safetensors"))) > 0
231+
assert os.path.exists(os.path.join(base_dir, "quantize_config.json")) is True
232+
assert os.path.exists(os.path.join(base_dir, "config.json")) is True
233+
234+
elif quant_method == "fp8":
235+
assert len(glob.glob(os.path.join(base_dir, "model*.safetensors"))) > 0
236+
assert os.path.exists(os.path.join(base_dir, "generation_config.json")) is True
237+
assert os.path.exists(os.path.join(base_dir, "config.json")) is True
238+
assert os.path.exists(os.path.join(base_dir, "recipe.yaml")) is True
239+
240+
elif quant_method == "dq":
241+
assert len(glob.glob(os.path.join(base_dir, "model*.safetensors"))) > 0
242+
assert os.path.exists(os.path.join(base_dir, "generation_config.json")) is True
243+
assert os.path.exists(os.path.join(base_dir, "config.json")) is True
244+
80245

81-
assert main() == 0
246+
def test_cleanup():
247+
# Runs to unset env variables that could disrupt other tests
248+
cleanup_env()
249+
assert True

0 commit comments

Comments
 (0)