@@ -312,6 +312,7 @@ def not_which2patch_contextmanager_settings():
312312 """
313313 return ["torch.vmm" , "torch.natnul" , "None" ]
314314
315+
315316@pytest .fixture (scope = "session" )
316317def bad_mx_specs_settings ():
317318 """
@@ -330,6 +331,7 @@ def bad_mx_specs_settings():
330331 ("custom_cuda" , "yes" ),
331332 ]
332333
334+
333335@pytest .fixture (scope = "session" )
334336def bad_mx_config_settings ():
335337 """
@@ -348,6 +350,7 @@ def bad_mx_config_settings():
348350 ("mx_custom_cuda" , "custom_cuda" , "yes" , "yes" ),
349351 ]
350352
353+
351354################################
352355# Toy Model Classes + Fixtures #
353356################################
@@ -464,6 +467,7 @@ def forward(self, input_tensor):
464467 out = self .fourth_layer (out )
465468 return out
466469
470+
467471model_fp32_params = [
468472 ToyModel1 (),
469473 ToyModel2 (),
@@ -819,11 +823,12 @@ def config_fp32(request):
819823 qconfig = request .param
820824 return deepcopy (qconfig )
821825
826+
822827@pytest .fixture (scope = "function" , params = default_config_params )
823828def config_fp32_mx (request ):
824829 """
825830 Create fp32 qconfig w/ mx_specs vars set in qconfig.
826-
831+
827832 Args:
828833 request (dict): qconfig_init
829834
@@ -856,11 +861,12 @@ def config_fp32_mx(request):
856861
857862 return qconfig
858863
864+
859865@pytest .fixture (scope = "function" , params = mx_config_params )
860866def config_fp32_mx_specs (request ):
861867 """
862868 Create fp32 qconfig w/ mx_specs.
863-
869+
864870
865871 Args:
866872 request (dict): qconfig_init
@@ -1176,7 +1182,7 @@ def model_bert():
11761182 """
11771183 return BertModel .from_pretrained ("google-bert/bert-base-uncased" , torchscript = True )
11781184
1179-
1185+
11801186@pytest .fixture (scope = "function" )
11811187def model_bert_eager ():
11821188 """
@@ -1192,10 +1198,12 @@ def model_bert_eager():
11921198
11931199# MX reference class for quantization
11941200if torch .cuda .is_available ():
1201+
11951202 class ResidualMLP (torch .nn .Module ):
11961203 """
11971204 Test Linear model for MX library
11981205 """
1206+
11991207 def __init__ (self , hidden_size , device = "cuda" ):
12001208 super ().__init__ ()
12011209
@@ -1230,8 +1238,10 @@ def forward(self, inputs):
12301238
12311239 return outputs
12321240
1241+
12331242mx_format_params = ["int8" , "int4" , "fp8_e4m3" , "fp8_e5m2" , "fp4_e2m1" ]
12341243
1244+
12351245@pytest .fixture (scope = "session" , params = mx_format_params )
12361246def mx_format (request ):
12371247 """
@@ -1254,6 +1264,7 @@ def input_residualMLP():
12541264 x = np .random .randn (16 , 128 )
12551265 return torch .tensor (x , dtype = torch .float32 , device = "cuda" )
12561266
1267+
12571268@pytest .fixture (scope = "function" )
12581269def model_residualMLP ():
12591270 """
0 commit comments