-
Notifications
You must be signed in to change notification settings - Fork 281
Open
Description
If I define a parameter with the same name as "layer_scale" in the pytorch nn.Module, as shown in the following code, a ValueError occurs.
class ConvEncoder(nn.Module):
"""
Implementation of ConvEncoder with 3*3 and 1*1 convolutions.
Input: tensor with shape [B, C, H, W]
Output: tensor with shape [B, C, H, W]
"""
def __init__(
self, dim, hidden_dim=64, kernel_size=3, drop_path=0.0, use_layer_scale=True
):
super().__init__()
self.dwconv = nn.Conv2d(
dim, dim, kernel_size=kernel_size, padding=kernel_size // 2, groups=dim
)
self.norm = nn.BatchNorm2d(dim)
self.pwconv1 = nn.Conv2d(dim, hidden_dim, kernel_size=1)
self.act = nn.GELU()
self.pwconv2 = nn.Conv2d(hidden_dim, dim, kernel_size=1)
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.use_layer_scale = use_layer_scale
if use_layer_scale:
self.layer_scale = nn.Parameter(
torch.ones(dim).unsqueeze(-1).unsqueeze(-1), requires_grad=True
)
self.apply(self._init_weights)
In "adaptor/pytorch" Line 4174
for node in model.graph.nodes:
if node.op == "get_attr":
if prefix:
sub_name = prefix + "--" + node.target
else:
sub_name = node.target
if not hasattr(model, node.target):
continue
if "scale" in node.target: #### This condition is not suitable
tune_cfg["get_attr"][sub_name] = float(getattr(model, node.target))
elif "zero_point" in node.target:
tune_cfg["get_attr"][sub_name] = int(getattr(model, node.target))
else:
pass
File "/root/.pyenv/versions/3.11.11/lib/python3.11/site-packages/neural_compressor/utils/utility.py", line 347, in fi
res = func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/root/.pyenv/versions/3.11.11/lib/python3.11/site-packages/neural_compressor/adaptor/pytorch.py", line 3658, in quantize
self._get_scale_zeropoint(q_model._model, q_model.q_config)
File "/root/.pyenv/versions/3.11.11/lib/python3.11/site-packages/neural_compressor/adaptor/pytorch.py", line 4217, in _get_scale_zeropoint
self._get_sub_module_scale_zeropoint(model, tune_cfg)
File "/root/.pyenv/versions/3.11.11/lib/python3.11/site-packages/neural_compressor/adaptor/pytorch.py", line 4199, in _get_sub_module_scale_zeropoint
self._get_sub_module_scale_zeropoint(module, tune_cfg, op_name)
File "/root/.pyenv/versions/3.11.11/lib/python3.11/site-packages/neural_compressor/adaptor/pytorch.py", line 4199, in _get_sub_module_scale_zeropoint
self._get_sub_module_scale_zeropoint(module, tune_cfg, op_name)
File "/root/.pyenv/versions/3.11.11/lib/python3.11/site-packages/neural_compressor/adaptor/pytorch.py", line 4199, in _get_sub_module_scale_zeropoint
self._get_sub_module_scale_zeropoint(module, tune_cfg, op_name)
File "/root/.pyenv/versions/3.11.11/lib/python3.11/site-packages/neural_compressor/adaptor/pytorch.py", line 4197, in _get_sub_module_scale_zeropoint
self._get_module_scale_zeropoint(module, tune_cfg, op_name)
File "/root/.pyenv/versions/3.11.11/lib/python3.11/site-packages/neural_compressor/adaptor/pytorch.py", line 4175, in _get_module_scale_zeropoint
tune_cfg["get_attr"][sub_name] = float(getattr(model, node.target))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ValueError: only one element tensors can be converted to Python scalars
When I check the string "node.target" and the tensor value, it treats the layer_scale like a scale of quantization, as follows.
feature_extractor.patch_embed--0_input_scale_0 tensor(0.0355)
feature_extractor.network.0.0--dwconv_input_scale_0 tensor(0.0338)
feature_extractor.network.0.0--pwconv2_input_scale_0 tensor(0.2263)
feature_extractor.network.0.0--layer_scale Parameter containing:
tensor([[[ 0.0336]],
[[ 0.0153]],
[[ 0.0214]],
[[ 0.0068]],
[[ 0.0229]],
[[ 0.0136]],
[[ 0.0491]],
[[ 0.0202]],
[[ 0.0420]],
[[ 0.0495]],
[[ 0.0060]],
[[ 0.0225]],
[[ 0.0311]],
[[ 0.0303]],
[[ 0.0556]],
[[ 0.0290]],
[[ 0.0222]],
[[ 0.0153]],
[[ 0.0332]],
[[ 0.0667]],
[[ 0.0168]],
[[ 0.0416]],
[[ 0.0258]],
[[ 0.0200]],
[[ 0.0259]],
[[ 0.0044]],
[[ 0.0514]],
[[ 0.0190]],
[[ 0.0545]],
[[ 0.0119]],
[[ 0.0220]],
[[ 0.0481]],
[[ 0.0115]],
[[ 0.0707]],
[[ 0.0299]],
[[ 0.0105]],
[[ 0.0266]],
[[ 0.0156]],
[[ 0.0380]],
[[ 0.0160]],
[[ 0.0521]],
[[ 0.0094]],
[[-0.0133]],
[[ 0.0585]],
[[ 0.0216]],
[[ 0.0102]],
[[ 0.0297]],
[[ 0.0104]]], requires_grad=True)
Modifying the conditional statement as below fixes the problem, but it doesn't seem to be a perfect way.
if "scale" in node.target and "layer_scale" not in node.target:
Metadata
Metadata
Assignees
Labels
No labels