Skip to content

Commit d4c7ab7

Browse files
sayakpaulpatrickvonplatenWauplin
authored
[Hub] feat: explicitly tag to diffusers when using push_to_hub (#6678)
* feat: explicitly tag to diffusers when using push_to_hub * remove tags. * reset repo. * Apply suggestions from code review Co-authored-by: Patrick von Platen <[email protected]> * fix: tests * fix: push_to_hub behaviour for tagging from save_pretrained * Apply suggestions from code review Co-authored-by: Lucain <[email protected]> * Apply suggestions from code review Co-authored-by: Lucain <[email protected]> * import fixes. * add library name to existing model card. * add: standalone test for generate_model_card * fix tests for standalone method * moved library_name to a better place. * merge create_model_card and generate_model_card. * fix test * address lucain's comments * fix return identation * Apply suggestions from code review Co-authored-by: Lucain <[email protected]> * address further comments. * Update src/diffusers/pipelines/pipeline_utils.py Co-authored-by: Lucain <[email protected]> --------- Co-authored-by: Patrick von Platen <[email protected]> Co-authored-by: Lucain <[email protected]>
1 parent ea9dc3f commit d4c7ab7

File tree

7 files changed

+100
-124
lines changed

7 files changed

+100
-124
lines changed

src/diffusers/models/modeling_utils.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
is_torch_version,
4343
logging,
4444
)
45-
from ..utils.hub_utils import PushToHubMixin
45+
from ..utils.hub_utils import PushToHubMixin, load_or_create_model_card, populate_model_card
4646

4747

