|
5 | 5 | from typing import Any, Dict, Optional, Union |
6 | 6 |
|
7 | 7 | import yaml |
| 8 | +from huggingface_hub.file_download import hf_hub_download |
| 9 | +from huggingface_hub.hf_api import HfApi |
8 | 10 | from huggingface_hub.repocard_types import ( |
9 | 11 | ModelIndex, |
10 | 12 | SingleMetric, |
|
13 | 15 | SingleResultTask, |
14 | 16 | ) |
15 | 17 |
|
| 18 | +from .constants import REPOCARD_NAME |
| 19 | + |
16 | 20 |
|
17 | 21 | # exact same regex as in the Hub server. Please keep in sync. |
18 | 22 | REGEX_YAML_BLOCK = re.compile(r"---[\n\r]+([\S\s]*?)[\n\r]+---[\n\r]") |
19 | 23 |
|
| 24 | +UNIQUE_RESULT_FEATURES = ["dataset", "task"] |
| 25 | +UNIQUE_METRIC_FEATURES = ["name", "type"] |
| 26 | + |
20 | 27 |
|
21 | 28 | def metadata_load(local_path: Union[str, Path]) -> Optional[Dict]: |
22 | 29 | content = Path(local_path).read_text() |
@@ -99,3 +106,178 @@ def metadata_eval_result( |
99 | 106 | model_index, dict_factory=lambda x: {k: v for (k, v) in x if v is not None} |
100 | 107 | ) |
101 | 108 | return {"model-index": [data]} |
| 109 | + |
| 110 | + |
| 111 | +def metadata_update( |
| 112 | + repo_id: str, |
| 113 | + metadata: Dict, |
| 114 | + *, |
| 115 | + repo_type: str = None, |
| 116 | + overwrite: bool = False, |
| 117 | + token: str = None, |
| 118 | +) -> str: |
| 119 | + """ |
| 120 | + Updates the metadata in the README.md of a repository on the Hugging Face Hub. |
| 121 | +
|
| 122 | + Example: |
| 123 | + >>> from huggingface_hub import metadata_update |
| 124 | + >>> metadata = {'model-index': [{'name': 'RoBERTa fine-tuned on ReactionGIF', |
| 125 | + ... 'results': [{'dataset': {'name': 'ReactionGIF', |
| 126 | + ... 'type': 'julien-c/reactiongif'}, |
| 127 | + ... 'metrics': [{'name': 'Recall', |
| 128 | + ... 'type': 'recall', |
| 129 | + ... 'value': 0.7762102282047272}], |
| 130 | + ... 'task': {'name': 'Text Classification', |
| 131 | + ... 'type': 'text-classification'}}]}]} |
| 132 | + >>> update_metdata("julien-c/reactiongif-roberta", metadata) |
| 133 | +
|
| 134 | + Args: |
| 135 | + repo_id (`str`): |
| 136 | + The name of the repository. |
| 137 | + metadata (`dict`): |
| 138 | + A dictionary containing the metadata to be updated. |
| 139 | + repo_type (`str`, *optional*): |
| 140 | + Set to `"dataset"` or `"space"` if updating to a dataset or space, |
| 141 | + `None` or `"model"` if updating to a model. Default is `None`. |
| 142 | + overwrite (`bool`, *optional*, defaults to `False`): |
| 143 | + If set to `True` an existing field can be overwritten, otherwise |
| 144 | + attempting to overwrite an existing field will cause an error. |
| 145 | + token (`str`, *optional*): |
| 146 | + The Hugging Face authentication token. |
| 147 | +
|
| 148 | + Returns: |
| 149 | + `str`: URL of the commit which updated the card metadata. |
| 150 | + """ |
| 151 | + |
| 152 | + filepath = hf_hub_download( |
| 153 | + repo_id, |
| 154 | + filename=REPOCARD_NAME, |
| 155 | + repo_type=repo_type, |
| 156 | + use_auth_token=token, |
| 157 | + force_download=True, |
| 158 | + ) |
| 159 | + existing_metadata = metadata_load(filepath) |
| 160 | + |
| 161 | + for key in metadata: |
| 162 | + # update model index containing the evaluation results |
| 163 | + if key == "model-index": |
| 164 | + if "model-index" not in existing_metadata: |
| 165 | + existing_metadata["model-index"] = metadata["model-index"] |
| 166 | + else: |
| 167 | + # the model-index contains a list of results as used by PwC but only has one element thus we take the first one |
| 168 | + existing_metadata["model-index"][0][ |
| 169 | + "results" |
| 170 | + ] = _update_metadata_model_index( |
| 171 | + existing_metadata["model-index"][0]["results"], |
| 172 | + metadata["model-index"][0]["results"], |
| 173 | + overwrite=overwrite, |
| 174 | + ) |
| 175 | + # update all fields except model index |
| 176 | + else: |
| 177 | + if key in existing_metadata and not overwrite: |
| 178 | + if existing_metadata[key] != metadata[key]: |
| 179 | + raise ValueError( |
| 180 | + f"""You passed a new value for the existing meta data field '{key}'. Set `overwrite=True` to overwrite existing metadata.""" |
| 181 | + ) |
| 182 | + else: |
| 183 | + existing_metadata[key] = metadata[key] |
| 184 | + |
| 185 | + # save and push to hub |
| 186 | + metadata_save(filepath, existing_metadata) |
| 187 | + |
| 188 | + return HfApi().upload_file( |
| 189 | + path_or_fileobj=filepath, |
| 190 | + path_in_repo=REPOCARD_NAME, |
| 191 | + repo_id=repo_id, |
| 192 | + repo_type=repo_type, |
| 193 | + identical_ok=False, |
| 194 | + token=token, |
| 195 | + ) |
| 196 | + |
| 197 | + |
| 198 | +def _update_metadata_model_index(existing_results, new_results, overwrite=False): |
| 199 | + """ |
| 200 | + Updates the model-index fields in the metadata. If results with same unique |
| 201 | + features exist they are updated, else a new result is appended. Updating existing |
| 202 | + values is only possible if `overwrite=True`. |
| 203 | +
|
| 204 | + Args: |
| 205 | + new_metrics (`List[dict]`): |
| 206 | + List of new metadata results. |
| 207 | + existing_metrics (`List[dict]`): |
| 208 | + List of existing metadata results. |
| 209 | + overwrite (`bool`, *optional*, defaults to `False`): |
| 210 | + If set to `True`, an existing metric values can be overwritten, otherwise |
| 211 | + attempting to overwrite an existing field will cause an error. |
| 212 | +
|
| 213 | + Returns: |
| 214 | + `list`: List of updated metadata results |
| 215 | + """ |
| 216 | + for new_result in new_results: |
| 217 | + result_found = False |
| 218 | + for existing_result_index, existing_result in enumerate(existing_results): |
| 219 | + if all( |
| 220 | + new_result[feat] == existing_result[feat] |
| 221 | + for feat in UNIQUE_RESULT_FEATURES |
| 222 | + ): |
| 223 | + result_found = True |
| 224 | + existing_results[existing_result_index][ |
| 225 | + "metrics" |
| 226 | + ] = _update_metadata_results_metric( |
| 227 | + new_result["metrics"], |
| 228 | + existing_result["metrics"], |
| 229 | + overwrite=overwrite, |
| 230 | + ) |
| 231 | + if not result_found: |
| 232 | + existing_results.append(new_result) |
| 233 | + return existing_results |
| 234 | + |
| 235 | + |
| 236 | +def _update_metadata_results_metric(new_metrics, existing_metrics, overwrite=False): |
| 237 | + """ |
| 238 | + Updates the metrics list of a result in the metadata. If metrics with same unique |
| 239 | + features exist their values are updated, else a new metric is appended. Updating |
| 240 | + existing values is only possible if `overwrite=True`. |
| 241 | +
|
| 242 | + Args: |
| 243 | + new_metrics (`list`): |
| 244 | + List of new metrics. |
| 245 | + existing_metrics (`list`): |
| 246 | + List of existing metrics. |
| 247 | + overwrite (`bool`, *optional*, defaults to `False`): |
| 248 | + If set to `True`, an existing metric values can be overwritten, otherwise |
| 249 | + attempting to overwrite an existing field will cause an error. |
| 250 | +
|
| 251 | + Returns: |
| 252 | + `list`: List of updated metrics |
| 253 | + """ |
| 254 | + for new_metric in new_metrics: |
| 255 | + metric_exists = False |
| 256 | + for existing_metric_index, existing_metric in enumerate(existing_metrics): |
| 257 | + if all( |
| 258 | + new_metric[feat] == existing_metric[feat] |
| 259 | + for feat in UNIQUE_METRIC_FEATURES |
| 260 | + ): |
| 261 | + if overwrite: |
| 262 | + existing_metrics[existing_metric_index]["value"] = new_metric[ |
| 263 | + "value" |
| 264 | + ] |
| 265 | + else: |
| 266 | + # if metric exists and value is not the same throw an error without overwrite flag |
| 267 | + if ( |
| 268 | + existing_metrics[existing_metric_index]["value"] |
| 269 | + != new_metric["value"] |
| 270 | + ): |
| 271 | + existing_str = ", ".join( |
| 272 | + f"{feat}: {new_metric[feat]}" |
| 273 | + for feat in UNIQUE_METRIC_FEATURES |
| 274 | + ) |
| 275 | + raise ValueError( |
| 276 | + "You passed a new value for the existing metric" |
| 277 | + f" '{existing_str}'. Set `overwrite=True` to overwrite" |
| 278 | + " existing metrics." |
| 279 | + ) |
| 280 | + metric_exists = True |
| 281 | + if not metric_exists: |
| 282 | + existing_metrics.append(new_metric) |
| 283 | + return existing_metrics |
0 commit comments