-
Notifications
You must be signed in to change notification settings - Fork 747
Closed
Labels
backend testerThis bug was found by the backend test suite.This bug was found by the backend test suite.module: coremlIssues related to Apple's Core ML delegation and code under backends/apple/coreml/Issues related to Apple's Core ML delegation and code under backends/apple/coreml/
Description
π Describe the bug
When running torch.diagonal on CoreML, it sometimes gives incorrect outputs and sometimes crashes the process. Looks like there's some sort of memory corruption happening?
Repro:
import torch
from executorch.backends.apple.coreml.partition import CoreMLPartitioner
from executorch.exir import to_edge_transform_and_lower, EdgeCompileConfig, to_edge
from executorch.extension.pybindings.portable_lib import _load_for_executorch_from_buffer
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return torch.diagonal(x)
model = Model()
inputs = (
torch.arange(25).reshape((5, 5)).to(torch.float),
)
print(inputs)
eager_outputs = model(*inputs)
print(f"Eager: {eager_outputs.shape} {eager_outputs}")
ep = torch.export.export(model.eval(), inputs)
print(ep)
print(f"EP: {ep.module()(*inputs)}")
lowered = to_edge_transform_and_lower(
ep,
partitioner=[CoreMLPartitioner()],
compile_config=EdgeCompileConfig(_check_ir_validity=False)
).to_executorch()
print(lowered.exported_program())
et_model = _load_for_executorch_from_buffer(lowered.buffer)
et_outputs = et_model([*inputs])[0]
print(et_outputs)
et_outputs - eager_outputsExample outputs (note the last line showing the error between reference and CoreML on ET outputs):
WARNING:root:Op aten.diagonal.default was requested for preservation by partitioner. This request is ignored because it aliases output.
WARNING:root:Op aten.diagonal.default was requested for preservation by partitioner. This request is ignored because it aliases output.
(tensor([[ 0., 1., 2., 3., 4.],
[ 5., 6., 7., 8., 9.],
[10., 11., 12., 13., 14.],
[15., 16., 17., 18., 19.],
[20., 21., 22., 23., 24.]]),)
Eager: torch.Size([5]) tensor([ 0., 6., 12., 18., 24.])
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, x: "f32[5, 5]"):
# File: /var/folders/90/5w9gk0fn4n3g7fw1bvq8r1_m0000gn/T/ipykernel_68275/3260929381.py:12 in forward, code: return torch.diagonal(x)
diagonal: "f32[5]" = torch.ops.aten.diagonal.default(x); x = None
return (diagonal,)
Graph signature:
# inputs
x: USER_INPUT
# outputs
diagonal: USER_OUTPUT
Range constraints: {}
EP: tensor([ 0., 6., 12., 18., 24.])
Converting PyTorch Frontend ==> MIL Ops: 0%| | 0[/1](http://localhost:8888/1) [00:00<?, ? ops[/s](http://localhost:8888/s)]
Running MIL frontend_pytorch pipeline: 100%|βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 5[/5](http://localhost:8888/5) [00:00<00:00, 18944.46 passes[/s](http://localhost:8888/s)]
Running MIL default pipeline: 100%|βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 89[/89](http://localhost:8888/89) [00:00<00:00, 9404.27 passes[/s](http://localhost:8888/s)]
Running MIL backend_mlprogram pipeline: 100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 12[/12](http://localhost:8888/12) [00:00<00:00, 13245.17 passes[/s](http://localhost:8888/s)]
[program.cpp:135] InternalConsistency verification requested but not available
[ETCoreMLModelManager.mm:528] Cache Miss: Model with identifier=executorch_51e9ef29-cbbc-4d0f-923d-49194fa7535b_all was not found in the models cache.
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, x: "f32[5, 5]"):
# No stacktrace found for following nodes
lowered_module_0 = self.lowered_module_0
executorch_call_delegate = torch.ops.higher_order.executorch_call_delegate(lowered_module_0, x); lowered_module_0 = x = None
getitem: "f32[5]" = executorch_call_delegate[0]; executorch_call_delegate = None
return (getitem,)
Graph signature:
# inputs
x: USER_INPUT
# outputs
getitem: USER_OUTPUT
Range constraints: {}
tensor([0., 0., 0., 0., 0.]) <-- ET outputs
tensor([ 0., -6., -12., -18., -24.]) <-- error
Versions
coremltools version 8.3
executorch commit 67b6009 (Jun 14)
Metadata
Metadata
Assignees
Labels
backend testerThis bug was found by the backend test suite.This bug was found by the backend test suite.module: coremlIssues related to Apple's Core ML delegation and code under backends/apple/coreml/Issues related to Apple's Core ML delegation and code under backends/apple/coreml/
Type
Projects
Status
Done