Skip to content

Commit caa6524

Browse files
committed
fix: Changed input_tiny to dataloader to work with qmodel_prep
Signed-off-by: Brandon Groth <[email protected]>
1 parent 9bb39e2 commit caa6524

File tree

1 file changed

+16
-10
lines changed

1 file changed

+16
-10
lines changed

tests/models/conftest.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
from torchvision.io import read_image
2626
from torchvision.models import ResNet50_Weights, ViT_B_16_Weights, resnet50, vit_b_16
2727
from transformers import (
28-
BatchEncoding,
2928
BertConfig,
3029
BertModel,
3130
BertTokenizer,
@@ -36,6 +35,7 @@
3635
)
3736
import pytest
3837
import torch
38+
from torch.utils.data import TensorDataset, DataLoader
3939

4040
# Local
4141
# fms_mo imports
@@ -1163,25 +1163,27 @@ def model_bert_eager():
11631163
size = (batch_size, max_position_embeddings)
11641164

11651165

1166-
@pytest.fixture(scope="session")
1167-
def input_tiny() -> BatchEncoding:
1166+
@pytest.fixture(scope="function")
1167+
def input_tiny() -> DataLoader:
11681168
"""
11691169
Create a fake input for tiny models w/ fixed vocab_size and max_position_embeddings
11701170
11711171
Returns:
1172-
BatchEncoding: Fake Encoding for a Tokenizer
1172+
DataLoader: Fake Encoding for a Tokenizer
11731173
"""
11741174
# Random tokens and attention mask == 1
11751175
random_tokens = torch.randint(low=0, high=vocab_size, size=size)
11761176
attention_mask = torch.ones(size)
11771177

1178-
fake_tokenizer_output = BatchEncoding(
1179-
{
1180-
"input_ids": random_tokens,
1181-
"attention_mask": attention_mask,
1182-
}
1178+
dataset = TensorDataset(random_tokens, attention_mask)
1179+
# qmodel_prep expects dataloader batch=tuple(tensor, tensor)
1180+
# Without collate_fn, it returns batch=list(tensor,tensor)
1181+
return DataLoader(
1182+
dataset,
1183+
batch_size=batch_size,
1184+
shuffle=False,
1185+
collate_fn=lambda batch: tuple(torch.stack(samples) for samples in zip(*batch))
11831186
)
1184-
return fake_tokenizer_output
11851187

11861188

11871189
#############################
@@ -1234,6 +1236,7 @@ def model_tiny_bert(config_tiny_bert: BertConfig) -> BertModel:
12341236
"nbits_w": 8,
12351237
"qa_mode": "pertokenmax",
12361238
"qw_mode": "max",
1239+
"qmodel_calibration": 1,
12371240
"smoothq": False,
12381241
"smoothq_scale_layers": [],
12391242
"qskip_layer_name": [
@@ -1250,6 +1253,7 @@ def model_tiny_bert(config_tiny_bert: BertConfig) -> BertModel:
12501253
"nbits_w": 8,
12511254
"qa_mode": "maxsym",
12521255
"qw_mode": "maxperCh",
1256+
"qmodel_calibration": 1,
12531257
"smoothq": False,
12541258
"smoothq_scale_layers": [],
12551259
"qskip_layer_name": [
@@ -1348,6 +1352,7 @@ def model_tiny_llama(config_tiny_llama: LlamaConfig) -> LlamaModel:
13481352
"nbits_w": 8,
13491353
"qa_mode": "pertokenmax",
13501354
"qw_mode": "max",
1355+
"qmodel_calibration": 1,
13511356
"smoothq": False,
13521357
"smoothq_scale_layers": [],
13531358
"qskip_layer_name": [
@@ -1447,6 +1452,7 @@ def model_tiny_granite(config_tiny_granite: GraniteConfig) -> GraniteModel:
14471452
"nbits_w": 8,
14481453
"qa_mode": "pertokenmax",
14491454
"qw_mode": "maxperCh",
1455+
"qmodel_calibration": 1,
14501456
"smoothq": False,
14511457
"smoothq_scale_layers": ["k_proj", "v_proj", "gate_proj", "up_proj"],
14521458
"qskip_layer_name": ["lm_head"],

0 commit comments

Comments
 (0)