Skip to content

Commit a42c629

Browse files
authored
ModelHubMixin overwrite config if preexistant (#2142)
1 parent 3252e27 commit a42c629

File tree

3 files changed

+31
-7
lines changed

3 files changed

+31
-7
lines changed

src/huggingface_hub/hub_mixin.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,12 @@ def save_pretrained(
262262
save_directory = Path(save_directory)
263263
save_directory.mkdir(parents=True, exist_ok=True)
264264

265+
# Remove config.json if already exists. After `_save_pretrained` we don't want to overwrite config.json
266+
# as it might have been saved by the custom `_save_pretrained` already. However we do want to overwrite
267+
# an existing config.json if it was not saved by `_save_pretrained`.
268+
config_path = save_directory / CONFIG_NAME
269+
config_path.unlink(missing_ok=True)
270+
265271
# save model weights/files (framework-specific)
266272
self._save_pretrained(save_directory)
267273

@@ -271,7 +277,6 @@ def save_pretrained(
271277
if config is not None:
272278
if is_dataclass(config):
273279
config = asdict(config) # type: ignore[arg-type]
274-
config_path = save_directory / CONFIG_NAME
275280
if not config_path.exists():
276281
config_str = json.dumps(config, sort_keys=True, indent=2)
277282
config_path.write_text(config_str)

tests/test_hub_mixin.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -302,10 +302,29 @@ def test_push_to_hub(self):
302302
# Delete repo
303303
self._api.delete_repo(repo_id=repo_id)
304304

305-
def test_save_pretrained_do_not_overwrite_config(self):
306-
"""Regression test for https://github.com/huggingface/huggingface_hub/issues/2102."""
305+
def test_save_pretrained_do_not_overwrite_new_config(self):
306+
"""Regression test for https://github.com/huggingface/huggingface_hub/issues/2102.
307+
308+
If `_from_pretrained` does save a config file, we should not overwrite it.
309+
"""
307310
model = DummyModelSavingConfig()
308311
model.save_pretrained(self.cache_dir)
309312
# config.json is not overwritten
310313
with open(self.cache_dir / "config.json") as f:
311314
assert json.load(f) == {"custom_config": "custom_config"}
315+
316+
def test_save_pretrained_does_overwrite_legacy_config(self):
317+
"""Regression test for https://github.com/huggingface/huggingface_hub/issues/2142.
318+
319+
If a previously existing config file exists, it should be overwritten.
320+
"""
321+
# Something existing in the cache dir
322+
(self.cache_dir / "config.json").write_text(json.dumps({"something_legacy": 123}))
323+
324+
# Save model
325+
model = DummyModelWithKwargs(a=1, b=2)
326+
model.save_pretrained(self.cache_dir)
327+
328+
# config.json IS overwritten
329+
with open(self.cache_dir / "config.json") as f:
330+
assert json.load(f) == {"a": 1, "b": 2}

tests/test_hub_mixin_pytorch.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -346,10 +346,10 @@ def forward(self, x):
346346

347347
# Linear layers should share weights and biases in memory
348348
state_dict = reloaded.state_dict()
349-
a_weight_ptr = state_dict["a.weight"].storage().data_ptr()
350-
b_weight_ptr = state_dict["b.weight"].storage().data_ptr()
351-
a_bias_ptr = state_dict["a.bias"].storage().data_ptr()
352-
b_bias_ptr = state_dict["b.bias"].storage().data_ptr()
349+
a_weight_ptr = state_dict["a.weight"].untyped_storage().data_ptr()
350+
b_weight_ptr = state_dict["b.weight"].untyped_storage().data_ptr()
351+
a_bias_ptr = state_dict["a.bias"].untyped_storage().data_ptr()
352+
b_bias_ptr = state_dict["b.bias"].untyped_storage().data_ptr()
353353
assert a_weight_ptr == b_weight_ptr
354354
assert a_bias_ptr == b_bias_ptr
355355

0 commit comments

Comments
 (0)