Skip to content

Commit 131fd35

Browse files
authored
FIX overwriting metadata when both verified and unverified reported values (#1186)
* FIX overwriting metadata when both verified and unverified reported value * fix error message
1 parent 711f688 commit 131fd35

File tree

3 files changed

+113
-73
lines changed

3 files changed

+113
-73
lines changed

src/huggingface_hub/repocard.py

Lines changed: 16 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -788,47 +788,38 @@ def metadata_update(
788788
else:
789789
existing_results = card.data.eval_results
790790

791+
# Iterate over new results
792+
# Iterate over existing results
793+
# If both results describe the same metric but value is different:
794+
# If overwrite=True: overwrite the metric value
795+
# Else: raise ValueError
796+
# Else: append new result to existing ones.
791797
for new_result in new_results:
792798
result_found = False
793-
for existing_result_index, existing_result in enumerate(
794-
existing_results
795-
):
796-
if all(
797-
[
798-
new_result.dataset_name == existing_result.dataset_name,
799-
new_result.dataset_type == existing_result.dataset_type,
800-
new_result.task_type == existing_result.task_type,
801-
new_result.task_name == existing_result.task_name,
802-
new_result.metric_name == existing_result.metric_name,
803-
new_result.metric_type == existing_result.metric_type,
804-
]
805-
):
806-
if (
807-
new_result.metric_value != existing_result.metric_value
808-
and not overwrite
809-
):
810-
existing_str = (
811-
f"name: {new_result.metric_name}, type:"
812-
f" {new_result.metric_type}"
813-
)
799+
for existing_result in existing_results:
800+
if new_result.is_equal_except_value(existing_result):
801+
if new_result != existing_result and not overwrite:
814802
raise ValueError(
815803
"You passed a new value for the existing metric"
816-
f" '{existing_str}'. Set `overwrite=True` to"
817-
" overwrite existing metrics."
804+
f" 'name: {new_result.metric_name}, type: "
805+
f"{new_result.metric_type}'. Set `overwrite=True`"
806+
" to overwrite existing metrics."
818807
)
819808
result_found = True
820-
card.data.eval_results[existing_result_index] = new_result
809+
existing_result.metric_value = new_result.metric_value
821810
if not result_found:
822811
card.data.eval_results.append(new_result)
823812
else:
813+
# Any metadata that is not a result metric
824814
if (
825815
hasattr(card.data, key)
826816
and getattr(card.data, key) is not None
827817
and not overwrite
828818
and getattr(card.data, key) != value
829819
):
830820
raise ValueError(
831-
f"""You passed a new value for the existing meta data field '{key}'. Set `overwrite=True` to overwrite existing metadata."""
821+
f"You passed a new value for the existing meta data field '{key}'."
822+
" Set `overwrite=True` to overwrite existing metadata."
832823
)
833824
else:
834825
setattr(card.data, key, value)

src/huggingface_hub/repocard_data.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,18 @@ class EvalResult:
121121
# A JSON Web Token that is used to verify whether the metrics originate from Hugging Face's [evaluation service](https://huggingface.co/spaces/autoevaluate/model-evaluator) or not.
122122
verify_token: Optional[str] = None
123123

124+
def is_equal_except_value(self, other: "EvalResult") -> bool:
125+
"""
126+
Return True if `self` and `other` describe exactly the same metric but with a
127+
different value.
128+
"""
129+
for key, _ in self.__dict__.items():
130+
if key == "metric_value":
131+
continue
132+
if getattr(self, key) != getattr(other, key):
133+
return False
134+
return True
135+
124136

125137
@dataclass
126138
class CardData:

tests/test_repocard.py

Lines changed: 85 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,43 @@
136136
---
137137
"""
138138

139+
DUMMY_MODELCARD_EVAL_RESULT_BOTH_VERIFIED_AND_UNVERIFIED = """---
140+
model-index:
141+
- name: RoBERTa fine-tuned on ReactionGIF
142+
results:
143+
- task:
144+
type: text-classification
145+
name: Text Classification
146+
dataset:
147+
name: ReactionGIF
148+
type: julien-c/reactiongif
149+
config: default
150+
split: test
151+
metrics:
152+
- type: accuracy
153+
value: 0.2662102282047272
154+
name: Accuracy
155+
config: default
156+
verified: false
157+
- task:
158+
type: text-classification
159+
name: Text Classification
160+
dataset:
161+
name: ReactionGIF
162+
type: julien-c/reactiongif
163+
config: default
164+
split: test
165+
metrics:
166+
- type: accuracy
167+
value: 0.6666666666666666
168+
name: Accuracy
169+
config: default
170+
verified: true
171+
---
172+
173+
This is a test model card.
174+
"""
175+
139176
logger = logging.get_logger(__name__)
140177

141178
REPOCARD_DIR = os.path.join(
@@ -240,17 +277,18 @@ def setUpClass(cls):
240277
def setUp(self) -> None:
241278
self.repo_path = Path(tempfile.mkdtemp())
242279
self.REPO_NAME = repo_name()
243-
self._api.create_repo(f"{USER}/{self.REPO_NAME}")
280+
self.repo_id = f"{USER}/{self.REPO_NAME}"
281+
self._api.create_repo(self.repo_id)
244282
self._api.upload_file(
245283
path_or_fileobj=DUMMY_MODELCARD_EVAL_RESULT.encode(),
246-
repo_id=f"{USER}/{self.REPO_NAME}",
284+
repo_id=self.repo_id,
247285
path_in_repo="README.md",
248286
commit_message="Add README to main branch",
249287
)
250288

251289
self.repo = Repository(
252290
self.repo_path / self.REPO_NAME,
253-
clone_from=f"{USER}/{self.REPO_NAME}",
291+
clone_from=self.repo_id,
254292
use_auth_token=self._token,
255293
git_user="ci",
256294
git_email="[email protected]",
@@ -260,14 +298,12 @@ def setUp(self) -> None:
260298
)
261299

262300
def tearDown(self) -> None:
263-
self._api.delete_repo(repo_id=f"{self.REPO_NAME}")
301+
self._api.delete_repo(repo_id=self.repo_id)
264302
shutil.rmtree(self.repo_path)
265303

266304
def test_update_dataset_name(self):
267305
new_datasets_data = {"datasets": ["test/test_dataset"]}
268-
metadata_update(
269-
f"{USER}/{self.REPO_NAME}", new_datasets_data, token=self._token
270-
)
306+
metadata_update(self.repo_id, new_datasets_data, token=self._token)
271307

272308
self.repo.git_pull()
273309
updated_metadata = metadata_load(self.repo_path / self.REPO_NAME / "README.md")
@@ -280,9 +316,7 @@ def test_update_existing_result_with_overwrite(self):
280316
new_metadata["model-index"][0]["results"][0]["metrics"][0][
281317
"value"
282318
] = 0.2862102282047272
283-
metadata_update(
284-
f"{USER}/{self.REPO_NAME}", new_metadata, token=self._token, overwrite=True
285-
)
319+
metadata_update(self.repo_id, new_metadata, token=self._token, overwrite=True)
286320

287321
self.repo.git_pull()
288322
updated_metadata = metadata_load(self.repo_path / self.REPO_NAME / "README.md")
@@ -293,14 +327,12 @@ def test_metadata_update_upstream(self):
293327
new_metadata["model-index"][0]["results"][0]["metrics"][0]["value"] = 0.1
294328

295329
path = hf_hub_download(
296-
f"{USER}/{self.REPO_NAME}",
330+
self.repo_id,
297331
filename=REPOCARD_NAME,
298332
use_auth_token=self._token,
299333
)
300334

301-
metadata_update(
302-
f"{USER}/{self.REPO_NAME}", new_metadata, token=self._token, overwrite=True
303-
)
335+
metadata_update(self.repo_id, new_metadata, token=self._token, overwrite=True)
304336

305337
self.assertNotEqual(metadata_load(path), new_metadata)
306338
self.assertEqual(metadata_load(path), self.existing_metadata)
@@ -319,17 +351,12 @@ def test_update_existing_result_without_overwrite(self):
319351
),
320352
):
321353
metadata_update(
322-
f"{USER}/{self.REPO_NAME}",
323-
new_metadata,
324-
token=self._token,
325-
overwrite=False,
354+
self.repo_id, new_metadata, token=self._token, overwrite=False
326355
)
327356

328357
def test_update_existing_field_without_overwrite(self):
329358
new_datasets_data = {"datasets": "['test/test_dataset']"}
330-
metadata_update(
331-
f"{USER}/{self.REPO_NAME}", new_datasets_data, token=self._token
332-
)
359+
metadata_update(self.repo_id, new_datasets_data, token=self._token)
333360

334361
with pytest.raises(
335362
ValueError,
@@ -340,7 +367,7 @@ def test_update_existing_field_without_overwrite(self):
340367
):
341368
new_datasets_data = {"datasets": "['test/test_dataset_2']"}
342369
metadata_update(
343-
f"{USER}/{self.REPO_NAME}",
370+
self.repo_id,
344371
new_datasets_data,
345372
token=self._token,
346373
overwrite=False,
@@ -362,9 +389,7 @@ def test_update_new_result_existing_dataset(self):
362389
dataset_split="test",
363390
)
364391

365-
metadata_update(
366-
f"{USER}/{self.REPO_NAME}", new_result, token=self._token, overwrite=False
367-
)
392+
metadata_update(self.repo_id, new_result, token=self._token, overwrite=False)
368393

369394
expected_metadata = copy.deepcopy(self.existing_metadata)
370395
expected_metadata["model-index"][0]["results"][0]["metrics"].append(
@@ -391,9 +416,7 @@ def test_update_new_result_new_dataset(self):
391416
dataset_split="test",
392417
)
393418

394-
metadata_update(
395-
f"{USER}/{self.REPO_NAME}", new_result, token=self._token, overwrite=False
396-
)
419+
metadata_update(self.repo_id, new_result, token=self._token, overwrite=False)
397420

398421
expected_metadata = copy.deepcopy(self.existing_metadata)
399422
expected_metadata["model-index"][0]["results"].append(
@@ -412,7 +435,7 @@ def test_update_metadata_on_empty_text_content(self) -> None:
412435
with self.repo.commit("Add README to main branch"):
413436
with open("README.md", "w") as f:
414437
f.write(DUMMY_MODELCARD_NO_TEXT_CONTENT)
415-
metadata_update(f"{USER}/{self.REPO_NAME}", {"tag": "test"}, token=self._token)
438+
metadata_update(self.repo_id, {"tag": "test"}, token=self._token)
416439

417440
# Check update went fine
418441
self.repo.git_pull()
@@ -426,43 +449,57 @@ def test_update_with_existing_name(self):
426449
new_metadata["model-index"][0]["results"][0]["metrics"][0][
427450
"value"
428451
] = 0.2862102282047272
452+
metadata_update(self.repo_id, new_metadata, token=self._token, overwrite=True)
429453

430-
metadata_update(
431-
f"{USER}/{self.REPO_NAME}",
432-
new_metadata,
433-
token=self._token,
434-
overwrite=True,
435-
)
436-
437-
card_data = ModelCard.load(f"{USER}/{self.REPO_NAME}", token=self._token)
438-
454+
card_data = ModelCard.load(self.repo_id, token=self._token)
439455
self.assertEqual(
440456
card_data.data.model_name, self.existing_metadata["model-index"][0]["name"]
441457
)
442458

443459
def test_update_without_existing_name(self):
444-
445460
# delete existing metadata
446461
self._api.upload_file(
447462
path_or_fileobj="# Test".encode(),
448-
repo_id=f"{USER}/{self.REPO_NAME}",
463+
repo_id=self.repo_id,
449464
path_in_repo="README.md",
450-
commit_message="Add README to main branch",
451465
)
452466

453467
new_metadata = copy.deepcopy(self.existing_metadata)
454468
new_metadata["model-index"][0].pop("name")
455469

456-
metadata_update(
457-
f"{USER}/{self.REPO_NAME}",
458-
new_metadata,
459-
token=self._token,
460-
overwrite=True,
470+
metadata_update(self.repo_id, new_metadata, token=self._token, overwrite=True)
471+
472+
card_data = ModelCard.load(self.repo_id, token=self._token)
473+
474+
self.assertEqual(card_data.data.model_name, self.repo_id)
475+
476+
def test_update_with_both_verified_and_unverified_metric(self):
477+
"""Regression test for #1185.
478+
479+
See https://github.com/huggingface/huggingface_hub/issues/1185.
480+
"""
481+
self._api.upload_file(
482+
path_or_fileobj=DUMMY_MODELCARD_EVAL_RESULT_BOTH_VERIFIED_AND_UNVERIFIED.encode(),
483+
repo_id=self.repo_id,
484+
path_in_repo="README.md",
461485
)
486+
card = ModelCard.load(self.repo_id)
487+
metadata = card.data.to_dict()
488+
metadata_update(self.repo_id, metadata=metadata, overwrite=True, token=TOKEN)
489+
490+
card_data = ModelCard.load(self.repo_id, token=self._token)
491+
492+
self.assertEqual(len(card_data.data.eval_results), 2)
493+
first_result = card_data.data.eval_results[0]
494+
second_result = card_data.data.eval_results[1]
462495

463-
card_data = ModelCard.load(f"{USER}/{self.REPO_NAME}", token=self._token)
496+
# One is verified, the other not
497+
self.assertFalse(first_result.verified)
498+
self.assertTrue(second_result.verified)
464499

465-
self.assertEqual(card_data.data.model_name, f"{USER}/{self.REPO_NAME}")
500+
# Result values are different
501+
self.assertEqual(first_result.metric_value, 0.2662102282047272)
502+
self.assertEqual(second_result.metric_value, 0.6666666666666666)
466503

467504

468505
class TestMetadataUpdateOnMissingCard(unittest.TestCase):

0 commit comments

Comments
 (0)