2828 ModelCard ,
2929 ModelCardData ,
3030 create_repo ,
31- get_full_repo_name ,
3231 hf_hub_download ,
3332 upload_folder ,
3433)
6766logger = get_logger (__name__ )
6867
6968
70- MODEL_CARD_TEMPLATE_PATH = Path (__file__ ).parent / "model_card_template.md"
7169SESSION_ID = uuid4 ().hex
7270
7371
@@ -95,53 +93,45 @@ def http_user_agent(user_agent: Union[Dict, str, None] = None) -> str:
9593 return ua
9694
9795
98- def create_model_card (args , model_name ):
96+ def load_or_create_model_card (
97+ repo_id_or_path : Optional [str ] = None , token : Optional [str ] = None , is_pipeline : bool = False
98+ ) -> ModelCard :
99+ """
100+ Loads or creates a model card.
101+
102+ Args:
103+ repo_id (`str`):
104+ The repo_id where to look for the model card.
105+ token (`str`, *optional*):
106+ Authentication token. Will default to the stored token. See https://huggingface.co/settings/token for more details.
107+ is_pipeline (`bool`, *optional*):
108+ Boolean to indicate if we're adding tag to a [`DiffusionPipeline`].
109+ """
99110 if not is_jinja_available ():
100111 raise ValueError (
101112 "Modelcard rendering is based on Jinja templates."
102113 " Please make sure to have `jinja` installed before using `create_model_card`."
103114 " To install it, please run `pip install Jinja2`."
104115 )
105116
106- if hasattr (args , "local_rank" ) and args .local_rank not in [- 1 , 0 ]:
107- return
108-
109- hub_token = args .hub_token if hasattr (args , "hub_token" ) else None
110- repo_name = get_full_repo_name (model_name , token = hub_token )
111-
112- model_card = ModelCard .from_template (
113- card_data = ModelCardData ( # Card metadata object that will be converted to YAML block
114- language = "en" ,
115- license = "apache-2.0" ,
116- library_name = "diffusers" ,
117- tags = [],
118- datasets = args .dataset_name ,
119- metrics = [],
120- ),
121- template_path = MODEL_CARD_TEMPLATE_PATH ,
122- model_name = model_name ,
123- repo_name = repo_name ,
124- dataset_name = args .dataset_name if hasattr (args , "dataset_name" ) else None ,
125- learning_rate = args .learning_rate ,
126- train_batch_size = args .train_batch_size ,
127- eval_batch_size = args .eval_batch_size ,
128- gradient_accumulation_steps = (
129- args .gradient_accumulation_steps if hasattr (args , "gradient_accumulation_steps" ) else None
130- ),
131- adam_beta1 = args .adam_beta1 if hasattr (args , "adam_beta1" ) else None ,
132- adam_beta2 = args .adam_beta2 if hasattr (args , "adam_beta2" ) else None ,
133- adam_weight_decay = args .adam_weight_decay if hasattr (args , "adam_weight_decay" ) else None ,
134- adam_epsilon = args .adam_epsilon if hasattr (args , "adam_epsilon" ) else None ,
135- lr_scheduler = args .lr_scheduler if hasattr (args , "lr_scheduler" ) else None ,
136- lr_warmup_steps = args .lr_warmup_steps if hasattr (args , "lr_warmup_steps" ) else None ,
137- ema_inv_gamma = args .ema_inv_gamma if hasattr (args , "ema_inv_gamma" ) else None ,
138- ema_power = args .ema_power if hasattr (args , "ema_power" ) else None ,
139- ema_max_decay = args .ema_max_decay if hasattr (args , "ema_max_decay" ) else None ,
140- mixed_precision = args .mixed_precision ,
141- )
142-
143- card_path = os .path .join (args .output_dir , "README.md" )
144- model_card .save (card_path )
117+ try :
118+ # Check if the model card is present on the remote repo
119+ model_card = ModelCard .load (repo_id_or_path , token = token )
120+ except EntryNotFoundError :
121+ # Otherwise create a simple model card from template
122+ component = "pipeline" if is_pipeline else "model"
123+ model_description = f"This is the model card of a 🧨 diffusers { component } that has been pushed on the Hub. This model card has been automatically generated."
124+ card_data = ModelCardData ()
125+ model_card = ModelCard .from_template (card_data , model_description = model_description )
126+
127+ return model_card
128+
129+
130+ def populate_model_card (model_card : ModelCard ) -> ModelCard :
131+ """Populates the `model_card` with library name."""
132+ if model_card .data .library_name is None :
133+ model_card .data .library_name = "diffusers"
134+ return model_card
145135
146136
147137def extract_commit_hash (resolved_file : Optional [str ], commit_hash : Optional [str ] = None ):
@@ -435,6 +425,10 @@ def push_to_hub(
435425 """
436426 repo_id = create_repo (repo_id , private = private , token = token , exist_ok = True ).repo_id
437427
428+ # Create a new empty model card and eventually tag it
429+ model_card = load_or_create_model_card (repo_id , token = token )
430+ model_card = populate_model_card (model_card )
431+
438432 # Save all files.
439433 save_kwargs = {"safe_serialization" : safe_serialization }
440434 if "Scheduler" not in self .__class__ .__name__ :
@@ -443,6 +437,9 @@ def push_to_hub(
443437 with tempfile .TemporaryDirectory () as tmpdir :
444438 self .save_pretrained (tmpdir , ** save_kwargs )
445439
440+ # Update model card if needed:
441+ model_card .save (os .path .join (tmpdir , "README.md" ))
442+
446443 return self ._upload_folder (
447444 tmpdir ,
448445 repo_id ,
0 commit comments