1+ import pytest
2+ from transformers import PretrainedConfig , PreTrainedModel
3+ from ltsm .models import register_model , get_model , model_dict
4+
5+ def test_register_model (mocker ):
6+ mock_model = mocker .MagicMock (spec = PreTrainedModel )
7+ register_model (mock_model , "MockModel1" )
8+ assert "MockModel1" in model_dict
9+ assert model_dict ["MockModel1" ] == mock_model
10+
11+ with pytest .raises (AssertionError , match = "Reader MockModel1 already registered" ):
12+ register_model (mock_model , "MockModel1" )
13+
14+ def test_get_model (mocker ):
15+ mock_model = mocker .MagicMock (spec = PreTrainedModel )
16+ mock_config = mocker .MagicMock (spec = PretrainedConfig )
17+ register_model (mock_model , "MockModel2" )
18+
19+ instance = get_model (mock_config , "MockModel2" )
20+ mock_model .assert_called_once_with (mock_config )
21+ assert isinstance (instance , mocker .MagicMock )
22+
23+ def test_get_model_invalid_name ():
24+ with pytest .raises (ValueError , match = "Model NonExistentModel is not registered" ):
25+ get_model (PretrainedConfig (), "NonExistentModel" )
26+
27+ def test_get_model_local_pretrain (mocker ):
28+ mock_from_pretrained = mocker .patch ("transformers.PretrainedConfig.from_pretrained" )
29+ mock_model = mocker .MagicMock (spec = PreTrainedModel )
30+ register_model (mock_model , "MockModel3" )
31+
32+ mock_from_pretrained .return_value = mocker .MagicMock ()
33+ instance = get_model (PretrainedConfig (), "MockModel3" , local_pretrain = "path/to/pretrained" )
34+ mock_model .from_pretrained .assert_called_once_with ("path/to/pretrained" , mock_from_pretrained .return_value )
35+ assert isinstance (instance , mocker .MagicMock )
36+
37+ def test_get_model_hf_hub (mocker ):
38+ mock_from_pretrained = mocker .patch ("transformers.PreTrainedModel.from_pretrained" )
39+ mock_model = mocker .MagicMock (spec = PreTrainedModel )
40+ register_model (mock_model , "MockModel4" )
41+
42+ instance = get_model (PretrainedConfig (), "MockModel4" , hf_hub_model = "mock-hub-model" )
43+ mock_model .from_pretrained .assert_called_once_with ("mock-hub-model" , PretrainedConfig ())
44+ assert isinstance (instance , mocker .MagicMock )
0 commit comments