@@ -177,6 +177,26 @@ def _save_pretrained(cls, save_directory: Path):
177177 return
178178
179179
180+ @dataclass
181+ class DummyDataclass :
182+ foo : int
183+ bar : str
184+
185+
186+ class DummyWithDataclassInputs (ModelHubMixin ):
187+ def __init__ (self , arg1 : DummyDataclass , arg2 : DummyDataclass ):
188+ self .arg1 = arg1
189+ self .arg2 = arg2
190+
191+ @classmethod
192+ def _from_pretrained (cls , ** kwargs ):
193+ return cls (arg1 = kwargs ["arg1" ], arg2 = kwargs ["arg2" ])
194+
195+ @classmethod
196+ def _save_pretrained (cls , save_directory : Path ):
197+ return
198+
199+
180200@pytest .mark .usefixtures ("fx_cache_dir" )
181201class HubMixinTest (unittest .TestCase ):
182202 cache_dir : Path
@@ -501,3 +521,22 @@ def method_with_hints(self, x: int) -> str:
501521 # Test method type hints on instance
502522 model = ModelWithHints ()
503523 assert get_type_hints (model .method_with_hints ) == {"x" : int , "return" : str }
524+
525+ def test_with_dataclass_inputs (self ):
526+ model = DummyWithDataclassInputs (
527+ arg1 = DummyDataclass (foo = 1 , bar = "1" ),
528+ arg2 = DummyDataclass (foo = 2 , bar = "2" ),
529+ )
530+ model .save_pretrained (self .cache_dir )
531+
532+ config = json .loads ((self .cache_dir / "config.json" ).read_text ())
533+ assert config == {
534+ "arg1" : {"foo" : 1 , "bar" : "1" },
535+ "arg2" : {"foo" : 2 , "bar" : "2" },
536+ }
537+
538+ model_reloaded = DummyWithDataclassInputs .from_pretrained (self .cache_dir )
539+ assert model_reloaded .arg1 .foo == 1
540+ assert model_reloaded .arg1 .bar == "1"
541+ assert model_reloaded .arg2 .foo == 2
542+ assert model_reloaded .arg2 .bar == "2"
0 commit comments