-
Notifications
You must be signed in to change notification settings - Fork 32.2k
Open
Labels
Description
System Info
transformersversion: 5.2.0- Platform: Linux-5.15.153.1-microsoft-standard-WSL2-aarch64-with-glibc2.35
- Python version: 3.10.16
- Huggingface_hub version: 1.4.1
- Safetensors version: 0.7.0
- Accelerate version: not installed
- Accelerate config: not found
- DeepSpeed version: not installed
- PyTorch version (accelerator?): 2.10.0+cpu (NA)
- Using distributed or parallel set-up in script?: no
Who can help?
model loading (from pretrained, etc): @Cyrilvallez
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examplesfolder (such as GLUE/SQuAD, ...) - My own task or dataset (give details below)
Reproduction
Case 1. extra_state: dict[str, torch.Tensor]
extra_state of type dict[str, torch.Tensor] fails during PreTrainedModel.save_pretrained.
import random
import torch
from transformers.models.qwen3.modeling_qwen3 import Qwen3ForCausalLM
from transformers.models.qwen3.configuration_qwen3 import Qwen3Config
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.state = {
"a": random.randint(0, 100),
"b": random.randint(0, 100),
"c": random.randint(0, 100),
}
def get_extra_state(self) -> dict[str, torch.Tensor]:
return {
key: torch.tensor(val)
for key, val in self.state.items()
}
def set_extra_state(self, state: dict[str, torch.Tensor]):
self.state = {
key: val.item()
for key, val in state.items()
}
class Qwen3ForCausalLM(Qwen3ForCausalLM):
def __init__(self, config):
super().__init__(config)
self.mymodule = MyModule()
def main():
config = Qwen3Config(
vocab_size=10,
hidden_size=32,
intermediate_size=32,
num_attention_heads=1,
num_key_value_heads=1,
num_hidden_layers=1,
)
model = Qwen3ForCausalLM(config)
model.save_pretrained("./Qwen3ForCausalLM")
model_ = Qwen3ForCausalLM.from_pretrained("./Qwen3ForCausalLM")
assert model.mymodule.state == model_.mymodule.state
if __name__ == "__main__":
main()Traceback
Traceback (most recent call last):
File ".../state_dict.py", line 51, in <module>
main()
File ".../state_dict.py", line 45, in main
model.save_pretrained("./Qwen3QuantForCausalLM")
File ".../transformers/modeling_utils.py", line 3321, in save_pretrained
state_dict_split = split_torch_state_dict_into_shards(
File ".../huggingface_hub/serialization/_torch.py", line 351, in split_torch_state_dict_into_shards
return split_state_dict_into_shards_factory(
File ".../huggingface_hub/serialization/_base.py", line 105, in split_state_dict_into_shards_factory
storage_id = get_storage_id(tensor) # type: ignore[invalid-argument-type]
File ".../huggingface_hub/serialization/_torch.py", line 737, in get_torch_storage_id
if tensor.device.type == "meta":
AttributeError: 'dict' object has no attribute 'device'
Case 2. extra_state: torch.Tensor
extra_state of type torch.Tensor fails during PreTrainedModel.from_pretrained.
(Expected extra_state of torch.Tensor to be supported by #38155)
import random
import torch
from transformers.models.qwen3.modeling_qwen3 import Qwen3ForCausalLM
from transformers.models.qwen3.configuration_qwen3 import Qwen3Config
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.state = [random.randint(0, 100) for _ in range(3)]
def get_extra_state(self) -> torch.Tensor:
return torch.tensor(self.state)
def set_extra_state(self, state: torch.Tensor):
self.state = state.tolist()
class Qwen3ForCausalLM(Qwen3ForCausalLM):
def __init__(self, config):
super().__init__(config)
self.mymodule = MyModule()
def main():
config = Qwen3Config(
vocab_size=10,
hidden_size=32,
intermediate_size=32,
num_attention_heads=1,
num_key_value_heads=1,
num_hidden_layers=1,
)
model = Qwen3ForCausalLM(config)
model.save_pretrained("./Qwen3ForCausalLM")
model_ = Qwen3ForCausalLM.from_pretrained("./Qwen3ForCausalLM")
assert model.mymodule.state == model_.mymodule.state
if __name__ == "__main__":
main()Traceback
Traceback (most recent call last):
File ".../state_dict_2.py", line 41, in <module>
main()
File ".../state_dict_2.py", line 36, in main
model_ = Qwen3ForCausalLM.from_pretrained("./Qwen3ForCausalLM")
File ".../transformers/modeling_utils.py", line 4072, in from_pretrained
loading_info, disk_offload_index = cls._load_pretrained_model(model, state_dict, checkpoint_files, load_config)
File ".../transformers/modeling_utils.py", line 4191, in _load_pretrained_model
loading_info, disk_offload_index = convert_and_load_state_dict_in_model(
File ".../transformers/core_model_loading.py", line 1225, in convert_and_load_state_dict_in_model
set_param_for_module(
File ".../transformers/core_model_loading.py", line 900, in set_param_for_module
ref = getattr(module_obj, param_name)
File ".../torch/nn/modules/module.py", line 1965, in __getattr__
raise AttributeError(
AttributeError: 'MyModule' object has no attribute '_extra_state'. Did you mean: 'get_extra_state'?
Expected behavior
- [bug] Support extra_state of type torch.Tensor
I think it was meant to be supported by #38155, but it doesn't work in the latest transformers. - [feature request] Support extra_state of arbitrary type by delegating to
nn.Module.set_extra_state
Reactions are currently unavailable