Skip to content

Commit 624718d

Browse files
committed
fix: Removed tiny model configs to params of model fixture
Signed-off-by: Brandon Groth <[email protected]>
1 parent 59e1f37 commit 624718d

File tree

1 file changed

+9
-42
lines changed

1 file changed

+9
-42
lines changed

tests/models/conftest.py

Lines changed: 9 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1369,19 +1369,8 @@ def input_tiny() -> DataLoader:
13691369
]
13701370

13711371

1372-
@pytest.fixture(scope="session", params=tiny_bert_config_params)
1373-
def config_tiny_bert(request) -> BertConfig:
1374-
"""
1375-
Get a tiny Bert config
1376-
1377-
Returns:
1378-
BertConfig: Trimmed Tiny Bert config
1379-
"""
1380-
return request.param
1381-
1382-
1383-
@pytest.fixture(scope="function")
1384-
def model_tiny_bert(config_tiny_bert: BertConfig) -> BertModel:
1372+
@pytest.fixture(scope="function", params=tiny_bert_config_params)
1373+
def model_tiny_bert(request) -> BertModel:
13851374
"""
13861375
Get a tiny Llama Model based on the config
13871376
@@ -1391,7 +1380,7 @@ def model_tiny_bert(config_tiny_bert: BertConfig) -> BertModel:
13911380
Returns:
13921381
BertConfig: Tiny Bert model
13931382
"""
1394-
model = BertModel(config=config_tiny_bert)
1383+
model = BertModel(config=request.param)
13951384
return model
13961385

13971386

@@ -1485,19 +1474,8 @@ def bert_linear_names() -> list:
14851474
]
14861475

14871476

1488-
@pytest.fixture(scope="session", params=tiny_llama_config_params)
1489-
def config_tiny_llama(request) -> LlamaConfig:
1490-
"""
1491-
Get a tiny Llama config
1492-
1493-
Returns:
1494-
LlamaConfig: Trimmed Tiny Llama config
1495-
"""
1496-
return request.param
1497-
1498-
1499-
@pytest.fixture(scope="function")
1500-
def model_tiny_llama(config_tiny_llama: LlamaConfig) -> LlamaModel:
1477+
@pytest.fixture(scope="function", params=tiny_llama_config_params)
1478+
def model_tiny_llama(request) -> LlamaModel:
15011479
"""
15021480
Get a tiny Llama Model based on the config
15031481
@@ -1507,7 +1485,7 @@ def model_tiny_llama(config_tiny_llama: LlamaConfig) -> LlamaModel:
15071485
Returns:
15081486
LlamaModel: Tiny Llama model
15091487
"""
1510-
model = LlamaModel(config=config_tiny_llama)
1488+
model = LlamaModel(config=request.param)
15111489
return model
15121490

15131491

@@ -1585,19 +1563,8 @@ def llama_linear_names() -> list:
15851563
]
15861564

15871565

1588-
@pytest.fixture(scope="session", params=tiny_granite_config_params)
1589-
def config_tiny_granite(request) -> GraniteConfig:
1590-
"""
1591-
Get a tiny Granite config
1592-
1593-
Returns:
1594-
GraniteConfig: Tiny Granite config
1595-
"""
1596-
return request.param
1597-
1598-
1599-
@pytest.fixture(scope="function")
1600-
def model_tiny_granite(config_tiny_granite: GraniteConfig) -> GraniteModel:
1566+
@pytest.fixture(scope="function", params=tiny_granite_config_params)
1567+
def model_tiny_granite(request) -> GraniteModel:
16011568
"""
16021569
Get a tiny Granite Model based on the config
16031570
@@ -1607,7 +1574,7 @@ def model_tiny_granite(config_tiny_granite: GraniteConfig) -> GraniteModel:
16071574
Returns:
16081575
GraniteModel: Tiny Granite model
16091576
"""
1610-
model = GraniteModel(config=config_tiny_granite)
1577+
model = GraniteModel(config=request.param)
16111578
return model
16121579

16131580

0 commit comments

Comments
 (0)