Skip to content

Commit 9bab835

Browse files
lint and ruff fix
Signed-off-by: cliu-us <[email protected]>
1 parent 45dd501 commit 9bab835

File tree

2 files changed

+17
-8
lines changed

2 files changed

+17
-8
lines changed

fms_mo/utils/qconfig_utils.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -342,7 +342,7 @@ def set_mx_specs(
342342

343343
# Check args for any mx_specs vars
344344
use_mx_args = args is not None and any(
345-
hasattr(args, key) for key,_ in fms_defaults.items()
345+
hasattr(args, key) for key, _ in fms_defaults.items()
346346
)
347347

348348
# Lastly, check for BMM consistency to enable QBmmMX
@@ -403,7 +403,7 @@ def set_mx_specs(
403403
if config["qa_mode"].startswith(mx_prefix):
404404
mx_specs["a_elem_format"] = config["qa_mode"].replace(mx_prefix, "")
405405

406-
for mx_var,_ in fms_defaults.items():
406+
for mx_var, _ in fms_defaults.items():
407407
fms_var = "mx_" + mx_var
408408
# Only update if its in config; default values already set
409409
if fms_var in config:
@@ -1194,8 +1194,6 @@ def check_config(config, model_dtype=None):
11941194
raise ValueError("MX mapping for nn.Linear is not QLinearMX")
11951195

11961196
if mapping["matmul_or_bmm"].func is QBmmMX:
1197-
raise ValueError(
1198-
"MX mapping for matmul_or_bmm is not QBmmMX"
1199-
)
1197+
raise ValueError("MX mapping for matmul_or_bmm is not QBmmMX")
12001198

12011199
# End mx_specs checks

tests/models/conftest.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,7 @@ def not_which2patch_contextmanager_settings():
312312
"""
313313
return ["torch.vmm", "torch.natnul", "None"]
314314

315+
315316
@pytest.fixture(scope="session")
316317
def bad_mx_specs_settings():
317318
"""
@@ -330,6 +331,7 @@ def bad_mx_specs_settings():
330331
("custom_cuda", "yes"),
331332
]
332333

334+
333335
@pytest.fixture(scope="session")
334336
def bad_mx_config_settings():
335337
"""
@@ -348,6 +350,7 @@ def bad_mx_config_settings():
348350
("mx_custom_cuda", "custom_cuda", "yes", "yes"),
349351
]
350352

353+
351354
################################
352355
# Toy Model Classes + Fixtures #
353356
################################
@@ -464,6 +467,7 @@ def forward(self, input_tensor):
464467
out = self.fourth_layer(out)
465468
return out
466469

470+
467471
model_fp32_params = [
468472
ToyModel1(),
469473
ToyModel2(),
@@ -819,11 +823,12 @@ def config_fp32(request):
819823
qconfig = request.param
820824
return deepcopy(qconfig)
821825

826+
822827
@pytest.fixture(scope="function", params=default_config_params)
823828
def config_fp32_mx(request):
824829
"""
825830
Create fp32 qconfig w/ mx_specs vars set in qconfig.
826-
831+
827832
Args:
828833
request (dict): qconfig_init
829834
@@ -856,11 +861,12 @@ def config_fp32_mx(request):
856861

857862
return qconfig
858863

864+
859865
@pytest.fixture(scope="function", params=mx_config_params)
860866
def config_fp32_mx_specs(request):
861867
"""
862868
Create fp32 qconfig w/ mx_specs.
863-
869+
864870
865871
Args:
866872
request (dict): qconfig_init
@@ -1176,7 +1182,7 @@ def model_bert():
11761182
"""
11771183
return BertModel.from_pretrained("google-bert/bert-base-uncased", torchscript=True)
11781184

1179-
1185+
11801186
@pytest.fixture(scope="function")
11811187
def model_bert_eager():
11821188
"""
@@ -1192,10 +1198,12 @@ def model_bert_eager():
11921198

11931199
# MX reference class for quantization
11941200
if torch.cuda.is_available():
1201+
11951202
class ResidualMLP(torch.nn.Module):
11961203
"""
11971204
Test Linear model for MX library
11981205
"""
1206+
11991207
def __init__(self, hidden_size, device="cuda"):
12001208
super().__init__()
12011209

@@ -1230,8 +1238,10 @@ def forward(self, inputs):
12301238

12311239
return outputs
12321240

1241+
12331242
mx_format_params = ["int8", "int4", "fp8_e4m3", "fp8_e5m2", "fp4_e2m1"]
12341243

1244+
12351245
@pytest.fixture(scope="session", params=mx_format_params)
12361246
def mx_format(request):
12371247
"""
@@ -1254,6 +1264,7 @@ def input_residualMLP():
12541264
x = np.random.randn(16, 128)
12551265
return torch.tensor(x, dtype=torch.float32, device="cuda")
12561266

1267+
12571268
@pytest.fixture(scope="function")
12581269
def model_residualMLP():
12591270
"""

0 commit comments

Comments
 (0)