@@ -1204,19 +1204,8 @@ def input_tiny() -> DataLoader:
12041204]
12051205
12061206
1207- @pytest .fixture (scope = "session" , params = tiny_bert_config_params )
1208- def config_tiny_bert (request ) -> BertConfig :
1209- """
1210- Get a tiny Bert config
1211-
1212- Returns:
1213- BertConfig: Trimmed Tiny Bert config
1214- """
1215- return request .param
1216-
1217-
1218- @pytest .fixture (scope = "function" )
1219- def model_tiny_bert (config_tiny_bert : BertConfig ) -> BertModel :
1207+ @pytest .fixture (scope = "function" , params = tiny_bert_config_params )
1208+ def model_tiny_bert (request ) -> BertModel :
12201209 """
12211210 Get a tiny Llama Model based on the config
12221211
@@ -1226,7 +1215,7 @@ def model_tiny_bert(config_tiny_bert: BertConfig) -> BertModel:
12261215 Returns:
12271216 BertConfig: Tiny Bert model
12281217 """
1229- model = BertModel (config = config_tiny_bert )
1218+ model = BertModel (config = request . param )
12301219 return model
12311220
12321221
@@ -1320,19 +1309,8 @@ def bert_linear_names() -> list:
13201309]
13211310
13221311
1323- @pytest .fixture (scope = "session" , params = tiny_llama_config_params )
1324- def config_tiny_llama (request ) -> LlamaConfig :
1325- """
1326- Get a tiny Llama config
1327-
1328- Returns:
1329- LlamaConfig: Trimmed Tiny Llama config
1330- """
1331- return request .param
1332-
1333-
1334- @pytest .fixture (scope = "function" )
1335- def model_tiny_llama (config_tiny_llama : LlamaConfig ) -> LlamaModel :
1312+ @pytest .fixture (scope = "function" , params = tiny_llama_config_params )
1313+ def model_tiny_llama (request ) -> LlamaModel :
13361314 """
13371315 Get a tiny Llama Model based on the config
13381316
@@ -1342,7 +1320,7 @@ def model_tiny_llama(config_tiny_llama: LlamaConfig) -> LlamaModel:
13421320 Returns:
13431321 LlamaModel: Tiny Llama model
13441322 """
1345- model = LlamaModel (config = config_tiny_llama )
1323+ model = LlamaModel (config = request . param )
13461324 return model
13471325
13481326
@@ -1420,19 +1398,8 @@ def llama_linear_names() -> list:
14201398]
14211399
14221400
1423- @pytest .fixture (scope = "session" , params = tiny_granite_config_params )
1424- def config_tiny_granite (request ) -> GraniteConfig :
1425- """
1426- Get a tiny Granite config
1427-
1428- Returns:
1429- GraniteConfig: Tiny Granite config
1430- """
1431- return request .param
1432-
1433-
1434- @pytest .fixture (scope = "function" )
1435- def model_tiny_granite (config_tiny_granite : GraniteConfig ) -> GraniteModel :
1401+ @pytest .fixture (scope = "function" , params = tiny_granite_config_params )
1402+ def model_tiny_granite (request ) -> GraniteModel :
14361403 """
14371404 Get a tiny Granite Model based on the config
14381405
@@ -1442,7 +1409,7 @@ def model_tiny_granite(config_tiny_granite: GraniteConfig) -> GraniteModel:
14421409 Returns:
14431410 GraniteModel: Tiny Granite model
14441411 """
1445- model = GraniteModel (config = config_tiny_granite )
1412+ model = GraniteModel (config = request . param )
14461413 return model
14471414
14481415
0 commit comments