Skip to content

Commit 6fd2495

Browse files
committed
feat: Added tiny bert, llama, and granite models fixtures
Signed-off-by: Brandon Groth <[email protected]>
1 parent 05bb442 commit 6fd2495

File tree

1 file changed

+136
-1
lines changed

1 file changed

+136
-1
lines changed

tests/models/conftest.py

Lines changed: 136 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,15 @@
2424
# Third Party
2525
from torchvision.io import read_image
2626
from torchvision.models import ResNet50_Weights, ViT_B_16_Weights, resnet50, vit_b_16
27-
from transformers import BertModel, BertTokenizer
27+
from transformers import (
28+
BertConfig,
29+
BertModel,
30+
BertTokenizer,
31+
LlamaConfig,
32+
LlamaModel,
33+
GraniteConfig,
34+
GraniteModel,
35+
)
2836
import numpy as np
2937
import pytest
3038
import torch
@@ -1299,3 +1307,130 @@ def model_residualMLP():
12991307
torch.nn.Module: _description_
13001308
"""
13011309
return ResidualMLP(128)
1310+
1311+
1312+
#############################
1313+
# Tiny BERT Model Fixtures #
1314+
#############################
1315+
1316+
tiny_bert_config_params = [
1317+
BertConfig(
1318+
vocab_size=512, # 30522
1319+
hidden_size=128, # 768
1320+
num_hidden_layers=2, # 12
1321+
num_attention_heads=2,# 12
1322+
intermediate_size=512, # 3072
1323+
max_position_embeddings=512, # 512
1324+
type_vocab_size=1, # 2
1325+
),
1326+
]
1327+
1328+
1329+
@pytest.fixture(scope="session", params=tiny_bert_config_params)
1330+
def config_tiny_bert(request):
1331+
"""
1332+
Get a tiny Bert config
1333+
1334+
Returns:
1335+
BertConfig: Trimmed Tiny Bert config
1336+
"""
1337+
return request.param
1338+
1339+
1340+
@pytest.fixture(scope="function")
1341+
def model_tiny_bert(config_tiny_bert):
1342+
"""
1343+
Get a tiny Llama Model based on the config
1344+
1345+
Args:
1346+
config_tiny_bert (BertConfig): Trimmed Tiny Bert config
1347+
1348+
Returns:
1349+
BertConfig: Tiny Bert model
1350+
"""
1351+
model = deepcopy(BertModel(config_tiny_bert))
1352+
return model
1353+
1354+
1355+
#############################
1356+
# Tiny Llama Model Fixtures #
1357+
#############################
1358+
1359+
tiny_llama_config_params = [
1360+
LlamaConfig(
1361+
vocab_size=1024, # 32000
1362+
hidden_size=128, # 4096
1363+
intermediate_size=256, # 11008
1364+
num_hidden_layers=2, # 32
1365+
num_attention_heads=2,# 32
1366+
max_position_embeddings=256, # 2048
1367+
),
1368+
]
1369+
1370+
1371+
@pytest.fixture(scope="session", params=tiny_llama_config_params)
1372+
def config_tiny_llama(request):
1373+
"""
1374+
Get a tiny Llama config
1375+
1376+
Returns:
1377+
LlamaConfig: Trimmed Tiny Llama config
1378+
"""
1379+
return request.param
1380+
1381+
1382+
@pytest.fixture(scope="function")
1383+
def model_tiny_llama(config_tiny_llama):
1384+
"""
1385+
Get a tiny Llama Model based on the config
1386+
1387+
Args:
1388+
config_tiny_llama (LlamaConfig): Trimmed Tiny Llama config
1389+
1390+
Returns:
1391+
LlamaModel: Tiny Llama model
1392+
"""
1393+
model = deepcopy(LlamaModel(config_tiny_llama))
1394+
return model
1395+
1396+
1397+
###############################
1398+
# Tiny Granite Model Fixtures #
1399+
###############################
1400+
1401+
tiny_granite_config_params = [
1402+
GraniteConfig(
1403+
vocab_size=1024, # 32000
1404+
hidden_size=128, # 4096
1405+
intermediate_size=256, # 11008
1406+
num_hidden_layers=2, # 32
1407+
num_attention_heads=2,# 32
1408+
max_position_embeddings=256, # 2048
1409+
),
1410+
]
1411+
1412+
1413+
@pytest.fixture(scope="session", params=tiny_granite_config_params)
1414+
def config_tiny_granite(request):
1415+
"""
1416+
Get a tiny Granite config
1417+
1418+
Returns:
1419+
GraniteConfig: Tiny Granite config
1420+
"""
1421+
return request.param
1422+
1423+
1424+
@pytest.fixture(scope="function")
1425+
def model_tiny_granite(config_tiny_granite):
1426+
"""
1427+
Get a tiny Granite Model based on the config
1428+
1429+
Args:
1430+
config_tiny_granite (GraniteConfig): Trimmed Tiny Granite config
1431+
1432+
Returns:
1433+
GraniteModel: Tiny Granite model
1434+
"""
1435+
model = deepcopy(GraniteModel(config_tiny_granite))
1436+
return model

0 commit comments

Comments
 (0)