Skip to content

Commit 026201a

Browse files
Adding unit tests for run_quant.py
Signed-off-by: Thara Palanivel <[email protected]>
1 parent 3833f21 commit 026201a

File tree

2 files changed

+137
-1
lines changed

2 files changed

+137
-1
lines changed

fms_mo/utils/dq_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def config_quantize_smooth_layers(qcfg):
3838
"granite-20b-code",
3939
"granite-20b-code",
4040
]
41-
if any(model in qcfg["model"] for model in llama_architecture):
41+
if any(model in qcfg["model"] for model in llama_architecture) or any(model in qcfg["model_type"] for model in llama_architecture):
4242
qcfg["qlayer_name_pattern"] = ["model.layers."]
4343
qcfg["scale_layers"] = ["k_proj", "v_proj", "gate_proj", "up_proj"]
4444
qcfg["qskip_layer_name"] = []

tests/test_run_quant.py

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
# Copyright The FMS Model Optimizer Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Unit tests for run_quant.py"""
16+
17+
# Standard
18+
import copy
19+
import json
20+
import os
21+
22+
# Third Party
23+
import pytest
24+
25+
# Local
26+
from fms_mo.run_quant import get_parser, parse_arguments, quantize
27+
from fms_mo.training_args import (
28+
DataArguments,
29+
FMSMOArguments,
30+
FP8Arguments,
31+
GPTQArguments,
32+
ModelArguments,
33+
OptArguments,
34+
)
35+
from tests.artifacts.testdata import MODEL_NAME, WIKITEXT_TOKENIZED_DATA_JSON
36+
37+
MODEL_ARGS = ModelArguments(
38+
model_name_or_path=MODEL_NAME, torch_dtype="float16"
39+
)
40+
DATA_ARGS = DataArguments(
41+
training_data_path=WIKITEXT_TOKENIZED_DATA_JSON,
42+
)
43+
OPT_ARGS = OptArguments(quant_method="dq", output_dir="tmp")
44+
GPTQ_ARGS = GPTQArguments(
45+
bits=4,
46+
group_size=64,
47+
)
48+
FP8_ARGS = FP8Arguments()
49+
DQ_ARGS = FMSMOArguments(
50+
nbits_w=8,
51+
nbits_a=8,
52+
nbits_kvcache=32,
53+
qa_mode="fp8_e4m3_scale",
54+
qw_mode="fp8_e4m3_scale",
55+
qmodel_calibration_new=0,
56+
)
57+
58+
59+
def test_run_train_requires_output_dir():
60+
"""Check fails when output dir not provided."""
61+
updated_output_dir_opt_args = copy.deepcopy(OPT_ARGS)
62+
updated_output_dir_opt_args.output_dir = None
63+
with pytest.raises(TypeError):
64+
quantize(
65+
model_args=MODEL_ARGS,
66+
data_args=DATA_ARGS,
67+
opt_args=updated_output_dir_opt_args,
68+
fms_mo_args=DQ_ARGS,
69+
)
70+
71+
72+
def test_run_train_fails_training_data_path_not_exist():
73+
"""Check fails when data path not found."""
74+
updated_data_path_args = copy.deepcopy(DATA_ARGS)
75+
updated_data_path_args.training_data_path = "fake/path"
76+
with pytest.raises(FileNotFoundError): # TPP Should this be FileNotFoundError or ValueError?
77+
quantize(
78+
model_args=MODEL_ARGS,
79+
data_args=updated_data_path_args,
80+
opt_args=OPT_ARGS,
81+
fms_mo_args=DQ_ARGS,
82+
)
83+
84+
85+
HAPPY_PATH_DUMMY_CONFIG_PATH = os.path.join(
86+
os.path.dirname(__file__), "artifacts", "configs", "dummy_job_config.json"
87+
)
88+
89+
90+
# Note: job_config dict gets modified during process training args
91+
@pytest.fixture(name="job_config", scope="session")
92+
def fixture_job_config():
93+
with open(HAPPY_PATH_DUMMY_CONFIG_PATH, "r", encoding="utf-8") as f:
94+
dummy_job_config_dict = json.load(f)
95+
return dummy_job_config_dict
96+
97+
98+
############################# Arg Parsing Tests #############################
99+
100+
101+
def test_parse_arguments(job_config):
102+
parser = get_parser()
103+
job_config_copy = copy.deepcopy(job_config)
104+
(
105+
model_args,
106+
data_args,
107+
opt_args,
108+
fms_mo_args,
109+
gptq_args,
110+
fp8_args,
111+
) = parse_arguments(parser, job_config_copy)
112+
assert str(model_args.torch_dtype) == "torch.bfloat16"
113+
assert data_args.training_data_path == "data_train"
114+
assert opt_args.output_dir == "models/Maykeye/TinyLLama-v0-GPTQ"
115+
assert opt_args.quant_method == "gptq"
116+
117+
118+
def test_parse_arguments_defaults(job_config):
119+
parser = get_parser()
120+
job_config_defaults = copy.deepcopy(job_config)
121+
assert "torch_dtype" not in job_config_defaults
122+
assert "max_seq_length" not in job_config_defaults
123+
assert "model_revision" not in job_config_defaults
124+
assert "nbits_kvcache" not in job_config_defaults
125+
(
126+
model_args,
127+
data_args,
128+
opt_args,
129+
fms_mo_args,
130+
gptq_args,
131+
fp8_args,
132+
) = parse_arguments(parser, job_config_defaults)
133+
assert str(model_args.torch_dtype) == "torch.bfloat16"
134+
assert model_args.model_revision == "main"
135+
assert data_args.max_seq_length == 2048
136+
assert fms_mo_args.nbits_kvcache == 32

0 commit comments

Comments
 (0)