|
3 | 3 | import struct |
4 | 4 | import unittest |
5 | 5 | from pathlib import Path |
6 | | -from typing import Any, TypeVar |
| 6 | +from typing import Any, Dict, Optional, TypeVar |
7 | 7 | from unittest.mock import Mock, patch |
8 | 8 |
|
9 | 9 | import pytest |
@@ -53,10 +53,16 @@ def __init__( |
53 | 53 | self.num_classes = num_classes |
54 | 54 | self.state = state |
55 | 55 | self.not_jsonable = not_jsonable |
| 56 | + |
| 57 | + class DummyModelWithConfigAndKwargs(nn.Module, PyTorchModelHubMixin): |
| 58 | + def __init__(self, num_classes: int = 42, state: str = "layernorm", config: Optional[Dict] = None, **kwargs): |
| 59 | + super().__init__() |
| 60 | + |
56 | 61 | else: |
57 | 62 | DummyModel = None |
58 | 63 | DummyModelWithTags = None |
59 | 64 | DummyModelNoConfig = None |
| 65 | + DummyModelWithConfigAndKwargs = None |
60 | 66 |
|
61 | 67 |
|
62 | 68 | @requires("torch") |
@@ -346,3 +352,12 @@ def forward(self, x): |
346 | 352 | b_bias_ptr = state_dict["b.bias"].storage().data_ptr() |
347 | 353 | assert a_weight_ptr == b_weight_ptr |
348 | 354 | assert a_bias_ptr == b_bias_ptr |
| 355 | + |
| 356 | + def test_save_pretrained_when_config_and_kwargs_are_passed(self): |
| 357 | + # Test creating model with config and kwargs => all values are saved together in config.json |
| 358 | + model = DummyModelWithConfigAndKwargs(num_classes=50, state="layernorm", config={"a": 1}, b=2, c=3) |
| 359 | + model.save_pretrained(self.cache_dir) |
| 360 | + assert model._hub_mixin_config == {"num_classes": 50, "state": "layernorm", "a": 1, "b": 2, "c": 3} |
| 361 | + |
| 362 | + reloaded = DummyModelWithConfigAndKwargs.from_pretrained(self.cache_dir) |
| 363 | + assert reloaded._hub_mixin_config == model._hub_mixin_config |
0 commit comments