@@ -374,6 +374,7 @@ def save_pretrained(
374374 config : Optional [Union [dict , "DataclassInstance" ]] = None ,
375375 repo_id : Optional [str ] = None ,
376376 push_to_hub : bool = False ,
377+ model_card_kwargs : Optional [Dict [str , Any ]] = None ,
377378 ** push_to_hub_kwargs ,
378379 ) -> Optional [str ]:
379380 """
@@ -389,7 +390,9 @@ def save_pretrained(
389390 repo_id (`str`, *optional*):
390391 ID of your repository on the Hub. Used only if `push_to_hub=True`. Will default to the folder name if
391392 not provided.
392- kwargs:
393+ model_card_kwargs (`Dict[str, Any]`, *optional*):
394+ Additional arguments passed to the model card template to customize the model card.
395+ push_to_hub_kwargs:
393396 Additional key word arguments passed along to the [`~ModelHubMixin.push_to_hub`] method.
394397 Returns:
395398 `str` or `None`: url of the commit on the Hub if `push_to_hub=True`, `None` otherwise.
@@ -418,8 +421,9 @@ def save_pretrained(
418421
419422 # save model card
420423 model_card_path = save_directory / "README.md"
424+ model_card_kwargs = model_card_kwargs if model_card_kwargs is not None else {}
421425 if not model_card_path .exists (): # do not overwrite if already exists
422- self .generate_model_card ().save (save_directory / "README.md" )
426+ self .generate_model_card (** model_card_kwargs ).save (save_directory / "README.md" )
423427
424428 # push to the Hub if required
425429 if push_to_hub :
@@ -428,7 +432,7 @@ def save_pretrained(
428432 kwargs ["config" ] = config
429433 if repo_id is None :
430434 repo_id = save_directory .name # Defaults to `save_directory` name
431- return self .push_to_hub (repo_id = repo_id , ** kwargs )
435+ return self .push_to_hub (repo_id = repo_id , model_card_kwargs = model_card_kwargs , ** kwargs )
432436 return None
433437
434438 def _save_pretrained (self , save_directory : Path ) -> None :
@@ -637,6 +641,7 @@ def push_to_hub(
637641 allow_patterns : Optional [Union [List [str ], str ]] = None ,
638642 ignore_patterns : Optional [Union [List [str ], str ]] = None ,
639643 delete_patterns : Optional [Union [List [str ], str ]] = None ,
644+ model_card_kwargs : Optional [Dict [str , Any ]] = None ,
640645 ) -> str :
641646 """
642647 Upload model checkpoint to the Hub.
@@ -667,6 +672,8 @@ def push_to_hub(
667672 If provided, files matching any of the patterns are not pushed.
668673 delete_patterns (`List[str]` or `str`, *optional*):
669674 If provided, remote files matching any of the patterns will be deleted from the repo.
675+ model_card_kwargs (`Dict[str, Any]`, *optional*):
676+ Additional arguments passed to the model card template to customize the model card.
670677
671678 Returns:
672679 The url of the commit of your model in the given repository.
@@ -677,7 +684,7 @@ def push_to_hub(
677684 # Push the files to the repo in a single commit
678685 with SoftTemporaryDirectory () as tmp :
679686 saved_path = Path (tmp ) / repo_id
680- self .save_pretrained (saved_path , config = config )
687+ self .save_pretrained (saved_path , config = config , model_card_kwargs = model_card_kwargs )
681688 return api .upload_folder (
682689 repo_id = repo_id ,
683690 repo_type = "model" ,
@@ -696,6 +703,7 @@ def generate_model_card(self, *args, **kwargs) -> ModelCard:
696703 template_str = self ._hub_mixin_info .model_card_template ,
697704 repo_url = self ._hub_mixin_info .repo_url ,
698705 docs_url = self ._hub_mixin_info .docs_url ,
706+ ** kwargs ,
699707 )
700708 return card
701709
0 commit comments