66from fms_mo import qmodel_prep
77from fms_mo .utils .import_utils import available_packages
88from fms_mo .utils .qconfig_utils import check_config , set_mx_specs
9- from tests .models .test_model_utils import delete_config , qmodule_error
9+ from tests .models .test_model_utils import delete_file , qmodule_error
1010
1111if available_packages ["mx" ]:
1212 # Local
1919 QBmmMX ,
2020 ]
2121
22+
23+ @pytest .fixture (autouse = True )
24+ def delete_files ():
25+ """
26+ Delete any known files lingering before starting test
27+ """
28+ delete_file ("qcfg.json" )
29+
30+
2231@pytest .mark .skipif (
2332 not available_packages ["mx" ],
2433 reason = "Skipping mx_specs error test; No package found" ,
@@ -42,7 +51,7 @@ def test_config_mx_specs_error(
4251 assert "mx_specs" in config_fp32_mx_specs
4352 mx_specs_temp = config_fp32_mx_specs .get ("mx_specs" )
4453
45- for key ,bad_val in bad_mx_specs_settings :
54+ for key , bad_val in bad_mx_specs_settings :
4655 # Every time we change the value, we must reset mx_specs
4756 config_fp32_mx_specs ["mx_specs" ][key ] = bad_val
4857 set_mx_specs (config_fp32_mx_specs )
@@ -53,6 +62,7 @@ def test_config_mx_specs_error(
5362 # Reset to saved value
5463 config_fp32_mx_specs ["mx_specs" ] = mx_specs_temp
5564
65+
5666@pytest .mark .skipif (
5767 not available_packages ["mx" ],
5868 reason = "Skipping mx_specs error test; No package found" ,
@@ -75,7 +85,12 @@ def test_config_mx_error(
7585
7686 assert "mx_specs" not in config_fp32_mx
7787
78- for config_key , mx_specs_key , config_bad_val , mx_specs_bad_val in bad_mx_config_settings :
88+ for (
89+ config_key ,
90+ mx_specs_key ,
91+ config_bad_val ,
92+ mx_specs_bad_val ,
93+ ) in bad_mx_config_settings :
7994 # Second check config w/ "mx_" prefix
8095 mx_temp = config_fp32_mx [config_key ]
8196
@@ -95,8 +110,7 @@ def test_config_mx_error(
95110
96111
97112@pytest .mark .skipif (
98- not torch .cuda .is_available ()
99- or not available_packages ["mx" ],
113+ not torch .cuda .is_available () or not available_packages ["mx" ],
100114 reason = "Skipped because CUDA or MX library was not available" ,
101115)
102116def test_residualMLP (
@@ -115,19 +129,21 @@ def test_residualMLP(
115129 mx_format (str): MX format for quantization.
116130 """
117131 # Remove any saved qcfg.json
118- delete_config ()
132+ delete_file ()
119133
120134 config_fp32_mx_specs ["mx_specs" ]["w_elem_format" ] = mx_format
121135 config_fp32_mx_specs ["mx_specs" ]["a_elem_format" ] = mx_format
122136 set_mx_specs (config_fp32_mx_specs )
123137
124- qmodel_prep (model_residualMLP , input_residualMLP , config_fp32_mx_specs , use_dynamo = True )
138+ qmodel_prep (
139+ model_residualMLP , input_residualMLP , config_fp32_mx_specs , use_dynamo = True
140+ )
125141 qmodule_error (model_residualMLP , 2 , 1 )
126142
127143 # One layer should be QLinearMX
128144 found_qmodule_mx = False
129145 for _ , module in model_residualMLP .named_modules ():
130- if any ( isinstance (module , qmodule_mx ) for qmodule_mx in mx_qmodules ):
146+ if any (isinstance (module , qmodule_mx ) for qmodule_mx in mx_qmodules ):
131147 found_qmodule_mx = True
132148 # Check that the desired mx format was propagated to class
133149 assert module .mx_specs ["w_elem_format" ] == mx_format
0 commit comments