Skip to content

Commit 30e5192

Browse files
Wauplinqubvel
andcommitted
ModelHubMixin: Fix attributes lost in inheritance (#2305)
* ModelHubMixn: Fix attributes lost in inhericance * make style * deprecate * style * Update src/huggingface_hub/hub_mixin.py Co-authored-by: Pavel Iakubovskii <[email protected]> --------- Co-authored-by: Pavel Iakubovskii <[email protected]>
1 parent 83efe51 commit 30e5192

File tree

4 files changed

+100
-45
lines changed

4 files changed

+100
-45
lines changed

src/huggingface_hub/hub_mixin.py

Lines changed: 46 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import inspect
22
import json
33
import os
4+
import warnings
45
from dataclasses import asdict, dataclass, is_dataclass
56
from pathlib import Path
67
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union, get_args
@@ -85,8 +86,8 @@ class ModelHubMixin:
8586
URL of the library documentation. Used to generate model card.
8687
model_card_template (`str`, *optional*):
8788
Template of the model card. Used to generate model card. Defaults to a generic template.
88-
languages (`List[str]`, *optional*):
89-
Languages supported by the library. Used to generate model card.
89+
language (`str` or `List[str]`, *optional*):
90+
Language supported by the library. Used to generate model card.
9091
library_name (`str`, *optional*):
9192
Name of the library integrating ModelHubMixin. Used to generate model card.
9293
license (`str`, *optional*):
@@ -191,7 +192,7 @@ def __init_subclass__(
191192
# Model card template
192193
model_card_template: str = DEFAULT_MODEL_CARD,
193194
# Model card metadata
194-
languages: Optional[List[str]] = None,
195+
language: Optional[List[str]] = None,
195196
library_name: Optional[str] = None,
196197
license: Optional[str] = None,
197198
license_name: Optional[str] = None,
@@ -205,27 +206,55 @@ def __init_subclass__(
205206
# Value is a tuple (encoder, decoder).
206207
# Example: {MyCustomType: (lambda x: x.value, lambda data: MyCustomType(data))}
207208
] = None,
209+
# Deprecated arguments
210+
languages: Optional[List[str]] = None,
208211
) -> None:
209212
"""Inspect __init__ signature only once when subclassing + handle modelcard."""
210213
super().__init_subclass__()
211214

212215
# Will be reused when creating modelcard
213216
tags = tags or []
214217
tags.append("model_hub_mixin")
215-
cls._hub_mixin_info = MixinInfo(
216-
model_card_template=model_card_template,
217-
repo_url=repo_url,
218-
docs_url=docs_url,
219-
model_card_data=ModelCardData(
220-
languages=languages,
221-
library_name=library_name,
222-
license=license,
223-
license_name=license_name,
224-
license_link=license_link,
225-
pipeline_tag=pipeline_tag,
226-
tags=tags,
227-
),
228-
)
218+
219+
# Initialize MixinInfo if not existent
220+
if not hasattr(cls, "_hub_mixin_info"):
221+
cls._hub_mixin_info = MixinInfo(
222+
model_card_template=model_card_template,
223+
model_card_data=ModelCardData(),
224+
)
225+
info = cls._hub_mixin_info
226+
227+
if languages is not None:
228+
warnings.warn(
229+
"The `languages` argument is deprecated. Use `language` instead. This will be removed in `huggingface_hub>=0.27.0`.",
230+
DeprecationWarning,
231+
)
232+
language = languages
233+
234+
# Update MixinInfo with metadata
235+
if model_card_template is not None and model_card_template != DEFAULT_MODEL_CARD:
236+
info.model_card_template = model_card_template
237+
if repo_url is not None:
238+
info.repo_url = repo_url
239+
if docs_url is not None:
240+
info.docs_url = docs_url
241+
if language is not None:
242+
info.model_card_data.language = language
243+
if library_name is not None:
244+
info.model_card_data.library_name = library_name
245+
if license is not None:
246+
info.model_card_data.license = license
247+
if license_name is not None:
248+
info.model_card_data.license_name = license_name
249+
if license_link is not None:
250+
info.model_card_data.license_link = license_link
251+
if pipeline_tag is not None:
252+
info.model_card_data.pipeline_tag = pipeline_tag
253+
if tags is not None:
254+
if info.model_card_data.tags is not None:
255+
info.model_card_data.tags.extend(tags)
256+
else:
257+
info.model_card_data.tags = tags
229258

230259
# Handle encoders/decoders for args
231260
cls._hub_mixin_coders = coders or {}

src/huggingface_hub/repocard_data.py

Lines changed: 38 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -242,37 +242,43 @@ class ModelCardData(CardData):
242242
"""Model Card Metadata that is used by Hugging Face Hub when included at the top of your README.md
243243
244244
Args:
245-
language (`Union[str, List[str]]`, *optional*):
246-
Language of model's training data or metadata. It must be an ISO 639-1, 639-2 or
247-
639-3 code (two/three letters), or a special value like "code", "multilingual". Defaults to `None`.
248-
license (`str`, *optional*):
249-
License of this model. Example: apache-2.0 or any license from
250-
https://huggingface.co/docs/hub/repositories-licenses. Defaults to None.
251-
library_name (`str`, *optional*):
252-
Name of library used by this model. Example: keras or any library from
253-
https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/src/model-libraries.ts.
254-
Defaults to None.
255-
tags (`List[str]`, *optional*):
256-
List of tags to add to your model that can be used when filtering on the Hugging
257-
Face Hub. Defaults to None.
258245
base_model (`str` or `List[str]`, *optional*):
259246
The identifier of the base model from which the model derives. This is applicable for example if your model is a
260247
fine-tune or adapter of an existing model. The value must be the ID of a model on the Hub (or a list of IDs
261248
if your model derives from multiple models). Defaults to None.
262249
datasets (`List[str]`, *optional*):
263250
List of datasets that were used to train this model. Should be a dataset ID
264251
found on https://hf.co/datasets. Defaults to None.
265-
metrics (`List[str]`, *optional*):
266-
List of metrics used to evaluate this model. Should be a metric name that can be found
267-
at https://hf.co/metrics. Example: 'accuracy'. Defaults to None.
268252
eval_results (`Union[List[EvalResult], EvalResult]`, *optional*):
269253
List of `huggingface_hub.EvalResult` that define evaluation results of the model. If provided,
270254
`model_name` is used to as a name on PapersWithCode's leaderboards. Defaults to `None`.
255+
language (`Union[str, List[str]]`, *optional*):
256+
Language of model's training data or metadata. It must be an ISO 639-1, 639-2 or
257+
639-3 code (two/three letters), or a special value like "code", "multilingual". Defaults to `None`.
258+
library_name (`str`, *optional*):
259+
Name of library used by this model. Example: keras or any library from
260+
https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/src/model-libraries.ts.
261+
Defaults to None.
262+
license (`str`, *optional*):
263+
License of this model. Example: apache-2.0 or any license from
264+
https://huggingface.co/docs/hub/repositories-licenses. Defaults to None.
265+
license_name (`str`, *optional*):
266+
Name of the license of this model. Defaults to None. To be used in conjunction with `license_link`.
267+
Common licenses (Apache-2.0, MIT, CC-BY-SA-4.0) do not need a name. In that case, use `license` instead.
268+
license_link (`str`, *optional*):
269+
Link to the license of this model. Defaults to None. To be used in conjunction with `license_name`.
270+
Common licenses (Apache-2.0, MIT, CC-BY-SA-4.0) do not need a link. In that case, use `license` instead.
271+
metrics (`List[str]`, *optional*):
272+
List of metrics used to evaluate this model. Should be a metric name that can be found
273+
at https://hf.co/metrics. Example: 'accuracy'. Defaults to None.
271274
model_name (`str`, *optional*):
272275
A name for this model. It is used along with
273276
`eval_results` to construct the `model-index` within the card's metadata. The name
274277
you supply here is what will be used on PapersWithCode's leaderboards. If None is provided
275278
then the repo name is used as a default. Defaults to None.
279+
tags (`List[str]`, *optional*):
280+
List of tags to add to your model that can be used when filtering on the Hugging
281+
Face Hub. Defaults to None.
276282
ignore_metadata_errors (`str`):
277283
If True, errors while parsing the metadata section will be ignored. Some information might be lost during
278284
the process. Use it at your own risk.
@@ -297,27 +303,33 @@ class ModelCardData(CardData):
297303
def __init__(
298304
self,
299305
*,
300-
language: Optional[Union[str, List[str]]] = None,
301-
license: Optional[str] = None,
302-
library_name: Optional[str] = None,
303-
tags: Optional[List[str]] = None,
304306
base_model: Optional[Union[str, List[str]]] = None,
305307
datasets: Optional[List[str]] = None,
306-
metrics: Optional[List[str]] = None,
307308
eval_results: Optional[List[EvalResult]] = None,
309+
language: Optional[Union[str, List[str]]] = None,
310+
library_name: Optional[str] = None,
311+
license: Optional[str] = None,
312+
license_name: Optional[str] = None,
313+
license_link: Optional[str] = None,
314+
metrics: Optional[List[str]] = None,
308315
model_name: Optional[str] = None,
316+
pipeline_tag: Optional[str] = None,
317+
tags: Optional[List[str]] = None,
309318
ignore_metadata_errors: bool = False,
310319
**kwargs,
311320
):
312-
self.language = language
313-
self.license = license
314-
self.library_name = library_name
315-
self.tags = _to_unique_list(tags)
316321
self.base_model = base_model
317322
self.datasets = datasets
318-
self.metrics = metrics
319323
self.eval_results = eval_results
324+
self.language = language
325+
self.library_name = library_name
326+
self.license = license
327+
self.license_name = license_name
328+
self.license_link = license_link
329+
self.metrics = metrics
320330
self.model_name = model_name
331+
self.pipeline_tag = pipeline_tag
332+
self.tags = _to_unique_list(tags)
321333

322334
model_index = kwargs.pop("model-index", None)
323335
if model_index:

tests/test_hub_mixin.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,14 @@ def _from_pretrained(
9090
return cls(**kwargs)
9191

9292

93+
class BaseModelForInheritance(ModelHubMixin, repo_url="https://hf.co/my-repo", library_name="my-cool-library"):
94+
pass
95+
96+
97+
class DummyModelInherited(BaseModelForInheritance):
98+
pass
99+
100+
93101
class DummyModelSavingConfig(ModelHubMixin):
94102
def _save_pretrained(self, save_directory: Path) -> None:
95103
"""Implementation that uses `config.json` to serialize the config.
@@ -414,3 +422,9 @@ def test_from_cls_with_custom_type(self):
414422
assert model_reloaded.bar == "bar"
415423
assert model_reloaded.custom.value == "custom"
416424
assert model_reloaded.custom_default.value == "default"
425+
426+
def test_inherited_class(self):
427+
"""Test MixinInfo attributes are inherited from the parent class."""
428+
model = DummyModelInherited()
429+
assert model._hub_mixin_info.repo_url == "https://hf.co/my-repo"
430+
assert model._hub_mixin_info.model_card_data.library_name == "my-cool-library"

tests/test_hub_mixin_pytorch.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -282,11 +282,11 @@ def test_push_to_hub(self):
282282
def test_generate_model_card(self):
283283
model = DummyModelWithModelCard()
284284
card = model.generate_model_card()
285-
assert card.data.languages == ["en", "zh"]
285+
assert card.data.language == ["en", "zh"]
286286
assert card.data.library_name == "my-dummy-lib"
287287
assert card.data.license == "apache-2.0"
288288
assert card.data.pipeline_tag == "text-classification"
289-
assert card.data.tags == ["tag1", "tag2", "pytorch_model_hub_mixin", "model_hub_mixin"]
289+
assert card.data.tags == ["model_hub_mixin", "pytorch_model_hub_mixin", "tag1", "tag2"]
290290

291291
# Model card template has been used
292292
assert "This is a dummy model card." in str(card)

0 commit comments

Comments
 (0)