Skip to content

Commit adf5672

Browse files
qubvelWauplin
authored andcommitted
Support custom kwargs for model card in save_pretrained (#2310)
* Support custom kwargs for model card in save_pretrained * Fix failing test * Fix test for pytorch mixin * Add test for model_card_kwargs * Fix style
1 parent 23d3bb4 commit adf5672

File tree

3 files changed

+49
-8
lines changed

3 files changed

+49
-8
lines changed

src/huggingface_hub/hub_mixin.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -374,6 +374,7 @@ def save_pretrained(
374374
config: Optional[Union[dict, "DataclassInstance"]] = None,
375375
repo_id: Optional[str] = None,
376376
push_to_hub: bool = False,
377+
model_card_kwargs: Optional[Dict[str, Any]] = None,
377378
**push_to_hub_kwargs,
378379
) -> Optional[str]:
379380
"""
@@ -389,7 +390,9 @@ def save_pretrained(
389390
repo_id (`str`, *optional*):
390391
ID of your repository on the Hub. Used only if `push_to_hub=True`. Will default to the folder name if
391392
not provided.
392-
kwargs:
393+
model_card_kwargs (`Dict[str, Any]`, *optional*):
394+
Additional arguments passed to the model card template to customize the model card.
395+
push_to_hub_kwargs:
393396
Additional key word arguments passed along to the [`~ModelHubMixin.push_to_hub`] method.
394397
Returns:
395398
`str` or `None`: url of the commit on the Hub if `push_to_hub=True`, `None` otherwise.
@@ -418,8 +421,9 @@ def save_pretrained(
418421

419422
# save model card
420423
model_card_path = save_directory / "README.md"
424+
model_card_kwargs = model_card_kwargs if model_card_kwargs is not None else {}
421425
if not model_card_path.exists(): # do not overwrite if already exists
422-
self.generate_model_card().save(save_directory / "README.md")
426+
self.generate_model_card(**model_card_kwargs).save(save_directory / "README.md")
423427

424428
# push to the Hub if required
425429
if push_to_hub:
@@ -428,7 +432,7 @@ def save_pretrained(
428432
kwargs["config"] = config
429433
if repo_id is None:
430434
repo_id = save_directory.name # Defaults to `save_directory` name
431-
return self.push_to_hub(repo_id=repo_id, **kwargs)
435+
return self.push_to_hub(repo_id=repo_id, model_card_kwargs=model_card_kwargs, **kwargs)
432436
return None
433437

434438
def _save_pretrained(self, save_directory: Path) -> None:
@@ -637,6 +641,7 @@ def push_to_hub(
637641
allow_patterns: Optional[Union[List[str], str]] = None,
638642
ignore_patterns: Optional[Union[List[str], str]] = None,
639643
delete_patterns: Optional[Union[List[str], str]] = None,
644+
model_card_kwargs: Optional[Dict[str, Any]] = None,
640645
) -> str:
641646
"""
642647
Upload model checkpoint to the Hub.
@@ -667,6 +672,8 @@ def push_to_hub(
667672
If provided, files matching any of the patterns are not pushed.
668673
delete_patterns (`List[str]` or `str`, *optional*):
669674
If provided, remote files matching any of the patterns will be deleted from the repo.
675+
model_card_kwargs (`Dict[str, Any]`, *optional*):
676+
Additional arguments passed to the model card template to customize the model card.
670677
671678
Returns:
672679
The url of the commit of your model in the given repository.
@@ -677,7 +684,7 @@ def push_to_hub(
677684
# Push the files to the repo in a single commit
678685
with SoftTemporaryDirectory() as tmp:
679686
saved_path = Path(tmp) / repo_id
680-
self.save_pretrained(saved_path, config=config)
687+
self.save_pretrained(saved_path, config=config, model_card_kwargs=model_card_kwargs)
681688
return api.upload_folder(
682689
repo_id=repo_id,
683690
repo_type="model",
@@ -696,6 +703,7 @@ def generate_model_card(self, *args, **kwargs) -> ModelCard:
696703
template_str=self._hub_mixin_info.model_card_template,
697704
repo_url=self._hub_mixin_info.repo_url,
698705
docs_url=self._hub_mixin_info.docs_url,
706+
**kwargs,
699707
)
700708
return card
701709

tests/test_hub_mixin.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -293,11 +293,11 @@ def test_save_pretrained_with_push_to_hub(self):
293293

294294
# Push to hub with repo_id (config is pushed)
295295
mocked_model.save_pretrained(save_directory, push_to_hub=True, repo_id="CustomID")
296-
mocked_model.push_to_hub.assert_called_with(repo_id="CustomID", config=CONFIG_AS_DICT)
296+
mocked_model.push_to_hub.assert_called_with(repo_id="CustomID", config=CONFIG_AS_DICT, model_card_kwargs={})
297297

298298
# Push to hub with default repo_id (based on dir name)
299299
mocked_model.save_pretrained(save_directory, push_to_hub=True)
300-
mocked_model.push_to_hub.assert_called_with(repo_id=repo_id, config=CONFIG_AS_DICT)
300+
mocked_model.push_to_hub.assert_called_with(repo_id=repo_id, config=CONFIG_AS_DICT, model_card_kwargs={})
301301

302302
@patch.object(DummyModelNoConfig, "_from_pretrained")
303303
def test_from_pretrained_model_id_only(self, from_pretrained_mock: Mock) -> None:

tests/test_hub_mixin_pytorch.py

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,17 @@
2828
Arxiv ID: 1234.56789
2929
"""
3030

31+
DUMMY_MODEL_CARD_TEMPLATE_WITH_CUSTOM_KWARGS = """
32+
---
33+
{{ card_data }}
34+
---
35+
36+
This is a dummy model card with kwargs.
37+
Arxiv ID: 1234.56789
38+
39+
{{ custom_data }}
40+
"""
41+
3142
if is_torch_available():
3243
import torch
3344
import torch.nn as nn
@@ -76,11 +87,20 @@ class DummyModelWithConfigAndKwargs(nn.Module, PyTorchModelHubMixin):
7687
def __init__(self, num_classes: int = 42, state: str = "layernorm", config: Optional[Dict] = None, **kwargs):
7788
super().__init__()
7889

90+
class DummyModelWithModelCardAndCustomKwargs(
91+
nn.Module,
92+
PyTorchModelHubMixin,
93+
model_card_template=DUMMY_MODEL_CARD_TEMPLATE_WITH_CUSTOM_KWARGS,
94+
):
95+
def __init__(self, linear_layer: int = 4):
96+
super().__init__()
97+
7998
else:
8099
DummyModel = None
81100
DummyModelWithModelCard = None
82101
DummyModelNoConfig = None
83102
DummyModelWithConfigAndKwargs = None
103+
DummyModelWithModelCardAndCustomKwargs = None
84104

85105

86106
@requires("torch")
@@ -130,11 +150,11 @@ def test_save_pretrained_with_push_to_hub(self):
130150

131151
# Push to hub with repo_id
132152
mocked_model.save_pretrained(save_directory, push_to_hub=True, repo_id="CustomID", config=config)
133-
mocked_model.push_to_hub.assert_called_with(repo_id="CustomID", config=config)
153+
mocked_model.push_to_hub.assert_called_with(repo_id="CustomID", config=config, model_card_kwargs={})
134154

135155
# Push to hub with default repo_id (based on dir name)
136156
mocked_model.save_pretrained(save_directory, push_to_hub=True, config=config)
137-
mocked_model.push_to_hub.assert_called_with(repo_id=repo_id, config=config)
157+
mocked_model.push_to_hub.assert_called_with(repo_id=repo_id, config=config, model_card_kwargs={})
138158

139159
@patch.object(DummyModel, "_from_pretrained")
140160
def test_from_pretrained_model_id_only(self, from_pretrained_mock: Mock) -> None:
@@ -386,3 +406,16 @@ def test_save_pretrained_when_config_and_kwargs_are_passed(self):
386406

387407
reloaded = DummyModelWithConfigAndKwargs.from_pretrained(self.cache_dir)
388408
assert reloaded._hub_mixin_config == model._hub_mixin_config
409+
410+
def test_model_card_with_custom_kwargs(self):
411+
model_card_kwargs = {"custom_data": "This is a model custom data: 42."}
412+
413+
# Test creating model with custom kwargs => custom data is saved in model card
414+
model = DummyModelWithModelCardAndCustomKwargs()
415+
card = model.generate_model_card(**model_card_kwargs)
416+
assert model_card_kwargs["custom_data"] in str(card)
417+
418+
# Test saving card => model card is saved and restored with custom data
419+
model.save_pretrained(self.cache_dir, model_card_kwargs=model_card_kwargs)
420+
card_reloaded = ModelCard.load(self.cache_dir / "README.md")
421+
assert str(card) == str(card_reloaded)

0 commit comments

Comments
 (0)