2525from torchvision .io import read_image
2626from torchvision .models import ResNet50_Weights , ViT_B_16_Weights , resnet50 , vit_b_16
2727from transformers import (
28- BatchEncoding ,
2928 BertConfig ,
3029 BertModel ,
3130 BertTokenizer ,
3837import pytest
3938import torch
4039import torch .nn .functional as F
40+ from torch .utils .data import TensorDataset , DataLoader
4141
4242# Local
4343# fms_mo imports
@@ -1327,25 +1327,27 @@ def model_residualMLP():
13271327size = (batch_size , max_position_embeddings )
13281328
13291329
1330- @pytest .fixture (scope = "session " )
1331- def input_tiny () -> BatchEncoding :
1330+ @pytest .fixture (scope = "function " )
1331+ def input_tiny () -> DataLoader :
13321332 """
13331333 Create a fake input for tiny models w/ fixed vocab_size and max_position_embeddings
13341334
13351335 Returns:
1336- BatchEncoding : Fake Encoding for a Tokenizer
1336+ DataLoader : Fake Encoding for a Tokenizer
13371337 """
13381338 # Random tokens and attention mask == 1
13391339 random_tokens = torch .randint (low = 0 , high = vocab_size , size = size )
13401340 attention_mask = torch .ones (size )
13411341
1342- fake_tokenizer_output = BatchEncoding (
1343- {
1344- "input_ids" : random_tokens ,
1345- "attention_mask" : attention_mask ,
1346- }
1342+ dataset = TensorDataset (random_tokens , attention_mask )
1343+ # qmodel_prep expects dataloader batch=tuple(tensor, tensor)
1344+ # Without collate_fn, it returns batch=list(tensor,tensor)
1345+ return DataLoader (
1346+ dataset ,
1347+ batch_size = batch_size ,
1348+ shuffle = False ,
1349+ collate_fn = lambda batch : tuple (torch .stack (samples ) for samples in zip (* batch ))
13471350 )
1348- return fake_tokenizer_output
13491351
13501352
13511353#############################
@@ -1398,6 +1400,7 @@ def model_tiny_bert(config_tiny_bert: BertConfig) -> BertModel:
13981400 "nbits_w" : 8 ,
13991401 "qa_mode" : "pertokenmax" ,
14001402 "qw_mode" : "max" ,
1403+ "qmodel_calibration" : 1 ,
14011404 "smoothq" : False ,
14021405 "smoothq_scale_layers" : [],
14031406 "qskip_layer_name" : [
@@ -1414,6 +1417,7 @@ def model_tiny_bert(config_tiny_bert: BertConfig) -> BertModel:
14141417 "nbits_w" : 8 ,
14151418 "qa_mode" : "maxsym" ,
14161419 "qw_mode" : "maxperCh" ,
1420+ "qmodel_calibration" : 1 ,
14171421 "smoothq" : False ,
14181422 "smoothq_scale_layers" : [],
14191423 "qskip_layer_name" : [
@@ -1512,6 +1516,7 @@ def model_tiny_llama(config_tiny_llama: LlamaConfig) -> LlamaModel:
15121516 "nbits_w" : 8 ,
15131517 "qa_mode" : "pertokenmax" ,
15141518 "qw_mode" : "max" ,
1519+ "qmodel_calibration" : 1 ,
15151520 "smoothq" : False ,
15161521 "smoothq_scale_layers" : [],
15171522 "qskip_layer_name" : [
@@ -1611,6 +1616,7 @@ def model_tiny_granite(config_tiny_granite: GraniteConfig) -> GraniteModel:
16111616 "nbits_w" : 8 ,
16121617 "qa_mode" : "pertokenmax" ,
16131618 "qw_mode" : "maxperCh" ,
1619+ "qmodel_calibration" : 1 ,
16141620 "smoothq" : False ,
16151621 "smoothq_scale_layers" : ["k_proj" , "v_proj" , "gate_proj" , "up_proj" ],
16161622 "qskip_layer_name" : ["lm_head" ],
0 commit comments