Skip to content

Commit 08d3355

Browse files
lvwerraadrinjalaliosansevierojulien-cLysandreJik
committed
ENH Add update metadata to repocard (#844)
* add `metadata_update` function * add tests * add docstring * Apply suggestions from code review Co-authored-by: Adrin Jalali <[email protected]> * refactore `_update_metadata_model_index` * Apply suggestions from code review Co-authored-by: Adrin Jalali <[email protected]> * fix style and imports * switch to deepcopy everywhere * load repo in repocard test into tmp folder * simplify results and metrics checks when updating metadata * run black * Apply suggestions from code review Co-authored-by: Omar Sanseviero <[email protected]> * fix pyyaml version to work with `sort_keys` kwarg * don't allow empty commits if file hasn't changed * switch order of updates to first check model-index for easier readbility * expose repocard functions through `__init__` * fix init * make style & quality * revert to for-loop * Apply suggestions from code review Co-authored-by: Julien Chaumond <[email protected]> Co-authored-by: Lysandre Debut <[email protected]> * post suggestion fixes * add example * add type to list Co-authored-by: Adrin Jalali <[email protected]> Co-authored-by: Omar Sanseviero <[email protected]> Co-authored-by: Julien Chaumond <[email protected]> Co-authored-by: Lysandre Debut <[email protected]>
1 parent 0c36bf6 commit 08d3355

File tree

4 files changed

+361
-3
lines changed

4 files changed

+361
-3
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ def get_version() -> str:
1515
"filelock",
1616
"requests",
1717
"tqdm",
18-
"pyyaml",
18+
"pyyaml>=5.1",
1919
"typing-extensions>=3.7.4.3", # to be able to import TypeAlias
2020
"importlib_metadata;python_version<'3.8'",
2121
"packaging>=20.9",

src/huggingface_hub/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,12 @@
7373
push_to_hub_keras,
7474
save_pretrained_keras,
7575
)
76+
from .repocard import (
77+
metadata_eval_result,
78+
metadata_load,
79+
metadata_save,
80+
metadata_update,
81+
)
7682
from .repository import Repository
7783
from .snapshot_download import snapshot_download
7884
from .utils import logging

src/huggingface_hub/repocard.py

Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
from typing import Any, Dict, Optional, Union
66

77
import yaml
8+
from huggingface_hub.file_download import hf_hub_download
9+
from huggingface_hub.hf_api import HfApi
810
from huggingface_hub.repocard_types import (
911
ModelIndex,
1012
SingleMetric,
@@ -13,10 +15,15 @@
1315
SingleResultTask,
1416
)
1517

18+
from .constants import REPOCARD_NAME
19+
1620

1721
# exact same regex as in the Hub server. Please keep in sync.
1822
REGEX_YAML_BLOCK = re.compile(r"---[\n\r]+([\S\s]*?)[\n\r]+---[\n\r]")
1923

24+
UNIQUE_RESULT_FEATURES = ["dataset", "task"]
25+
UNIQUE_METRIC_FEATURES = ["name", "type"]
26+
2027

