Skip to content

Commit a2dd4b8

Browse files
NielsRoggegithub-actions[bot]hanouticelina
authored
Add paper URL to hub mixin (#2917)
* First draft * Add test * Apply style fixes * fix tests * style --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Co-authored-by: Celina Hanouti <[email protected]>
1 parent 7a24bca commit a2dd4b8

File tree

3 files changed

+32
-8
lines changed

3 files changed

+32
-8
lines changed

src/huggingface_hub/hub_mixin.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,8 @@ class DataclassInstance(Protocol):
5858
---
5959
6060
This model has been pushed to the Hub using the [PytorchModelHubMixin](https://huggingface.co/docs/huggingface_hub/package_reference/mixins#huggingface_hub.PyTorchModelHubMixin) integration:
61-
- Library: {{ repo_url | default("[More Information Needed]", true) }}
61+
- Code: {{ repo_url | default("[More Information Needed]", true) }}
62+
- Paper: {{ paper_url | default("[More Information Needed]", true) }}
6263
- Docs: {{ docs_url | default("[More Information Needed]", true) }}
6364
"""
6465

@@ -67,8 +68,9 @@ class DataclassInstance(Protocol):
6768
class MixinInfo:
6869
model_card_template: str
6970
model_card_data: ModelCardData
70-
repo_url: Optional[str] = None
7171
docs_url: Optional[str] = None
72+
paper_url: Optional[str] = None
73+
repo_url: Optional[str] = None
7274

7375

7476
class ModelHubMixin:
@@ -88,6 +90,8 @@ class ModelHubMixin:
8890
Args:
8991
repo_url (`str`, *optional*):
9092
URL of the library repository. Used to generate model card.
93+
paper_url (`str`, *optional*):
94+
URL of the library paper. Used to generate model card.
9195
docs_url (`str`, *optional*):
9296
URL of the library documentation. Used to generate model card.
9397
model_card_template (`str`, *optional*):
@@ -110,7 +114,7 @@ class ModelHubMixin:
110114
pipeline_tag (`str`, *optional*):
111115
Tag of the pipeline. Used to generate model card. E.g. "text-classification".
112116
tags (`List[str]`, *optional*):
113-
Tags to be added to the model card. Used to generate model card. E.g. ["x-custom-tag", "arxiv:2304.12244"]
117+
Tags to be added to the model card. Used to generate model card. E.g. ["computer-vision"]
114118
coders (`Dict[Type, Tuple[Callable, Callable]]`, *optional*):
115119
Dictionary of custom types and their encoders/decoders. Used to encode/decode arguments that are not
116120
jsonable by default. E.g dataclasses, argparse.Namespace, OmegaConf, etc.
@@ -124,8 +128,9 @@ class ModelHubMixin:
124128
>>> class MyCustomModel(
125129
... ModelHubMixin,
126130
... library_name="my-library",
127-
... tags=["x-custom-tag", "arxiv:2304.12244"],
131+
... tags=["computer-vision"],
128132
... repo_url="https://github.com/huggingface/my-cool-library",
133+
... paper_url="https://arxiv.org/abs/2304.12244",
129134
... docs_url="https://huggingface.co/docs/my-cool-library",
130135
... # ^ optional metadata to generate model card
131136
... ):
@@ -194,6 +199,7 @@ def __init_subclass__(
194199
*,
195200
# Generic info for model card
196201
repo_url: Optional[str] = None,
202+
paper_url: Optional[str] = None,
197203
docs_url: Optional[str] = None,
198204
# Model card template
199205
model_card_template: str = DEFAULT_MODEL_CARD,
@@ -234,6 +240,7 @@ def __init_subclass__(
234240

235241
# Inherit other info
236242
info.docs_url = cls._hub_mixin_info.docs_url
243+
info.paper_url = cls._hub_mixin_info.paper_url
237244
info.repo_url = cls._hub_mixin_info.repo_url
238245
cls._hub_mixin_info = info
239246

@@ -242,6 +249,8 @@ def __init_subclass__(
242249
info.model_card_template = model_card_template
243250
if repo_url is not None:
244251
info.repo_url = repo_url
252+
if paper_url is not None:
253+
info.paper_url = paper_url
245254
if docs_url is not None:
246255
info.docs_url = docs_url
247256
if language is not None:
@@ -692,6 +701,7 @@ def generate_model_card(self, *args, **kwargs) -> ModelCard:
692701
card_data=self._hub_mixin_info.model_card_data,
693702
template_str=self._hub_mixin_info.model_card_template,
694703
repo_url=self._hub_mixin_info.repo_url,
704+
paper_url=self._hub_mixin_info.paper_url,
695705
docs_url=self._hub_mixin_info.docs_url,
696706
**kwargs,
697707
)
@@ -718,6 +728,7 @@ class PyTorchModelHubMixin(ModelHubMixin):
718728
... PyTorchModelHubMixin,
719729
... library_name="keras-nlp",
720730
... repo_url="https://github.com/keras-team/keras-nlp",
731+
... paper_url="https://arxiv.org/abs/2304.12244",
721732
... docs_url="https://keras.io/keras_nlp/",
722733
... # ^ optional metadata to generate model card
723734
... ):

tests/test_hub_mixin.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,12 @@ def _from_pretrained(
9191
return cls(**kwargs)
9292

9393

94-
class BaseModelForInheritance(ModelHubMixin, repo_url="https://hf.co/my-repo", library_name="my-cool-library"):
94+
class BaseModelForInheritance(
95+
ModelHubMixin,
96+
repo_url="https://hf.co/my-repo",
97+
paper_url="https://arxiv.org/abs/2304.12244",
98+
library_name="my-cool-library",
99+
):
95100
pass
96101

97102

@@ -452,6 +457,7 @@ def test_inherited_class(self):
452457
"""Test MixinInfo attributes are inherited from the parent class."""
453458
model = DummyModelInherited()
454459
assert model._hub_mixin_info.repo_url == "https://hf.co/my-repo"
460+
assert model._hub_mixin_info.paper_url == "https://arxiv.org/abs/2304.12244"
455461
assert model._hub_mixin_info.model_card_data.library_name == "my-cool-library"
456462

457463
def test_autocomplete_works_as_expected(self):

tests/test_hub_mixin_pytorch.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,12 @@
3636
---
3737
3838
This is a dummy model card with kwargs.
39-
Arxiv ID: 1234.56789
4039
4140
{{ custom_data }}
41+
42+
- Code: {{ repo_url }}
43+
- Paper: {{ paper_url }}
44+
- Docs: {{ docs_url }}
4245
"""
4346

4447
if is_torch_available():
@@ -93,6 +96,9 @@ class DummyModelWithModelCardAndCustomKwargs(
9396
nn.Module,
9497
PyTorchModelHubMixin,
9598
model_card_template=DUMMY_MODEL_CARD_TEMPLATE_WITH_CUSTOM_KWARGS,
99+
docs_url="https://hf.co/docs/my-repo",
100+
paper_url="https://arxiv.org/abs/2304.12244",
101+
repo_url="https://hf.co/my-repo",
96102
):
97103
def __init__(self, linear_layer: int = 4):
98104
super().__init__()
@@ -331,7 +337,6 @@ def test_generate_model_card(self):
331337
assert card.data.license == "apache-2.0"
332338
assert card.data.pipeline_tag == "text-classification"
333339
assert card.data.tags == ["model_hub_mixin", "pytorch_model_hub_mixin", "tag1", "tag2"]
334-
335340
# Model card template has been used
336341
assert "This is a dummy model card" in str(card)
337342

@@ -438,7 +443,9 @@ def test_model_card_with_custom_kwargs(self):
438443
model = DummyModelWithModelCardAndCustomKwargs()
439444
card = model.generate_model_card(**model_card_kwargs)
440445
assert model_card_kwargs["custom_data"] in str(card)
441-
446+
assert "Code: https://hf.co/my-repo" in str(card)
447+
assert "Paper: https://arxiv.org/abs/2304.12244" in str(card)
448+
assert "Docs: https://hf.co/docs/my-repo" in str(card)
442449
# Test saving card => model card is saved and restored with custom data
443450
model.save_pretrained(self.cache_dir, model_card_kwargs=model_card_kwargs)
444451
card_reloaded = ModelCard.load(self.cache_dir / "README.md")

0 commit comments

Comments
 (0)