4848
logger = logging.get_logger(__name__)
@@ -377,6 +377,11 @@ def save_pretrained(
377377
logger.info(f"Model weights saved in {os.path.join(save_directory, weights_name)}")
378378

379379
if push_to_hub:
380+
# Create a new empty model card and eventually tag it
381+
model_card = load_or_create_model_card(repo_id, token=token)
382+
model_card = populate_model_card(model_card)
383+
model_card.save(os.path.join(save_directory, "README.md"))
384+
380385
self._upload_folder(
381386
save_directory,
382387
repo_id,

src/diffusers/pipelines/pipeline_utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@
6060
logging,
6161
numpy_to_pil,
6262
)
63+
from ..utils.hub_utils import load_or_create_model_card, populate_model_card
6364
from ..utils.torch_utils import is_compiled_module
6465

6566

@@ -725,6 +726,11 @@ def is_saveable_module(name, value):
725726
self.save_config(save_directory)
726727

727728
if push_to_hub:
729+
# Create a new empty model card and eventually tag it
730+
model_card = load_or_create_model_card(repo_id, token=token, is_pipeline=True)
731+
model_card = populate_model_card(model_card)
732+
model_card.save(os.path.join(save_directory, "README.md"))
733+
728734
self._upload_folder(
729735
save_directory,
730736
repo_id,

src/diffusers/utils/hub_utils.py

Lines changed: 39 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
ModelCard,
2929
ModelCardData,
3030
create_repo,
31-
get_full_repo_name,
3231
hf_hub_download,
3332
upload_folder,
3433
)
@@ -67,7 +66,6 @@
6766
logger = get_logger(__name__)
6867

6968

70-
MODEL_CARD_TEMPLATE_PATH = Path(__file__).parent / "model_card_template.md"
7169
SESSION_ID = uuid4().hex
7270

7371

@@ -95,53 +93,45 @@ def http_user_agent(user_agent: Union[Dict, str, None] = None) -> str:
9593
return ua
9694

9795

98-
def create_model_card(args, model_name):
96+
def load_or_create_model_card(
97+
repo_id_or_path: Optional[str] = None, token: Optional[str] = None, is_pipeline: bool = False
98+
) -> ModelCard:
99+
"""
100+
Loads or creates a model card.
101+
102+
Args:
103+
repo_id (`str`):
104+
The repo_id where to look for the model card.
105+
token (`str`, *optional*):
106+
Authentication token. Will default to the stored token. See https://huggingface.co/settings/token for more details.
107+
is_pipeline (`bool`, *optional*):
108+
Boolean to indicate if we're adding tag to a [`DiffusionPipeline`].
109+
"""
99110
if not is_jinja_available():
100111
raise ValueError(
101112
"Modelcard rendering is based on Jinja templates."
102113
" Please make sure to have `jinja` installed before using `create_model_card`."
103114
" To install it, please run `pip install Jinja2`."
104115
)
105116

106-
if hasattr(args, "local_rank") and args.local_rank not in [-1, 0]:
107-
return
108-
109-
hub_token = args.hub_token if hasattr(args, "hub_token") else None
110-
repo_name = get_full_repo_name(model_name, token=hub_token)
111-
112-
model_card = ModelCard.from_template(
113-
card_data=ModelCardData( # Card metadata object that will be converted to YAML block
114-
language="en",
115-
license="apache-2.0",
116-
library_name="diffusers",
117-
tags=[],
118-
datasets=args.dataset_name,
119-
metrics=[],
120-
),
121-
template_path=MODEL_CARD_TEMPLATE_PATH,
122-
model_name=model_name,
123-
repo_name=repo_name,
124-
dataset_name=args.dataset_name if hasattr(args, "dataset_name") else None,
125-
learning_rate=args.learning_rate,
126-
train_batch_size=args.train_batch_size,
127-
eval_batch_size=args.eval_batch_size,
128-
gradient_accumulation_steps=(
129-
args.gradient_accumulation_steps if hasattr(args, "gradient_accumulation_steps") else None
130-
),
131-
adam_beta1=args.adam_beta1 if hasattr(args, "adam_beta1") else None,
132-
adam_beta2=args.adam_beta2 if hasattr(args, "adam_beta2") else None,
133-
adam_weight_decay=args.adam_weight_decay if hasattr(args, "adam_weight_decay") else None,
134-
adam_epsilon=args.adam_epsilon if hasattr(args, "adam_epsilon") else None,
135-
lr_scheduler=args.lr_scheduler if hasattr(args, "lr_scheduler") else None,
136-
lr_warmup_steps=args.lr_warmup_steps if hasattr(args, "lr_warmup_steps") else None,
137-
ema_inv_gamma=args.ema_inv_gamma if hasattr(args, "ema_inv_gamma") else None,
138-
ema_power=args.ema_power if hasattr(args, "ema_power") else None,
139-
ema_max_decay=args.ema_max_decay if hasattr(args, "ema_max_decay") else None,
140-
mixed_precision=args.mixed_precision,
141-
)
142-
143-
card_path = os.path.join(args.output_dir, "README.md")
144-
model_card.save(card_path)
117+
try:
118+
# Check if the model card is present on the remote repo
119+
model_card = ModelCard.load(repo_id_or_path, token=token)
120+
except EntryNotFoundError:
121+
# Otherwise create a simple model card from template
122+
component = "pipeline" if is_pipeline else "model"
123+
model_description = f"This is the model card of a 🧨 diffusers {component} that has been pushed on the Hub. This model card has been automatically generated."
124+
card_data = ModelCardData()
125+
model_card = ModelCard.from_template(card_data, model_description=model_description)
126+
127+
return model_card
128+
129+
130+
def populate_model_card(model_card: ModelCard) -> ModelCard:
131+
"""Populates the `model_card` with library name."""
132+
if model_card.data.library_name is None:
133+
model_card.data.library_name = "diffusers"
134+
return model_card
145135

146136

147137
def extract_commit_hash(resolved_file: Optional[str], commit_hash: Optional[str] = None):
@@ -435,6 +425,10 @@ def push_to_hub(
435425
"""
436426
repo_id = create_repo(repo_id, private=private, token=token, exist_ok=True).repo_id
437427

428+
# Create a new empty model card and eventually tag it
429+
model_card = load_or_create_model_card(repo_id, token=token)
430+
model_card = populate_model_card(model_card)
431+
438432
# Save all files.
439433
save_kwargs = {"safe_serialization": safe_serialization}
440434
if "Scheduler" not in self.__class__.__name__:
@@ -443,6 +437,9 @@ def push_to_hub(
443437
with tempfile.TemporaryDirectory() as tmpdir:
444438
self.save_pretrained(tmpdir, **save_kwargs)
445439

440+
# Update model card if needed:
441+
model_card.save(os.path.join(tmpdir, "README.md"))
442+
446443
return self._upload_folder(
447444
tmpdir,
448445
repo_id,

src/diffusers/utils/model_card_template.md

Lines changed: 0 additions & 50 deletions
This file was deleted.

tests/models/test_modeling_common.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@
2424
import numpy as np
2525
import requests_mock
2626
import torch
27-
from huggingface_hub import delete_repo
27+
from huggingface_hub import ModelCard, delete_repo
28+
from huggingface_hub.utils import is_jinja_available
2829
from requests.exceptions import HTTPError
2930

3031
from diffusers.models import UNet2DConditionModel
@@ -732,3 +733,26 @@ def test_push_to_hub_in_organization(self):
732733

733734
# Reset repo
734735
delete_repo(self.org_repo_id, token=TOKEN)
736+
737+
@unittest.skipIf(
738+
not is_jinja_available(),
739+
reason="Model card tests cannot be performed without Jinja installed.",
740+
)
741+
def test_push_to_hub_library_name(self):
742+
model = UNet2DConditionModel(
743+
block_out_channels=(32, 64),
744+
layers_per_block=2,
745+
sample_size=32,
746+
in_channels=4,
747+
out_channels=4,
748+
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
749+
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
750+
cross_attention_dim=32,
751+
)
752+
model.push_to_hub(self.repo_id, token=TOKEN)
753+
754+
model_card = ModelCard.load(f"{USER}/{self.repo_id}", token=TOKEN).data
755+
assert model_card.library_name == "diffusers"
756+
757+
# Reset repo
758+
delete_repo(self.repo_id, token=TOKEN)

tests/others/test_hub_utils.py

Lines changed: 7 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -15,37 +15,15 @@
1515
import unittest
1616
from pathlib import Path
1717
from tempfile import TemporaryDirectory
18-
from unittest.mock import Mock, patch
1918

20-
import diffusers.utils.hub_utils
19+
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
2120

2221

2322
class CreateModelCardTest(unittest.TestCase):
24-
@patch("diffusers.utils.hub_utils.get_full_repo_name")
25-
def test_create_model_card(self, repo_name_mock: Mock) -> None:
26-
repo_name_mock.return_value = "full_repo_name"
23+
def test_generate_model_card_with_library_name(self):
2724
with TemporaryDirectory() as tmpdir:
28-
# Dummy args values
29-
args = Mock()
30-
args.output_dir = tmpdir
31-
args.local_rank = 0
32-
args.hub_token = "hub_token"
33-
args.dataset_name = "dataset_name"
34-
args.learning_rate = 0.01
35-
args.train_batch_size = 100000
36-
args.eval_batch_size = 10000
37-
args.gradient_accumulation_steps = 0.01
38-
args.adam_beta1 = 0.02
39-
args.adam_beta2 = 0.03
40-
args.adam_weight_decay = 0.0005
41-
args.adam_epsilon = 0.000001
42-
args.lr_scheduler = 1
43-
args.lr_warmup_steps = 10
44-
args.ema_inv_gamma = 0.001
45-
args.ema_power = 0.1
46-
args.ema_max_decay = 0.2
47-
args.mixed_precision = True
48-
49-
# Model card mush be rendered and saved
50-
diffusers.utils.hub_utils.create_model_card(args, model_name="model_name")
51-
self.assertTrue((Path(tmpdir) / "README.md").is_file())
25+
file_path = Path(tmpdir) / "README.md"
26+
file_path.write_text("---\nlibrary_name: foo\n---\nContent\n")
27+
model_card = load_or_create_model_card(file_path)
28+
populate_model_card(model_card)
29+
assert model_card.data.library_name == "foo"

tests/pipelines/test_pipelines_common.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@
1313
import numpy as np
1414
import PIL.Image
1515
import torch
16-
from huggingface_hub import delete_repo
16+
from huggingface_hub import ModelCard, delete_repo
17+
from huggingface_hub.utils import is_jinja_available
1718
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
1819

1920
import diffusers
@@ -1142,6 +1143,21 @@ def test_push_to_hub_in_organization(self):
11421143
# Reset repo
11431144
delete_repo(self.org_repo_id, token=TOKEN)
11441145

1146+
@unittest.skipIf(
1147+
not is_jinja_available(),
1148+
reason="Model card tests cannot be performed without Jinja installed.",
1149+
)
1150+
def test_push_to_hub_library_name(self):
1151+
components = self.get_pipeline_components()
1152+
pipeline = StableDiffusionPipeline(**components)
1153+
pipeline.push_to_hub(self.repo_id, token=TOKEN)
1154+
1155+
model_card = ModelCard.load(f"{USER}/{self.repo_id}", token=TOKEN).data
1156+
assert model_card.library_name == "diffusers"
1157+
1158+
# Reset repo
1159+
delete_repo(self.repo_id, token=TOKEN)
1160+
11451161

11461162
# For SDXL and its derivative pipelines (such as ControlNet), we have the text encoders
11471163
# and the tokenizers as optional components. So, we need to override the `test_save_load_optional_components()`

0 commit comments

Comments
 (0)