2525from torchvision .io import read_image
2626from torchvision .models import ResNet50_Weights , ViT_B_16_Weights , resnet50 , vit_b_16
2727from transformers import (
28- BatchEncoding ,
2928 BertConfig ,
3029 BertModel ,
3130 BertTokenizer ,
3635)
3736import pytest
3837import torch
38+ from torch .utils .data import TensorDataset , DataLoader
3939
4040# Local
4141# fms_mo imports
@@ -1163,25 +1163,27 @@ def model_bert_eager():
11631163size = (batch_size , max_position_embeddings )
11641164
11651165
1166- @pytest .fixture (scope = "session " )
1167- def input_tiny () -> BatchEncoding :
1166+ @pytest .fixture (scope = "function " )
1167+ def input_tiny () -> DataLoader :
11681168 """
11691169 Create a fake input for tiny models w/ fixed vocab_size and max_position_embeddings
11701170
11711171 Returns:
1172- BatchEncoding : Fake Encoding for a Tokenizer
1172+ DataLoader : Fake Encoding for a Tokenizer
11731173 """
11741174 # Random tokens and attention mask == 1
11751175 random_tokens = torch .randint (low = 0 , high = vocab_size , size = size )
11761176 attention_mask = torch .ones (size )
11771177
1178- fake_tokenizer_output = BatchEncoding (
1179- {
1180- "input_ids" : random_tokens ,
1181- "attention_mask" : attention_mask ,
1182- }
1178+ dataset = TensorDataset (random_tokens , attention_mask )
1179+ # qmodel_prep expects dataloader batch=tuple(tensor, tensor)
1180+ # Without collate_fn, it returns batch=list(tensor,tensor)
1181+ return DataLoader (
1182+ dataset ,
1183+ batch_size = batch_size ,
1184+ shuffle = False ,
1185+ collate_fn = lambda batch : tuple (torch .stack (samples ) for samples in zip (* batch ))
11831186 )
1184- return fake_tokenizer_output
11851187
11861188
11871189#############################
@@ -1234,6 +1236,7 @@ def model_tiny_bert(config_tiny_bert: BertConfig) -> BertModel:
12341236 "nbits_w" : 8 ,
12351237 "qa_mode" : "pertokenmax" ,
12361238 "qw_mode" : "max" ,
1239+ "qmodel_calibration" : 1 ,
12371240 "smoothq" : False ,
12381241 "smoothq_scale_layers" : [],
12391242 "qskip_layer_name" : [
@@ -1250,6 +1253,7 @@ def model_tiny_bert(config_tiny_bert: BertConfig) -> BertModel:
12501253 "nbits_w" : 8 ,
12511254 "qa_mode" : "maxsym" ,
12521255 "qw_mode" : "maxperCh" ,
1256+ "qmodel_calibration" : 1 ,
12531257 "smoothq" : False ,
12541258 "smoothq_scale_layers" : [],
12551259 "qskip_layer_name" : [
@@ -1348,6 +1352,7 @@ def model_tiny_llama(config_tiny_llama: LlamaConfig) -> LlamaModel:
13481352 "nbits_w" : 8 ,
13491353 "qa_mode" : "pertokenmax" ,
13501354 "qw_mode" : "max" ,
1355+ "qmodel_calibration" : 1 ,
13511356 "smoothq" : False ,
13521357 "smoothq_scale_layers" : [],
13531358 "qskip_layer_name" : [
@@ -1447,6 +1452,7 @@ def model_tiny_granite(config_tiny_granite: GraniteConfig) -> GraniteModel:
14471452 "nbits_w" : 8 ,
14481453 "qa_mode" : "pertokenmax" ,
14491454 "qw_mode" : "maxperCh" ,
1455+ "qmodel_calibration" : 1 ,
14501456 "smoothq" : False ,
14511457 "smoothq_scale_layers" : ["k_proj" , "v_proj" , "gate_proj" , "up_proj" ],
14521458 "qskip_layer_name" : ["lm_head" ],
0 commit comments