Skip to content

Commit 83ac2ea

Browse files
committed
fix: Changed input_tiny to dataloader to work with qmodel_prep
Signed-off-by: Brandon Groth <[email protected]>
1 parent d36e61f commit 83ac2ea

File tree

1 file changed

+16
-10
lines changed

1 file changed

+16
-10
lines changed

tests/models/conftest.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
from torchvision.io import read_image
2626
from torchvision.models import ResNet50_Weights, ViT_B_16_Weights, resnet50, vit_b_16
2727
from transformers import (
28-
BatchEncoding,
2928
BertConfig,
3029
BertModel,
3130
BertTokenizer,
@@ -38,6 +37,7 @@
3837
import pytest
3938
import torch
4039
import torch.nn.functional as F
40+
from torch.utils.data import TensorDataset, DataLoader
4141

4242
# Local
4343
# fms_mo imports
@@ -1327,25 +1327,27 @@ def model_residualMLP():
13271327
size = (batch_size, max_position_embeddings)
13281328

13291329

1330-
@pytest.fixture(scope="session")
1331-
def input_tiny() -> BatchEncoding:
1330+
@pytest.fixture(scope="function")
1331+
def input_tiny() -> DataLoader:
13321332
"""
13331333
Create a fake input for tiny models w/ fixed vocab_size and max_position_embeddings
13341334
13351335
Returns:
1336-
BatchEncoding: Fake Encoding for a Tokenizer
1336+
DataLoader: Fake Encoding for a Tokenizer
13371337
"""
13381338
# Random tokens and attention mask == 1
13391339
random_tokens = torch.randint(low=0, high=vocab_size, size=size)
13401340
attention_mask = torch.ones(size)
13411341

1342-
fake_tokenizer_output = BatchEncoding(
1343-
{
1344-
"input_ids": random_tokens,
1345-
"attention_mask": attention_mask,
1346-
}
1342+
dataset = TensorDataset(random_tokens, attention_mask)
1343+
# qmodel_prep expects dataloader batch=tuple(tensor, tensor)
1344+
# Without collate_fn, it returns batch=list(tensor,tensor)
1345+
return DataLoader(
1346+
dataset,
1347+
batch_size=batch_size,
1348+
shuffle=False,
1349+
collate_fn=lambda batch: tuple(torch.stack(samples) for samples in zip(*batch))
13471350
)
1348-
return fake_tokenizer_output
13491351

13501352

13511353
#############################
@@ -1398,6 +1400,7 @@ def model_tiny_bert(config_tiny_bert: BertConfig) -> BertModel:
13981400
"nbits_w": 8,
13991401
"qa_mode": "pertokenmax",
14001402
"qw_mode": "max",
1403+
"qmodel_calibration": 1,
14011404
"smoothq": False,
14021405
"smoothq_scale_layers": [],
14031406
"qskip_layer_name": [
@@ -1414,6 +1417,7 @@ def model_tiny_bert(config_tiny_bert: BertConfig) -> BertModel:
14141417
"nbits_w": 8,
14151418
"qa_mode": "maxsym",
14161419
"qw_mode": "maxperCh",
1420+
"qmodel_calibration": 1,
14171421
"smoothq": False,
14181422
"smoothq_scale_layers": [],
14191423
"qskip_layer_name": [
@@ -1512,6 +1516,7 @@ def model_tiny_llama(config_tiny_llama: LlamaConfig) -> LlamaModel:
15121516
"nbits_w": 8,
15131517
"qa_mode": "pertokenmax",
15141518
"qw_mode": "max",
1519+
"qmodel_calibration": 1,
15151520
"smoothq": False,
15161521
"smoothq_scale_layers": [],
15171522
"qskip_layer_name": [
@@ -1611,6 +1616,7 @@ def model_tiny_granite(config_tiny_granite: GraniteConfig) -> GraniteModel:
16111616
"nbits_w": 8,
16121617
"qa_mode": "pertokenmax",
16131618
"qw_mode": "maxperCh",
1619+
"qmodel_calibration": 1,
16141620
"smoothq": False,
16151621
"smoothq_scale_layers": ["k_proj", "v_proj", "gate_proj", "up_proj"],
16161622
"qskip_layer_name": ["lm_head"],

0 commit comments

Comments
 (0)