Skip to content

Commit 67f0962

Browse files
committed
Validate CoreML inputs against EP
1 parent 574e109 commit 67f0962

File tree

1 file changed

+72
-0
lines changed

1 file changed

+72
-0
lines changed

backends/apple/coreml/compiler/coreml_preprocess.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616

1717
import coremltools as ct
1818
import coremltools.optimize as cto
19+
import torch
20+
from coremltools.proto import FeatureTypes_pb2
1921

2022
from executorch.backends.apple.coreml import executorchcoreml
2123
from executorch.exir.backend.backend_details import (
@@ -31,6 +33,74 @@
3133
from executorch.backends.apple.coreml.compiler.torch_ops import * # noqa: F401, F403
3234

3335

36+
def _ct_dtype_to_torch_dtype(ct_data_type: int) -> torch.dtype:
37+
mapping = {
38+
FeatureTypes_pb2.ArrayFeatureType.INT32: torch.int32,
39+
FeatureTypes_pb2.ArrayFeatureType.FLOAT16: torch.float16,
40+
FeatureTypes_pb2.ArrayFeatureType.FLOAT32: torch.float32,
41+
FeatureTypes_pb2.ArrayFeatureType.DOUBLE: torch.float64,
42+
}
43+
return mapping[ct_data_type]
44+
45+
46+
def _validate_io(ep_io, ct_io, io_kind):
47+
assert len(ep_io) == len(
48+
ct_io
49+
), f"The ExportedProgram and the CoreML program have different {io_kind}. This is not expected. Please file an issue on ExecuTorch."
50+
51+
for i, (ep_io_val, ct_io_val) in enumerate(zip(ep_io, ct_io)):
52+
ct_name = ct_io_val.name
53+
ep_name = ep_io_val[0]
54+
assert (
55+
ct_name == ep_name
56+
), f"Mismatched {io_kind} at pos {i}: ExportedProgram {io_kind} is {ep_name}, but CoreML input is {ct_name}"
57+
assert (
58+
ct_io_val.type.WhichOneof("Type") == "multiArrayType"
59+
), f"Expected multiArrayType for {ct_name}"
60+
61+
ct_dtype = _ct_dtype_to_torch_dtype(ct_io_val.type.multiArrayType.dataType)
62+
ep_dtype = ep_io_val[1].dtype
63+
assert ct_dtype == ep_dtype, (
64+
f"Mismatched dtype for {io_kind} {ep_name} (pos {i}):\n"
65+
f" ExportedProgram has dtype {ep_dtype}, but CoreML has dtype {ct_dtype}.\n"
66+
f" Sometimes CoreML autocasts inputs/outputs (e.g., int64 → int32).\n"
67+
f" In some cases, this can be fixed by changing the input type on the ExportedProgram before lowering."
68+
)
69+
70+
71+
def _validate_inputs(ep, mlmodel):
72+
user_input_to_val = {}
73+
for node in ep.graph.nodes:
74+
if node.op == "placeholder":
75+
assert node.name in ep.graph_signature.user_inputs
76+
user_input_to_val[node.name] = node.meta["val"]
77+
78+
ep_user_inputs = []
79+
for uin in ep.graph_signature.user_inputs:
80+
ep_user_inputs.append((uin, user_input_to_val[uin]))
81+
82+
ct_inputs = mlmodel.get_spec().description.input
83+
_validate_io(ep_user_inputs, ct_inputs, "user_input")
84+
85+
86+
def _validate_outputs(ep, mlmodel):
87+
for output_node in ep.graph.nodes:
88+
if output_node.op == "output":
89+
break
90+
91+
user_output_to_val = {}
92+
for node in output_node.args[0]:
93+
assert node.name in ep.graph_signature.user_outputs
94+
user_output_to_val[node.name] = node.meta["val"]
95+
96+
ep_user_outputs = []
97+
for uout in ep.graph_signature.user_outputs:
98+
ep_user_outputs.append((uout, user_output_to_val[uout]))
99+
100+
ct_outputs = mlmodel.get_spec().description.output
101+
_validate_io(ep_user_outputs, ct_outputs, "user_output")
102+
103+
34104
class COMPILE_SPEC_KEYS(Enum):
35105
COMPUTE_UNITS = "compute_units"
36106
MODEL_TYPE = "model_type"
@@ -440,6 +510,8 @@ def preprocess(
440510
minimum_deployment_target=minimum_deployment_target,
441511
compute_units=compute_units,
442512
)
513+
_validate_inputs(edge_program, mlmodel)
514+
_validate_outputs(edge_program, mlmodel)
443515

444516
if op_linear_quantizer_config is not None:
445517
logger.warning(

0 commit comments

Comments
 (0)