@@ -1369,19 +1369,8 @@ def input_tiny() -> DataLoader:
13691369]
13701370
13711371
1372- @pytest .fixture (scope = "session" , params = tiny_bert_config_params )
1373- def config_tiny_bert (request ) -> BertConfig :
1374- """
1375- Get a tiny Bert config
1376-
1377- Returns:
1378- BertConfig: Trimmed Tiny Bert config
1379- """
1380- return request .param
1381-
1382-
1383- @pytest .fixture (scope = "function" )
1384- def model_tiny_bert (config_tiny_bert : BertConfig ) -> BertModel :
1372+ @pytest .fixture (scope = "function" , params = tiny_bert_config_params )
1373+ def model_tiny_bert (request ) -> BertModel :
13851374 """
13861375 Get a tiny Llama Model based on the config
13871376
@@ -1391,7 +1380,7 @@ def model_tiny_bert(config_tiny_bert: BertConfig) -> BertModel:
13911380 Returns:
13921381 BertConfig: Tiny Bert model
13931382 """
1394- model = BertModel (config = config_tiny_bert )
1383+ model = BertModel (config = request . param )
13951384 return model
13961385
13971386
@@ -1485,19 +1474,8 @@ def bert_linear_names() -> list:
14851474]
14861475
14871476
1488- @pytest .fixture (scope = "session" , params = tiny_llama_config_params )
1489- def config_tiny_llama (request ) -> LlamaConfig :
1490- """
1491- Get a tiny Llama config
1492-
1493- Returns:
1494- LlamaConfig: Trimmed Tiny Llama config
1495- """
1496- return request .param
1497-
1498-
1499- @pytest .fixture (scope = "function" )
1500- def model_tiny_llama (config_tiny_llama : LlamaConfig ) -> LlamaModel :
1477+ @pytest .fixture (scope = "function" , params = tiny_llama_config_params )
1478+ def model_tiny_llama (request ) -> LlamaModel :
15011479 """
15021480 Get a tiny Llama Model based on the config
15031481
@@ -1507,7 +1485,7 @@ def model_tiny_llama(config_tiny_llama: LlamaConfig) -> LlamaModel:
15071485 Returns:
15081486 LlamaModel: Tiny Llama model
15091487 """
1510- model = LlamaModel (config = config_tiny_llama )
1488+ model = LlamaModel (config = request . param )
15111489 return model
15121490
15131491
@@ -1585,19 +1563,8 @@ def llama_linear_names() -> list:
15851563]
15861564
15871565
1588- @pytest .fixture (scope = "session" , params = tiny_granite_config_params )
1589- def config_tiny_granite (request ) -> GraniteConfig :
1590- """
1591- Get a tiny Granite config
1592-
1593- Returns:
1594- GraniteConfig: Tiny Granite config
1595- """
1596- return request .param
1597-
1598-
1599- @pytest .fixture (scope = "function" )
1600- def model_tiny_granite (config_tiny_granite : GraniteConfig ) -> GraniteModel :
1566+ @pytest .fixture (scope = "function" , params = tiny_granite_config_params )
1567+ def model_tiny_granite (request ) -> GraniteModel :
16011568 """
16021569 Get a tiny Granite Model based on the config
16031570
@@ -1607,7 +1574,7 @@ def model_tiny_granite(config_tiny_granite: GraniteConfig) -> GraniteModel:
16071574 Returns:
16081575 GraniteModel: Tiny Granite model
16091576 """
1610- model = GraniteModel (config = config_tiny_granite )
1577+ model = GraniteModel (config = request . param )
16111578 return model
16121579
16131580
0 commit comments