Skip to content

Commit 9a6e7de

Browse files
committed
fix: Removed tiny model configs to params of model fixture
Signed-off-by: Brandon Groth <[email protected]>
1 parent 307027d commit 9a6e7de

File tree

1 file changed

+9
-42
lines changed

1 file changed

+9
-42
lines changed

tests/models/conftest.py

Lines changed: 9 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1204,19 +1204,8 @@ def input_tiny() -> DataLoader:
12041204
]
12051205

12061206

1207-
@pytest.fixture(scope="session", params=tiny_bert_config_params)
1208-
def config_tiny_bert(request) -> BertConfig:
1209-
"""
1210-
Get a tiny Bert config
1211-
1212-
Returns:
1213-
BertConfig: Trimmed Tiny Bert config
1214-
"""
1215-
return request.param
1216-
1217-
1218-
@pytest.fixture(scope="function")
1219-
def model_tiny_bert(config_tiny_bert: BertConfig) -> BertModel:
1207+
@pytest.fixture(scope="function", params=tiny_bert_config_params)
1208+
def model_tiny_bert(request) -> BertModel:
12201209
"""
12211210
Get a tiny Llama Model based on the config
12221211
@@ -1226,7 +1215,7 @@ def model_tiny_bert(config_tiny_bert: BertConfig) -> BertModel:
12261215
Returns:
12271216
BertConfig: Tiny Bert model
12281217
"""
1229-
model = BertModel(config=config_tiny_bert)
1218+
model = BertModel(config=request.param)
12301219
return model
12311220

12321221

@@ -1320,19 +1309,8 @@ def bert_linear_names() -> list:
13201309
]
13211310

13221311

1323-
@pytest.fixture(scope="session", params=tiny_llama_config_params)
1324-
def config_tiny_llama(request) -> LlamaConfig:
1325-
"""
1326-
Get a tiny Llama config
1327-
1328-
Returns:
1329-
LlamaConfig: Trimmed Tiny Llama config
1330-
"""
1331-
return request.param
1332-
1333-
1334-
@pytest.fixture(scope="function")
1335-
def model_tiny_llama(config_tiny_llama: LlamaConfig) -> LlamaModel:
1312+
@pytest.fixture(scope="function", params=tiny_llama_config_params)
1313+
def model_tiny_llama(request) -> LlamaModel:
13361314
"""
13371315
Get a tiny Llama Model based on the config
13381316
@@ -1342,7 +1320,7 @@ def model_tiny_llama(config_tiny_llama: LlamaConfig) -> LlamaModel:
13421320
Returns:
13431321
LlamaModel: Tiny Llama model
13441322
"""
1345-
model = LlamaModel(config=config_tiny_llama)
1323+
model = LlamaModel(config=request.param)
13461324
return model
13471325

13481326

@@ -1420,19 +1398,8 @@ def llama_linear_names() -> list:
14201398
]
14211399

14221400

1423-
@pytest.fixture(scope="session", params=tiny_granite_config_params)
1424-
def config_tiny_granite(request) -> GraniteConfig:
1425-
"""
1426-
Get a tiny Granite config
1427-
1428-
Returns:
1429-
GraniteConfig: Tiny Granite config
1430-
"""
1431-
return request.param
1432-
1433-
1434-
@pytest.fixture(scope="function")
1435-
def model_tiny_granite(config_tiny_granite: GraniteConfig) -> GraniteModel:
1401+
@pytest.fixture(scope="function", params=tiny_granite_config_params)
1402+
def model_tiny_granite(request) -> GraniteModel:
14361403
"""
14371404
Get a tiny Granite Model based on the config
14381405
@@ -1442,7 +1409,7 @@ def model_tiny_granite(config_tiny_granite: GraniteConfig) -> GraniteModel:
14421409
Returns:
14431410
GraniteModel: Tiny Granite model
14441411
"""
1445-
model = GraniteModel(config=config_tiny_granite)
1412+
model = GraniteModel(config=request.param)
14461413
return model
14471414

14481415

0 commit comments

Comments
 (0)