Skip to content

Incorrect Observer Sharing/Derivation at Conv-ReLU+ Residual with Arm Ethos Quantizer #12959

@Ninja91

Description

@Ninja91

🐛 Describe the bug

There's a bug in PyTorch 2.0 quantization with ARM quantizer where observers are incorrectly getting shared/derived across different operations (add, permute, relu). Reason most likely being add requires inputs to share observer which when merged with relu and permute observer sharing results in this bug. Since the ReLU observer is shared and observes negative values, incorrect quantization parameters(qparams) are computed (zp!=-128). The Relu fusion is impacted as discussed here which impacts model delegation to ARM U55 backend.

Compare quantized model and see qparams:
With bug triggered:

class SampleModel(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, input_dim):
        super().__init__()
        self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size)
        self.relu = torch.nn.ReLU()
        self.linear = nn.Linear(out_channels, out_channels)

    def forward(self, x):
        # Permute: (N, T, C) -> (N, C, T)
        x = x.permute(0, 2, 1)
        # Conv1D
        x = self.conv1(x)
        # ReLU
        x = self.relu(x)
        # Permute back: (N, C, T) -> (N, T, C)
        x = x.permute(0, 2, 1)
        # Residual connection
        x = x + self.linear(x)
        return x

# Input shape: (1, 10, 3)
model = SampleModel(in_channels=3, out_channels=16, kernel_size=3, input_dim=10)

def _get_ethos_quantizer(
    compile_spec: List[CompileSpec],
) -> Union[Quantizer, TorchQuantizer]:
    from executorch.backends.arm.quantizer import (
        EthosUQuantizer,
        get_symmetric_quantization_config,
    )

    quantizer = EthosUQuantizer(compile_spec=compile_spec)
    quantization_config = get_symmetric_quantization_config(is_per_channel=False)
    quantizer.set_global(quantization_config)
    return quantizer

# Apply post-training quantization with ARM U55 recipe
quantized_model = post_train_quantize(model, calibration_data, quantizer=_get_ethos_quantizer(get_u55_compile_spec()))

With bug not triggered for linear connection:

class SampleModel(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, input_dim):
        super().__init__()
        self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size)
        self.relu = torch.nn.ReLU()
        self.linear = nn.Linear(out_channels, out_channels)

    def forward(self, x):
        # Permute: (N, T, C) -> (N, C, T)
        x = x.permute(0, 2, 1)
        # Conv1D
        x = self.conv1(x)
        # ReLU
        x = self.relu(x)
        # Permute back: (N, C, T) -> (N, T, C)
        x = x.permute(0, 2, 1)
        # Linear connection instead of residual connection
        x = self.linear(x)
        return x

# Input shape: (1, 10, 3)
model = SampleModel(in_channels=3, out_channels=16, kernel_size=3, input_dim=10)
# Apply post-training quantization with ARM U55 recipe
# Apply post-training quantization with ARM U55 recipe
quantized_model = post_train_quantize(model, calibration_data, quantizer=_get_ethos_quantizer(get_u55_compile_spec()))

Although strange that this hasn't been uncovered with other residual networks, but IMO we need to reconsider certain Shared and Derived quantization spec and optimizations to fix this bug. This is a blocker and if there's no immediate resolution, discussion on short term fixes will also be meaningful.
cc: @digantdesai , @3l1, @Tessil

Versions

Python version: 3.12.11+meta (3.12:55fee9c, Jun 03 2025, 15:41:33) [Clang 19.1.2
Python platform: Linux-5.19.0-0_fbk12_hardened_11583_g0bef9520ca2b-x86_64-with-glibc2.34
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
Is XPU available: False
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

cc @digantdesai @freddan80 @per @zingo @oscarandersson8218

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: armIssues related to arm backendpartner: armFor backend delegation, kernels, demo, etc. from the 3rd-party partner, Arm

    Type

    Projects

    Status

    Done

    Milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions