Skip to content

ValueError occurs when parameter "layer_scale" is used in torch #2103

@nalnez13

Description

@nalnez13

If I define a parameter with the same name as "layer_scale" in the pytorch nn.Module, as shown in the following code, a ValueError occurs.

class ConvEncoder(nn.Module):
    """
    Implementation of ConvEncoder with 3*3 and 1*1 convolutions.
    Input: tensor with shape [B, C, H, W]
    Output: tensor with shape [B, C, H, W]
    """

    def __init__(
        self, dim, hidden_dim=64, kernel_size=3, drop_path=0.0, use_layer_scale=True
    ):
        super().__init__()
        self.dwconv = nn.Conv2d(
            dim, dim, kernel_size=kernel_size, padding=kernel_size // 2, groups=dim
        )
        self.norm = nn.BatchNorm2d(dim)
        self.pwconv1 = nn.Conv2d(dim, hidden_dim, kernel_size=1)
        self.act = nn.GELU()
        self.pwconv2 = nn.Conv2d(hidden_dim, dim, kernel_size=1)
        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
        self.use_layer_scale = use_layer_scale
        if use_layer_scale:
            self.layer_scale = nn.Parameter(
                torch.ones(dim).unsqueeze(-1).unsqueeze(-1), requires_grad=True
            )
        self.apply(self._init_weights)

In "adaptor/pytorch" Line 4174

        for node in model.graph.nodes:
            if node.op == "get_attr":
                if prefix:
                    sub_name = prefix + "--" + node.target
                else:
                    sub_name = node.target
                if not hasattr(model, node.target):
                    continue
                if "scale" in node.target: #### This condition is not suitable
                    tune_cfg["get_attr"][sub_name] = float(getattr(model, node.target))
                elif "zero_point" in node.target:
                    tune_cfg["get_attr"][sub_name] = int(getattr(model, node.target))
                else:
                    pass
  File "/root/.pyenv/versions/3.11.11/lib/python3.11/site-packages/neural_compressor/utils/utility.py", line 347, in fi
    res = func(*args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^
  File "/root/.pyenv/versions/3.11.11/lib/python3.11/site-packages/neural_compressor/adaptor/pytorch.py", line 3658, in quantize
    self._get_scale_zeropoint(q_model._model, q_model.q_config)
  File "/root/.pyenv/versions/3.11.11/lib/python3.11/site-packages/neural_compressor/adaptor/pytorch.py", line 4217, in _get_scale_zeropoint
    self._get_sub_module_scale_zeropoint(model, tune_cfg)
  File "/root/.pyenv/versions/3.11.11/lib/python3.11/site-packages/neural_compressor/adaptor/pytorch.py", line 4199, in _get_sub_module_scale_zeropoint
    self._get_sub_module_scale_zeropoint(module, tune_cfg, op_name)
  File "/root/.pyenv/versions/3.11.11/lib/python3.11/site-packages/neural_compressor/adaptor/pytorch.py", line 4199, in _get_sub_module_scale_zeropoint
    self._get_sub_module_scale_zeropoint(module, tune_cfg, op_name)
  File "/root/.pyenv/versions/3.11.11/lib/python3.11/site-packages/neural_compressor/adaptor/pytorch.py", line 4199, in _get_sub_module_scale_zeropoint
    self._get_sub_module_scale_zeropoint(module, tune_cfg, op_name)
  File "/root/.pyenv/versions/3.11.11/lib/python3.11/site-packages/neural_compressor/adaptor/pytorch.py", line 4197, in _get_sub_module_scale_zeropoint
    self._get_module_scale_zeropoint(module, tune_cfg, op_name)
  File "/root/.pyenv/versions/3.11.11/lib/python3.11/site-packages/neural_compressor/adaptor/pytorch.py", line 4175, in _get_module_scale_zeropoint
    tune_cfg["get_attr"][sub_name] = float(getattr(model, node.target))
                                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ValueError: only one element tensors can be converted to Python scalars

When I check the string "node.target" and the tensor value, it treats the layer_scale like a scale of quantization, as follows.

feature_extractor.patch_embed--0_input_scale_0 tensor(0.0355)
feature_extractor.network.0.0--dwconv_input_scale_0 tensor(0.0338)
feature_extractor.network.0.0--pwconv2_input_scale_0 tensor(0.2263)
feature_extractor.network.0.0--layer_scale Parameter containing:
tensor([[[ 0.0336]],

        [[ 0.0153]],

        [[ 0.0214]],

        [[ 0.0068]],

        [[ 0.0229]],

        [[ 0.0136]],

        [[ 0.0491]],

        [[ 0.0202]],

        [[ 0.0420]],

        [[ 0.0495]],

        [[ 0.0060]],

        [[ 0.0225]],

        [[ 0.0311]],

        [[ 0.0303]],

        [[ 0.0556]],

        [[ 0.0290]],

        [[ 0.0222]],

        [[ 0.0153]],

        [[ 0.0332]],

        [[ 0.0667]],

        [[ 0.0168]],

        [[ 0.0416]],

        [[ 0.0258]],

        [[ 0.0200]],

        [[ 0.0259]],

        [[ 0.0044]],

        [[ 0.0514]],

        [[ 0.0190]],

        [[ 0.0545]],

        [[ 0.0119]],

        [[ 0.0220]],

        [[ 0.0481]],

        [[ 0.0115]],

        [[ 0.0707]],

        [[ 0.0299]],

        [[ 0.0105]],

        [[ 0.0266]],

        [[ 0.0156]],

        [[ 0.0380]],

        [[ 0.0160]],

        [[ 0.0521]],

        [[ 0.0094]],

        [[-0.0133]],

        [[ 0.0585]],

        [[ 0.0216]],

        [[ 0.0102]],

        [[ 0.0297]],

        [[ 0.0104]]], requires_grad=True)

Modifying the conditional statement as below fixes the problem, but it doesn't seem to be a perfect way.

if "scale" in node.target and "layer_scale" not in node.target:

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions