Skip to content

Commit ae67c8c

Browse files
committed
test: Moved file deletion to autouse fixture, fixed torch.load warning, and changed qconfig_save warning to logger.info
Signed-off-by: Brandon Groth <[email protected]>
1 parent 624718d commit ae67c8c

File tree

6 files changed

+43
-84
lines changed

6 files changed

+43
-84
lines changed

fms_mo/utils/qconfig_utils.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -816,8 +816,7 @@ def qconfig_save(
816816

817817
# Save config as json
818818
if os.path.isfile(fname):
819-
message = f"{fname} already exist, will overwrite."
820-
warnings.warn(message, UserWarning)
819+
logger.info(f"{fname} already exist, will overwrite.")
821820
with open(fname, "w", encoding="utf-8") as outfile:
822821
json.dump(temp_qcfg, outfile, indent=4)
823822

tests/models/test_model_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -140,9 +140,9 @@ def qmodule_error(
140140
###############################
141141

142142

143-
def delete_config(file_path: str = "qcfg.json"):
143+
def delete_file(file_path: str = "qcfg.json"):
144144
"""
145-
Delete a qconfig at the file path provided
145+
Delete a file at the file path provided
146146
147147
Args:
148148
file_path (str, optional): Qconfig file to delete. Defaults to "qcfg.json".
@@ -214,7 +214,7 @@ def load_state_dict(fname: str = "qmodel_for_aiu.pt") -> dict:
214214
Returns:
215215
dict: Model state dictionary
216216
"""
217-
return torch.load(fname)
217+
return torch.load(fname, weights_only=True)
218218

219219

220220
def check_linear_dtypes(state_dict: dict, linear_names: list):

tests/models/test_qmodelprep.py

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,16 @@
2727
from fms_mo import qconfig_init, qmodel_prep
2828
from fms_mo.prep import has_quantized_module
2929
from fms_mo.utils.utils import patch_torch_bmm
30-
from tests.models.test_model_utils import count_qmodules, delete_config, qmodule_error
30+
from tests.models.test_model_utils import count_qmodules, delete_file, qmodule_error
31+
32+
33+
@pytest.fixture(autouse=True)
34+
def delete_files():
35+
"""
36+
Delete any known files lingering before starting test
37+
"""
38+
delete_file("qcfg.json")
39+
3140

3241
################
3342
# Qmodel tests #
@@ -49,7 +58,6 @@ def test_model_quantized(
4958
sample_input_fp32 (torch.FloatTensor): Sample fp32 input for calibration.
5059
config_fp32 (dict): Config w/ fp32 settings
5160
"""
52-
delete_config()
5361
with pytest.raises(RuntimeError):
5462
qmodel_prep(model_quantized, sample_input_fp32, config_fp32)
5563

@@ -81,11 +89,9 @@ def test_double_qmodel_prep_assert(
8189
sample_input_fp32 (torch.FloatTensor): Sample fp32 input for calibration
8290
config_fp32 (dict): Config w/ fp32 settings
8391
"""
84-
delete_config()
85-
8692
# Run qmodel_prep once
8793
qmodel_prep(model_fp32, sample_input_fp32, config_fp32)
88-
delete_config()
94+
delete_file()
8995

9096
# If model now has a quantized node, ensure it has raises a RuntimeError
9197
if has_quantized_module(model_fp32):
@@ -107,7 +113,6 @@ def test_config_recipes_fp32(
107113
sample_input_fp32 (torch.FloatTensor): Sample fp32 input for calibration
108114
config (dict): Recipe Config w/ int8 settings
109115
"""
110-
delete_config()
111116
qmodel_prep(model_fp32, sample_input_fp32, config_int8)
112117

113118

@@ -124,7 +129,6 @@ def test_config_recipes_fp16(
124129
sample_input_fp16 (torch.FloatTensor): Sample fp16 input for calibration
125130
config (dict): Recipe Config w/ int8 settings
126131
"""
127-
delete_config()
128132
qmodel_prep(model_fp16, sample_input_fp16, config_int8)
129133

130134

@@ -145,7 +149,6 @@ def test_config_fp32_qmodes(
145149
qa_mode (str): Activation quantizer
146150
qw_mode (str): Weight quantizer
147151
"""
148-
delete_config()
149152
config_int8["qa_mode"] = qa_mode
150153
config_int8["qw_mode"] = qw_mode
151154
qmodel_prep(model_config_fp32, sample_input_fp32, config_int8)
@@ -170,7 +173,6 @@ def test_resnet50_torchscript(
170173
config (dict): Recipe Config w/ int8 settings
171174
"""
172175
# Run qmodel_prep w/ default torchscript tracer
173-
delete_config()
174176
qmodel_prep(model_resnet, batch_resnet, config_int8, use_dynamo=False)
175177
qmodule_error(model_resnet, 6, 48)
176178

@@ -189,7 +191,6 @@ def test_resnet50_dynamo(
189191
config (dict): Recipe Config w/ int8 settings
190192
"""
191193
# Run qmodel_prep w/ Dynamo tracer
192-
delete_config()
193194
qmodel_prep(model_resnet, batch_resnet, config_int8, use_dynamo=True)
194195
qmodule_error(model_resnet, 6, 48)
195196

@@ -209,7 +210,6 @@ def test_resnet50_dynamo_layers(
209210
config (dict): Recipe Config w/ int8 settings
210211
"""
211212
# Run qmodel_prep w/ qlayer_name_pattern + Dynamo tracer
212-
delete_config()
213213
config_int8["qlayer_name_pattern"] = ["layer[1,2,4]"] # allow regex
214214
qmodel_prep(model_resnet, batch_resnet, config_int8, use_dynamo=True)
215215
qmodule_error(model_resnet, 21, 33)
@@ -230,7 +230,6 @@ def test_vit_torchscript(
230230
config (dict): Recipe Config w/ int8 settings
231231
"""
232232
# Run qmodel_prep w/ default torchscript tracer
233-
delete_config()
234233
qmodel_prep(model_vit, batch_vit, config_int8, use_dynamo=False)
235234
qmodule_error(model_vit, 2, 36)
236235

@@ -249,7 +248,6 @@ def test_vit_dynamo(
249248
config (dict): Recipe Config w/ int8 settings
250249
"""
251250
# Run qmodel_prep w/ Dynamo tracer
252-
delete_config()
253251
qmodel_prep(model_vit, batch_vit, config_int8, use_dynamo=True)
254252
qmodule_error(model_vit, 2, 36)
255253

@@ -268,7 +266,6 @@ def test_bert_dynamo(
268266
config (dict): Recipe Config w/ int8 settings
269267
"""
270268
# Run qmodel_prep w/ Dynamo tracer
271-
delete_config()
272269
qmodel_prep(model_bert, input_bert, config_int8, use_dynamo=True)
273270
qmodule_error(model_bert, 1, 72)
274271

@@ -295,7 +292,6 @@ def test_bert_dynamo_wi_qbmm(
295292
input_bert (torch.FloatTensor): Tokenized input for BERT
296293
config (dict): Recipe Config w/ int8 settings
297294
"""
298-
delete_config()
299295
config_int8["nbits_bmm1"] = 8
300296
config_int8["nbits_bmm2"] = 8
301297
qmodel_prep(model_bert_eager, input_bert, config_int8, use_dynamo=True)

tests/models/test_save_aiu.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import pytest
44

55
# Local
6-
from .test_model_utils import check_linear_dtypes, delete_config, load_state_dict
6+
from .test_model_utils import check_linear_dtypes, delete_file, load_state_dict
77
from fms_mo import qmodel_prep
88
from fms_mo.utils.aiu_utils import save_for_aiu
99

@@ -13,9 +13,9 @@ def delete_files():
1313
"""
1414
Delete any known files lingering before starting test
1515
"""
16-
delete_config("qcfg.json")
17-
delete_config("keys_to_save.json")
18-
delete_config("qmodel_for_aiu.pt")
16+
delete_file("qcfg.json")
17+
delete_file("keys_to_save.json")
18+
delete_file("qmodel_for_aiu.pt")
1919

2020

2121
def test_save_model_bert(

tests/models/test_saveconfig.py

Lines changed: 14 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,22 @@
2222
# Local
2323
from fms_mo.utils.qconfig_utils import qconfig_load, qconfig_save
2424
from tests.models.test_model_utils import (
25-
delete_config,
25+
delete_file,
2626
load_json,
2727
save_json,
2828
save_serialized_json,
2929
)
3030

31+
32+
@pytest.fixture(autouse=True)
33+
def delete_files():
34+
"""
35+
Delete any known files lingering before starting test
36+
"""
37+
delete_file("qcfg.json")
38+
delete_file("keys_to_save.json")
39+
40+
3141
#########
3242
# Tests #
3343
#########
@@ -45,7 +55,6 @@ def test_save_config_warn_bad_pair(
4555
bad_pair (tuple): A pair that can't be serialized for qconfig_save
4656
"""
4757
key, val = bad_pair
48-
delete_config()
4958

5059
# Add bad key,val pair and save ; should generate UserWarning(s) for removing bad pair
5160
config_fp32[key] = val
@@ -56,8 +65,6 @@ def test_save_config_warn_bad_pair(
5665
loaded_config = load_json("qcfg.json") # load json as is - do not modify
5766
assert key not in loaded_config
5867

59-
delete_config()
60-
6168

6269
def test_save_config_wanted_pairs(
6370
config_fp32: dict,
@@ -71,7 +78,6 @@ def test_save_config_wanted_pairs(
7178
wanted_pair (tuple): A pair that needs to be re-init if not present in qconfig_save
7279
"""
7380
key, default_val = wanted_pair
74-
delete_config()
7581

7682
# Delete wanted pair from config and save ; should be reset to default
7783
if key in config_fp32:
@@ -82,8 +88,6 @@ def test_save_config_wanted_pairs(
8288
loaded_config = load_json()
8389
assert loaded_config.get(key) == default_val
8490

85-
delete_config()
86-
8791

8892
def test_save_config_with_qcfg_save(
8993
config_fp32: dict,
@@ -96,7 +100,6 @@ def test_save_config_with_qcfg_save(
96100
config_fp32 (dict): Config for fp32 quantization
97101
save_list (list): List of variables to save in a quantized config.
98102
"""
99-
delete_config()
100103
config_fp32["keys_to_save"] = save_list
101104

102105
qconfig_save(config_fp32, minimal=False)
@@ -114,7 +117,6 @@ def test_save_config_with_qcfg_save(
114117
assert key in loaded_config
115118
assert loaded_config.get(key) == config_fp32.get(key)
116119

117-
delete_config()
118120
del config_fp32["keys_to_save"]
119121

120122

@@ -129,10 +131,6 @@ def test_save_config_with_recipe_save(
129131
config_fp32 (dict): Config for fp32 quantization
130132
save_list (list): List of variables to save in a quantized config.
131133
"""
132-
# Delete both qcfg and the save.json before starting
133-
delete_config()
134-
delete_config("keys_to_save.json")
135-
136134
# Save new "save.json"
137135
save_path = "keys_to_save.json"
138136
save_json(save_list, file_path=save_path)
@@ -153,9 +151,6 @@ def test_save_config_with_recipe_save(
153151
assert key in loaded_config
154152
assert loaded_config.get(key) == config_fp32.get(key)
155153

156-
delete_config()
157-
delete_config("keys_to_save.json")
158-
159154

160155
def test_save_config_minimal(
161156
config_fp32: dict,
@@ -166,8 +161,6 @@ def test_save_config_minimal(
166161
Args:
167162
config_fp32 (dict): Config for fp32 quantization
168163
"""
169-
delete_config()
170-
171164
qconfig_save(config_fp32, minimal=True)
172165

173166
# Check that saved qcfg matches
@@ -180,8 +173,6 @@ def test_save_config_minimal(
180173
# No items should exist - default config should be completely removed
181174
assert len(loaded_config) == 0
182175

183-
delete_config()
184-
185176

186177
def test_double_qconfig_save(
187178
config_fp32: dict,
@@ -192,14 +183,8 @@ def test_double_qconfig_save(
192183
Args:
193184
config_fp32 (dict): Config for fp32 quantization
194185
"""
195-
delete_config()
196-
197-
# Creating a qcfg, then saving again will cause a warning -> ignore it
198-
with pytest.warns(UserWarning, match="qcfg.json already exist, will overwrite."):
199-
qconfig_save(config_fp32, minimal=False)
200-
qconfig_save(config_fp32, minimal=False)
201-
202-
delete_config()
186+
qconfig_save(config_fp32, minimal=False)
187+
qconfig_save(config_fp32, minimal=False)
203188

204189

205190
def test_qconfig_save_list_as_dict(
@@ -211,7 +196,7 @@ def test_qconfig_save_list_as_dict(
211196
Args:
212197
config_fp32 (dict): Config for fp32 quantization
213198
"""
214-
delete_config()
199+
delete_file()
215200

216201
# Fill in keys_to_save as dict with nonsense val
217202
config_fp32["keys_to_save"] = {
@@ -226,8 +211,6 @@ def test_qconfig_save_list_as_dict(
226211
with pytest.raises(ValueError):
227212
qconfig_save(config_fp32, minimal=True)
228213

229-
delete_config()
230-
231214

232215
def test_qconfig_save_recipe_as_dict(
233216
config_fp32: dict,
@@ -238,8 +221,6 @@ def test_qconfig_save_recipe_as_dict(
238221
Args:
239222
config_fp32 (dict): Config for fp32 quantization
240223
"""
241-
delete_config()
242-
243224
# Fill in keys_to_save as dict with nonsense val
244225
save_dict = {
245226
"qa_mode": None,
@@ -254,8 +235,6 @@ def test_qconfig_save_recipe_as_dict(
254235
with pytest.raises(ValueError):
255236
qconfig_save(config_fp32, recipe="keys_to_save.json", minimal=True)
256237

257-
delete_config()
258-
259238

260239
def test_qconfig_load_with_recipe_as_list(
261240
config_fp32: dict,
@@ -266,17 +245,13 @@ def test_qconfig_load_with_recipe_as_list(
266245
Args:
267246
config_fp32 (dict): Config for fp32 quantization
268247
"""
269-
delete_config()
270-
271248
config_list = list(config_fp32.keys())
272249

273250
save_json(config_list, file_path="qcfg.json")
274251

275252
with pytest.raises(ValueError):
276253
_ = qconfig_load(fname="qcfg.json")
277254

278-
delete_config()
279-
280255

281256
def test_load_config_restored_pair(
282257
config_fp32: dict,
@@ -290,7 +265,6 @@ def test_load_config_restored_pair(
290265
wanted_pair (tuple): A pair that needs to be re-init if not present in qconfig_load
291266
"""
292267
key, default_val = wanted_pair
293-
delete_config()
294268

295269
if key in config_fp32:
296270
del config_fp32[key]
@@ -302,8 +276,6 @@ def test_load_config_restored_pair(
302276
loaded_config = qconfig_load("qcfg.json")
303277
assert loaded_config.get(key) == default_val
304278

305-
delete_config()
306-
307279

308280
def test_load_config_required_pair(
309281
config_fp32: dict,
@@ -317,7 +289,6 @@ def test_load_config_required_pair(
317289
required_pair (tuple): A pair that needs to be re-init if not present in qconfig_load
318290
"""
319291
key, default_val = required_pair
320-
delete_config()
321292

322293
if key in config_fp32:
323294
del config_fp32[key]

0 commit comments

Comments
 (0)