Skip to content

Commit 3d40f94

Browse files
authored
fix pretrained_config save dtype (#2587)
1 parent f2cf3b2 commit 3d40f94

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

paddleformers/transformers/configuration_utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -667,6 +667,8 @@ def __init__(self, **kwargs):
667667
"Transformers. Using `model.gradient_checkpointing_enable()` instead, or if you are using the "
668668
"`Trainer` API, pass `gradient_checkpointing=True` in your `TrainingArguments`."
669669
)
670+
self._save_to_hf = kwargs.pop("save_to_hf", False)
671+
self._unsavable_keys.add("_save_to_hf")
670672

671673
# Additional attributes without default values
672674
for key, value in kwargs.items():
@@ -759,6 +761,8 @@ def save_pretrained(self, save_directory: Union[str, os.PathLike], **kwargs):
759761

760762
os.makedirs(save_directory, exist_ok=True)
761763

764+
self._save_to_hf = kwargs.pop("save_to_hf", False)
765+
762766
# If we have a custom config, we copy the file defining it in the folder and set the attributes so it can be
763767
# loaded from the Hub.
764768
if self._auto_class is not None:
@@ -1068,7 +1072,7 @@ def to_dict(self, saving_file=False) -> Dict[str, Any]:
10681072
del output["_auto_class"]
10691073
if "moe_group" in output:
10701074
del output["moe_group"]
1071-
if "dtype" in output:
1075+
if self._save_to_hf and "dtype" in output:
10721076
output["torch_dtype"] = str(output["dtype"])
10731077
del output["dtype"]
10741078

0 commit comments

Comments
 (0)