Skip to content

save/from_pretrained fails to handle extra_state #44164

@quic-kyunggeu

Description

@quic-kyunggeu

System Info

  • transformers version: 5.2.0
  • Platform: Linux-5.15.153.1-microsoft-standard-WSL2-aarch64-with-glibc2.35
  • Python version: 3.10.16
  • Huggingface_hub version: 1.4.1
  • Safetensors version: 0.7.0
  • Accelerate version: not installed
  • Accelerate config: not found
  • DeepSpeed version: not installed
  • PyTorch version (accelerator?): 2.10.0+cpu (NA)
  • Using distributed or parallel set-up in script?: no

Who can help?

model loading (from pretrained, etc): @Cyrilvallez

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Case 1. extra_state: dict[str, torch.Tensor]

extra_state of type dict[str, torch.Tensor] fails during PreTrainedModel.save_pretrained.

import random
import torch
from transformers.models.qwen3.modeling_qwen3 import Qwen3ForCausalLM
from transformers.models.qwen3.configuration_qwen3 import Qwen3Config


class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.state = {
            "a": random.randint(0, 100),
            "b": random.randint(0, 100),
            "c": random.randint(0, 100),
        }

    def get_extra_state(self) -> dict[str, torch.Tensor]:
        return {
            key: torch.tensor(val)
            for key, val in self.state.items()
        }

    def set_extra_state(self, state: dict[str, torch.Tensor]):
        self.state = {
            key: val.item()
            for key, val in state.items()
        }


class Qwen3ForCausalLM(Qwen3ForCausalLM):
    def __init__(self, config):
        super().__init__(config)
        self.mymodule = MyModule()


def main():
    config = Qwen3Config(
        vocab_size=10,
        hidden_size=32,
        intermediate_size=32,
        num_attention_heads=1,
        num_key_value_heads=1,
        num_hidden_layers=1,
    )
    model = Qwen3ForCausalLM(config)
    model.save_pretrained("./Qwen3ForCausalLM")
    model_ = Qwen3ForCausalLM.from_pretrained("./Qwen3ForCausalLM")
    assert model.mymodule.state == model_.mymodule.state


if __name__ == "__main__":
    main()

Traceback

Traceback (most recent call last):
  File ".../state_dict.py", line 51, in <module>
    main()
  File ".../state_dict.py", line 45, in main
    model.save_pretrained("./Qwen3QuantForCausalLM")
  File ".../transformers/modeling_utils.py", line 3321, in save_pretrained
    state_dict_split = split_torch_state_dict_into_shards(
  File ".../huggingface_hub/serialization/_torch.py", line 351, in split_torch_state_dict_into_shards
    return split_state_dict_into_shards_factory(
  File ".../huggingface_hub/serialization/_base.py", line 105, in split_state_dict_into_shards_factory
    storage_id = get_storage_id(tensor)  # type: ignore[invalid-argument-type]
  File ".../huggingface_hub/serialization/_torch.py", line 737, in get_torch_storage_id
    if tensor.device.type == "meta":
AttributeError: 'dict' object has no attribute 'device'

Case 2. extra_state: torch.Tensor

extra_state of type torch.Tensor fails during PreTrainedModel.from_pretrained.
(Expected extra_state of torch.Tensor to be supported by #38155)

import random
import torch
from transformers.models.qwen3.modeling_qwen3 import Qwen3ForCausalLM
from transformers.models.qwen3.configuration_qwen3 import Qwen3Config


class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.state = [random.randint(0, 100) for _ in range(3)]

    def get_extra_state(self) -> torch.Tensor:
        return torch.tensor(self.state)

    def set_extra_state(self, state: torch.Tensor):
        self.state = state.tolist()


class Qwen3ForCausalLM(Qwen3ForCausalLM):
    def __init__(self, config):
        super().__init__(config)
        self.mymodule = MyModule()


def main():
    config = Qwen3Config(
        vocab_size=10,
        hidden_size=32,
        intermediate_size=32,
        num_attention_heads=1,
        num_key_value_heads=1,
        num_hidden_layers=1,
    )
    model = Qwen3ForCausalLM(config)
    model.save_pretrained("./Qwen3ForCausalLM")
    model_ = Qwen3ForCausalLM.from_pretrained("./Qwen3ForCausalLM")
    assert model.mymodule.state == model_.mymodule.state


if __name__ == "__main__":
    main()

Traceback

Traceback (most recent call last):
  File ".../state_dict_2.py", line 41, in <module>
    main()
  File ".../state_dict_2.py", line 36, in main
    model_ = Qwen3ForCausalLM.from_pretrained("./Qwen3ForCausalLM")
  File ".../transformers/modeling_utils.py", line 4072, in from_pretrained
    loading_info, disk_offload_index = cls._load_pretrained_model(model, state_dict, checkpoint_files, load_config)
  File ".../transformers/modeling_utils.py", line 4191, in _load_pretrained_model
    loading_info, disk_offload_index = convert_and_load_state_dict_in_model(
  File ".../transformers/core_model_loading.py", line 1225, in convert_and_load_state_dict_in_model
    set_param_for_module(
  File ".../transformers/core_model_loading.py", line 900, in set_param_for_module
    ref = getattr(module_obj, param_name)
  File ".../torch/nn/modules/module.py", line 1965, in __getattr__
    raise AttributeError(
AttributeError: 'MyModule' object has no attribute '_extra_state'. Did you mean: 'get_extra_state'?

Expected behavior

  • [bug] Support extra_state of type torch.Tensor
    I think it was meant to be supported by #38155, but it doesn't work in the latest transformers.
  • [feature request] Support extra_state of arbitrary type by delegating to nn.Module.set_extra_state

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions