Skip to content

Commit d4e713a

Browse files
committed
fix: Changed test_mx using delete_config over new delete_file + formatting
Signed-off-by: Brandon Groth <[email protected]>
1 parent daf9f97 commit d4e713a

File tree

2 files changed

+24
-9
lines changed

2 files changed

+24
-9
lines changed

tests/models/conftest.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@
3838
import pytest
3939
import torch
4040
import torch.nn.functional as F
41-
from torch.utils.data import TensorDataset, DataLoader
4241

4342
# Local
4443
# fms_mo imports

tests/models/test_mx.py

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from fms_mo import qmodel_prep
77
from fms_mo.utils.import_utils import available_packages
88
from fms_mo.utils.qconfig_utils import check_config, set_mx_specs
9-
from tests.models.test_model_utils import delete_config, qmodule_error
9+
from tests.models.test_model_utils import delete_file, qmodule_error
1010

1111
if available_packages["mx"]:
1212
# Local
@@ -19,6 +19,15 @@
1919
QBmmMX,
2020
]
2121

22+
23+
@pytest.fixture(autouse=True)
24+
def delete_files():
25+
"""
26+
Delete any known files lingering before starting test
27+
"""
28+
delete_file("qcfg.json")
29+
30+
2231
@pytest.mark.skipif(
2332
not available_packages["mx"],
2433
reason="Skipping mx_specs error test; No package found",
@@ -42,7 +51,7 @@ def test_config_mx_specs_error(
4251
assert "mx_specs" in config_fp32_mx_specs
4352
mx_specs_temp = config_fp32_mx_specs.get("mx_specs")
4453

45-
for key,bad_val in bad_mx_specs_settings:
54+
for key, bad_val in bad_mx_specs_settings:
4655
# Every time we change the value, we must reset mx_specs
4756
config_fp32_mx_specs["mx_specs"][key] = bad_val
4857
set_mx_specs(config_fp32_mx_specs)
@@ -53,6 +62,7 @@ def test_config_mx_specs_error(
5362
# Reset to saved value
5463
config_fp32_mx_specs["mx_specs"] = mx_specs_temp
5564

65+
5666
@pytest.mark.skipif(
5767
not available_packages["mx"],
5868
reason="Skipping mx_specs error test; No package found",
@@ -75,7 +85,12 @@ def test_config_mx_error(
7585

7686
assert "mx_specs" not in config_fp32_mx
7787

78-
for config_key, mx_specs_key, config_bad_val, mx_specs_bad_val in bad_mx_config_settings:
88+
for (
89+
config_key,
90+
mx_specs_key,
91+
config_bad_val,
92+
mx_specs_bad_val,
93+
) in bad_mx_config_settings:
7994
# Second check config w/ "mx_" prefix
8095
mx_temp = config_fp32_mx[config_key]
8196

@@ -95,8 +110,7 @@ def test_config_mx_error(
95110

96111

97112
@pytest.mark.skipif(
98-
not torch.cuda.is_available()
99-
or not available_packages["mx"],
113+
not torch.cuda.is_available() or not available_packages["mx"],
100114
reason="Skipped because CUDA or MX library was not available",
101115
)
102116
def test_residualMLP(
@@ -115,19 +129,21 @@ def test_residualMLP(
115129
mx_format (str): MX format for quantization.
116130
"""
117131
# Remove any saved qcfg.json
118-
delete_config()
132+
delete_file()
119133

120134
config_fp32_mx_specs["mx_specs"]["w_elem_format"] = mx_format
121135
config_fp32_mx_specs["mx_specs"]["a_elem_format"] = mx_format
122136
set_mx_specs(config_fp32_mx_specs)
123137

124-
qmodel_prep(model_residualMLP, input_residualMLP, config_fp32_mx_specs, use_dynamo=True)
138+
qmodel_prep(
139+
model_residualMLP, input_residualMLP, config_fp32_mx_specs, use_dynamo=True
140+
)
125141
qmodule_error(model_residualMLP, 2, 1)
126142

127143
# One layer should be QLinearMX
128144
found_qmodule_mx = False
129145
for _, module in model_residualMLP.named_modules():
130-
if any( isinstance(module, qmodule_mx) for qmodule_mx in mx_qmodules ):
146+
if any(isinstance(module, qmodule_mx) for qmodule_mx in mx_qmodules):
131147
found_qmodule_mx = True
132148
# Check that the desired mx format was propagated to class
133149
assert module.mx_specs["w_elem_format"] == mx_format

0 commit comments

Comments
 (0)