2424# Third Party
2525from torchvision .io import read_image
2626from torchvision .models import ResNet50_Weights , ViT_B_16_Weights , resnet50 , vit_b_16
27- from transformers import BertModel , BertTokenizer
27+ from transformers import (
28+ BertConfig ,
29+ BertModel ,
30+ BertTokenizer ,
31+ LlamaConfig ,
32+ LlamaModel ,
33+ GraniteConfig ,
34+ GraniteModel ,
35+ )
2836import numpy as np
2937import pytest
3038import torch
@@ -1299,3 +1307,130 @@ def model_residualMLP():
12991307 torch.nn.Module: _description_
13001308 """
13011309 return ResidualMLP (128 )
1310+
1311+
1312+ #############################
1313+ # Tiny BERT Model Fixtures #
1314+ #############################
1315+
1316+ tiny_bert_config_params = [
1317+ BertConfig (
1318+ vocab_size = 512 , # 30522
1319+ hidden_size = 128 , # 768
1320+ num_hidden_layers = 2 , # 12
1321+ num_attention_heads = 2 ,# 12
1322+ intermediate_size = 512 , # 3072
1323+ max_position_embeddings = 512 , # 512
1324+ type_vocab_size = 1 , # 2
1325+ ),
1326+ ]
1327+
1328+
1329+ @pytest .fixture (scope = "session" , params = tiny_bert_config_params )
1330+ def config_tiny_bert (request ):
1331+ """
1332+ Get a tiny Bert config
1333+
1334+ Returns:
1335+ BertConfig: Trimmed Tiny Bert config
1336+ """
1337+ return request .param
1338+
1339+
1340+ @pytest .fixture (scope = "function" )
1341+ def model_tiny_bert (config_tiny_bert ):
1342+ """
1343+ Get a tiny Llama Model based on the config
1344+
1345+ Args:
1346+ config_tiny_bert (BertConfig): Trimmed Tiny Bert config
1347+
1348+ Returns:
1349+ BertConfig: Tiny Bert model
1350+ """
1351+ model = deepcopy (BertModel (config_tiny_bert ))
1352+ return model
1353+
1354+
1355+ #############################
1356+ # Tiny Llama Model Fixtures #
1357+ #############################
1358+
1359+ tiny_llama_config_params = [
1360+ LlamaConfig (
1361+ vocab_size = 1024 , # 32000
1362+ hidden_size = 128 , # 4096
1363+ intermediate_size = 256 , # 11008
1364+ num_hidden_layers = 2 , # 32
1365+ num_attention_heads = 2 ,# 32
1366+ max_position_embeddings = 256 , # 2048
1367+ ),
1368+ ]
1369+
1370+
1371+ @pytest .fixture (scope = "session" , params = tiny_llama_config_params )
1372+ def config_tiny_llama (request ):
1373+ """
1374+ Get a tiny Llama config
1375+
1376+ Returns:
1377+ LlamaConfig: Trimmed Tiny Llama config
1378+ """
1379+ return request .param
1380+
1381+
1382+ @pytest .fixture (scope = "function" )
1383+ def model_tiny_llama (config_tiny_llama ):
1384+ """
1385+ Get a tiny Llama Model based on the config
1386+
1387+ Args:
1388+ config_tiny_llama (LlamaConfig): Trimmed Tiny Llama config
1389+
1390+ Returns:
1391+ LlamaModel: Tiny Llama model
1392+ """
1393+ model = deepcopy (LlamaModel (config_tiny_llama ))
1394+ return model
1395+
1396+
1397+ ###############################
1398+ # Tiny Granite Model Fixtures #
1399+ ###############################
1400+
1401+ tiny_granite_config_params = [
1402+ GraniteConfig (
1403+ vocab_size = 1024 , # 32000
1404+ hidden_size = 128 , # 4096
1405+ intermediate_size = 256 , # 11008
1406+ num_hidden_layers = 2 , # 32
1407+ num_attention_heads = 2 ,# 32
1408+ max_position_embeddings = 256 , # 2048
1409+ ),
1410+ ]
1411+
1412+
1413+ @pytest .fixture (scope = "session" , params = tiny_granite_config_params )
1414+ def config_tiny_granite (request ):
1415+ """
1416+ Get a tiny Granite config
1417+
1418+ Returns:
1419+ GraniteConfig: Tiny Granite config
1420+ """
1421+ return request .param
1422+
1423+
1424+ @pytest .fixture (scope = "function" )
1425+ def model_tiny_granite (config_tiny_granite ):
1426+ """
1427+ Get a tiny Granite Model based on the config
1428+
1429+ Args:
1430+ config_tiny_granite (GraniteConfig): Trimmed Tiny Granite config
1431+
1432+ Returns:
1433+ GraniteModel: Tiny Granite model
1434+ """
1435+ model = deepcopy (GraniteModel (config_tiny_granite ))
1436+ return model
0 commit comments