2128
def metadata_load(local_path: Union[str, Path]) -> Optional[Dict]:
2229
content = Path(local_path).read_text()
@@ -99,3 +106,178 @@ def metadata_eval_result(
99106
model_index, dict_factory=lambda x: {k: v for (k, v) in x if v is not None}
100107
)
101108
return {"model-index": [data]}
109+
110+
111+
def metadata_update(
112+
repo_id: str,
113+
metadata: Dict,
114+
*,
115+
repo_type: str = None,
116+
overwrite: bool = False,
117+
token: str = None,
118+
) -> str:
119+
"""
120+
Updates the metadata in the README.md of a repository on the Hugging Face Hub.
121+
122+
Example:
123+
>>> from huggingface_hub import metadata_update
124+
>>> metadata = {'model-index': [{'name': 'RoBERTa fine-tuned on ReactionGIF',
125+
... 'results': [{'dataset': {'name': 'ReactionGIF',
126+
... 'type': 'julien-c/reactiongif'},
127+
... 'metrics': [{'name': 'Recall',
128+
... 'type': 'recall',
129+
... 'value': 0.7762102282047272}],
130+
... 'task': {'name': 'Text Classification',
131+
... 'type': 'text-classification'}}]}]}
132+
>>> update_metdata("julien-c/reactiongif-roberta", metadata)
133+
134+
Args:
135+
repo_id (`str`):
136+
The name of the repository.
137+
metadata (`dict`):
138+
A dictionary containing the metadata to be updated.
139+
repo_type (`str`, *optional*):
140+
Set to `"dataset"` or `"space"` if updating to a dataset or space,
141+
`None` or `"model"` if updating to a model. Default is `None`.
142+
overwrite (`bool`, *optional*, defaults to `False`):
143+
If set to `True` an existing field can be overwritten, otherwise
144+
attempting to overwrite an existing field will cause an error.
145+
token (`str`, *optional*):
146+
The Hugging Face authentication token.
147+
148+
Returns:
149+
`str`: URL of the commit which updated the card metadata.
150+
"""
151+
152+
filepath = hf_hub_download(
153+
repo_id,
154+
filename=REPOCARD_NAME,
155+
repo_type=repo_type,
156+
use_auth_token=token,
157+
force_download=True,
158+
)
159+
existing_metadata = metadata_load(filepath)
160+
161+
for key in metadata:
162+
# update model index containing the evaluation results
163+
if key == "model-index":
164+
if "model-index" not in existing_metadata:
165+
existing_metadata["model-index"] = metadata["model-index"]
166+
else:
167+
# the model-index contains a list of results as used by PwC but only has one element thus we take the first one
168+
existing_metadata["model-index"][0][
169+
"results"
170+
] = _update_metadata_model_index(
171+
existing_metadata["model-index"][0]["results"],
172+
metadata["model-index"][0]["results"],
173+
overwrite=overwrite,
174+
)
175+
# update all fields except model index
176+
else:
177+
if key in existing_metadata and not overwrite:
178+
if existing_metadata[key] != metadata[key]:
179+
raise ValueError(
180+
f"""You passed a new value for the existing meta data field '{key}'. Set `overwrite=True` to overwrite existing metadata."""
181+
)
182+
else:
183+
existing_metadata[key] = metadata[key]
184+
185+
# save and push to hub
186+
metadata_save(filepath, existing_metadata)
187+
188+
return HfApi().upload_file(
189+
path_or_fileobj=filepath,
190+
path_in_repo=REPOCARD_NAME,
191+
repo_id=repo_id,
192+
repo_type=repo_type,
193+
identical_ok=False,
194+
token=token,
195+
)
196+
197+
198+
def _update_metadata_model_index(existing_results, new_results, overwrite=False):
199+
"""
200+
Updates the model-index fields in the metadata. If results with same unique
201+
features exist they are updated, else a new result is appended. Updating existing
202+
values is only possible if `overwrite=True`.
203+
204+
Args:
205+
new_metrics (`List[dict]`):
206+
List of new metadata results.
207+
existing_metrics (`List[dict]`):
208+
List of existing metadata results.
209+
overwrite (`bool`, *optional*, defaults to `False`):
210+
If set to `True`, an existing metric values can be overwritten, otherwise
211+
attempting to overwrite an existing field will cause an error.
212+
213+
Returns:
214+
`list`: List of updated metadata results
215+
"""
216+
for new_result in new_results:
217+
result_found = False
218+
for existing_result_index, existing_result in enumerate(existing_results):
219+
if all(
220+
new_result[feat] == existing_result[feat]
221+
for feat in UNIQUE_RESULT_FEATURES
222+
):
223+
result_found = True
224+
existing_results[existing_result_index][
225+
"metrics"
226+
] = _update_metadata_results_metric(
227+
new_result["metrics"],
228+
existing_result["metrics"],
229+
overwrite=overwrite,
230+
)
231+
if not result_found:
232+
existing_results.append(new_result)
233+
return existing_results
234+
235+
236+
def _update_metadata_results_metric(new_metrics, existing_metrics, overwrite=False):
237+
"""
238+
Updates the metrics list of a result in the metadata. If metrics with same unique
239+
features exist their values are updated, else a new metric is appended. Updating
240+
existing values is only possible if `overwrite=True`.
241+
242+
Args:
243+
new_metrics (`list`):
244+
List of new metrics.
245+
existing_metrics (`list`):
246+
List of existing metrics.
247+
overwrite (`bool`, *optional*, defaults to `False`):
248+
If set to `True`, an existing metric values can be overwritten, otherwise
249+
attempting to overwrite an existing field will cause an error.
250+
251+
Returns:
252+
`list`: List of updated metrics
253+
"""
254+
for new_metric in new_metrics:
255+
metric_exists = False
256+
for existing_metric_index, existing_metric in enumerate(existing_metrics):
257+
if all(
258+
new_metric[feat] == existing_metric[feat]
259+
for feat in UNIQUE_METRIC_FEATURES
260+
):
261+
if overwrite:
262+
existing_metrics[existing_metric_index]["value"] = new_metric[
263+
"value"
264+
]
265+
else:
266+
# if metric exists and value is not the same throw an error without overwrite flag
267+
if (
268+
existing_metrics[existing_metric_index]["value"]
269+
!= new_metric["value"]
270+
):
271+
existing_str = ", ".join(
272+
f"{feat}: {new_metric[feat]}"
273+
for feat in UNIQUE_METRIC_FEATURES
274+
)
275+
raise ValueError(
276+
"You passed a new value for the existing metric"
277+
f" '{existing_str}'. Set `overwrite=True` to overwrite"
278+
" existing metrics."
279+
)
280+
metric_exists = True
281+
if not metric_exists:
282+
existing_metrics.append(new_metric)
283+
return existing_metrics

0 commit comments

Comments
 (0)