Skip to content

CoreML ignores add/sub alpha parameter #11687

@GregoryComer

Description

@GregoryComer

🐛 Describe the bug

The CoreML backend ignores the add value in aten::add.Tensor and aten::sub.Tensor. This gives incorrect results relative to eager mode.

Repro:

import torch

from executorch.backends.apple.coreml.partition import CoreMLPartitioner
from executorch.exir import to_edge_transform_and_lower
from executorch.extension.pybindings.portable_lib import _load_for_executorch_from_buffer

class Model(torch.nn.Module):
    def forward(self, x, y):
        return torch.sub(x, y, alpha=5)

model = Model()
inputs = (
    torch.randn(10),
    torch.randn(10),
)
lowered = to_edge_transform_and_lower(
    torch.export.export(model, inputs),
    partitioner=[CoreMLPartitioner()],
).to_executorch()

et_model = _load_for_executorch_from_buffer(lowered.buffer)
et_outputs = et_model([*inputs])[0]
eager_outputs = model(*inputs)

et_outputs - eager_outputs

Example output:

tensor([ 1.4324, -0.9128, -0.8488,  0.9937, -0.5696, -0.7926, -4.8031, -0.0652,
         2.5334,  3.6470])

Versions

Nightly

cc @kimishpatel @YifanShenSZ @cymbalrush @metascroy

Metadata

Metadata

Assignees

No one assigned

    Labels

    backend testerThis bug was found by the backend test suite.module: coremlIssues related to Apple's Core ML delegation and code under backends/apple/coreml/

    Type

    Projects

    Status

    Done

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions