Skip to content

Commit 59e1f37

Browse files
committed
chore: linting
Signed-off-by: Brandon Groth <[email protected]>
1 parent 83ac2ea commit 59e1f37

File tree

3 files changed

+9
-13
lines changed

3 files changed

+9
-13
lines changed

tests/models/conftest.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import os
2323

2424
# Third Party
25+
from torch.utils.data import DataLoader, TensorDataset
2526
from torchvision.io import read_image
2627
from torchvision.models import ResNet50_Weights, ViT_B_16_Weights, resnet50, vit_b_16
2728
from transformers import (
@@ -1346,7 +1347,7 @@ def input_tiny() -> DataLoader:
13461347
dataset,
13471348
batch_size=batch_size,
13481349
shuffle=False,
1349-
collate_fn=lambda batch: tuple(torch.stack(samples) for samples in zip(*batch))
1350+
collate_fn=lambda batch: tuple(torch.stack(samples) for samples in zip(*batch)),
13501351
)
13511352

13521353

tests/models/test_model_utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -238,4 +238,3 @@ def check_linear_dtypes(state_dict: dict, linear_names: list):
238238
for k, v in state_dict.items()
239239
if all(n not in k for n in linear_names) or not k.endswith(".weight")
240240
)
241-

tests/models/test_save_aiu.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,7 @@
33
import pytest
44

55
# Local
6-
from .test_model_utils import (
7-
check_linear_dtypes,
8-
delete_config,
9-
load_state_dict,
10-
)
6+
from .test_model_utils import check_linear_dtypes, delete_config, load_state_dict
117
from fms_mo import qmodel_prep
128
from fms_mo.utils.aiu_utils import save_for_aiu
139

@@ -61,32 +57,33 @@ def test_large_outlier_bert(
6157
qcfg_bert (dict): Fake tiny input
6258
bert_linear_names (list): Quantized config for Bert
6359
"""
60+
# Third Party
6461
import torch
6562

6663
# Break every tensor channel with a large magnitude outlier
67-
for k,v in model_tiny_bert.state_dict().items():
64+
for k, v in model_tiny_bert.state_dict().items():
6865
if k.endswith(".weight") and any(n in k for n in bert_linear_names):
69-
v[:,0] = 1.21
66+
v[:, 0] = 1.21
7067

7168
# Set recomputation for narrow weights and prep
7269
qcfg_bert["recompute_narrow_weights"] = True
7370
qmodel_prep(model_tiny_bert, input_tiny, qcfg_bert, use_dynamo=True)
7471

7572
# Qmax should break the quantization with an outlier to have skinny distribution
7673
layer2stdev: dict[str, torch.Tensor] = {}
77-
for k,v in model_tiny_bert.state_dict().items():
74+
for k, v in model_tiny_bert.state_dict().items():
7875
if k.endswith(".weight") and any(n in k for n in bert_linear_names):
7976
layer2stdev[k] = v.to(torch.float32).std(dim=-1)
8077

8178
save_for_aiu(model_tiny_bert, qcfg=qcfg_bert, verbose=True)
8279
state_dict = load_state_dict()
8380

8481
# Loaded model w/ recomputed SAWB should have widened channel quantization stdev
85-
for k,v in state_dict.items():
82+
for k, v in state_dict.items():
8683
if k.endswith(".weight") and any(n in k for n in bert_linear_names):
8784
perCh_stdev_model = layer2stdev.get(k)
8885
perCh_stdev_loaded = v.to(torch.float32).std(dim=-1)
89-
86+
9087
assert torch.all(perCh_stdev_loaded >= perCh_stdev_model)
9188

9289

@@ -136,4 +133,3 @@ def test_save_model_granite(
136133
# Fetch saved state dict
137134
state_dict = load_state_dict()
138135
check_linear_dtypes(state_dict, granite_linear_names)
139-

0 commit comments

Comments
 (0)