Skip to content

RuntimeError during load_state #3101

@tshmak

Description

@tshmak

System Info

- `Accelerate` version: 0.34.2
- Platform: Linux-5.15.0-92-generic-x86_64-with-glibc2.35
- `accelerate` bash location: /mnt/ttsnas/tshmak/WORK/Projects/TTS/FanoOpenVoice/conda_kcgpu1_openvoice/bin/accelerate
- Python version: 3.9.19
- Numpy version: 1.22.0
- PyTorch version (GPU?): 2.3.1+cu121 (True)
- PyTorch XPU available: False
- PyTorch NPU available: False
- PyTorch MLU available: False
- PyTorch MUSA available: False
- System RAM: 251.52 GB
- GPU type: NVIDIA GeForce RTX 2080 Ti
- `Accelerate` default config:
        Not found

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • One of the scripts in the examples/ folder of Accelerate or an officially supported no_trainer script in the examples folder of the transformers repo (such as run_no_trainer_glue.py)
  • My own task or dataset (give details below)

Reproduction

# Model taken from https://github.com/myshell-ai/OpenVoice/blob/f3cf835540572ade1460c8952f39d53e4f7952df/openvoice/models.py#L399 

#!/usr/bin/env python3

from accelerate import Accelerator
import torch.nn as nn
from torch.nn.utils import weight_norm

class SynthesizerTrn(nn.Module):
    """
    Synthesizer for Training
    """

    def __init__(
        self,
        spec_channels=513,
        gin_channels=256,
    ):
        super().__init__()
        self.ref_enc = ReferenceEncoder(spec_channels, gin_channels)

class ReferenceEncoder(nn.Module):
    def __init__(self, spec_channels, gin_channels=0, layernorm=True):
        super().__init__()
        self.spec_channels = spec_channels
        ref_enc_filters = [32, 32, 64, 64, 128, 128]
        K = len(ref_enc_filters)
        filters = [1] + ref_enc_filters
        convs = [
            weight_norm(
                nn.Conv2d(
                    in_channels=filters[i],
                    out_channels=filters[i + 1],
                    kernel_size=(3, 3),
                    stride=(2, 2),
                    padding=(1, 1),
                )
            )
            for i in range(K)
        ]
        self.convs = nn.ModuleList(convs)

        out_channels = self.calculate_channels(spec_channels, 3, 2, 1, K)
        self.gru = nn.GRU(
            input_size=ref_enc_filters[-1] * out_channels,
            hidden_size=256 // 2,
            batch_first=True,
        )
        self.proj = nn.Linear(128, gin_channels)
        if layernorm:
            self.layernorm = nn.LayerNorm(self.spec_channels)
        else:
            self.layernorm = None

    def calculate_channels(self, L, kernel_size, stride, pad, n_convs):
        for i in range(n_convs):
            L = (L - kernel_size + 2 * pad) // stride + 1
        return L

# Test
model = SynthesizerTrn()

accelerator = Accelerator()
model = accelerator.prepare(model)
accelerator.save_state('bug')
accelerator.load_state('bug')

I got the following when saving:

>>> accelerator.save_state('test')
Removed shared tensor {'ref_enc.gru.bias_ih_l0', 'ref_enc.gru.bias_hh_l0', 'ref_enc.gru.weight_hh_l0'} while saving. This should be OK, but check by verifying that you don't receive any warning while reloading

And the following when loading:

>>> accelerator.load_state('test')
Traceback (most recent call last):
... 
RuntimeError: Error while trying to find names to remove to save state dict, but found no suitable name to keep for saving amongst: {'ref_enc.gru.weight_ih_l0'}. None is covering the entire storage.Refusing to save/load the model since you could be storing much more memory than needed. Please refer to https://huggingface.co/docs/safetensors/torch_shared_tensors for more information. Or open an issue.

Expected behavior

Model should save and load without issues.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions