Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions backends/apple/coreml/compiler/torch_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@
from coremltools.converters.mil.frontend import _utils
from coremltools.converters.mil.frontend.torch.ops import (
_get_inputs,
_get_kwinputs,
NUM_TO_NUMPY_DTYPE,
NUM_TO_TORCH_DTYPE,
split,
to,
transpose,
unbind,
)
Expand All @@ -24,6 +26,7 @@
register_torch_op,
)
from coremltools.converters.mil.mil import types
from executorch.exir.dim_order_utils import get_memory_format


# https://github.com/apple/coremltools/pull/2556
Expand All @@ -44,6 +47,26 @@ def split_copy(context, node):
split(context, node)


@register_torch_op(
torch_alias=[
"dim_order_ops::_to_dim_order_copy",
"dim_order_ops._to_dim_order_copy",
],
override=False,
)
def _to_dim_order_copy(context, node):
dim_order = _get_kwinputs(context, node, "dim_order", default=[None])[0]
node.kwinputs.pop("dim_order")

# In CoreML, dim_order.val will be an ndarray, so we convert it to a list
dim_order = [int(d) for d in dim_order.val]
memory_format = get_memory_format(dim_order)
assert (
memory_format == _torch.contiguous_format
), "Only contiguous memory format is supported in CoreML"
to(context, node)


# https://github.com/apple/coremltools/pull/2558
@register_torch_op(
torch_alias=["torchao::dequantize_affine", "torchao.dequantize_affine"],
Expand Down
6 changes: 1 addition & 5 deletions examples/apple/coreml/llama/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from executorch.exir import to_edge_transform_and_lower
from executorch.exir.backend.utils import format_delegated_graph
from executorch.exir.capture._config import EdgeCompileConfig, ExecutorchBackendConfig
from executorch.exir.capture._config import ExecutorchBackendConfig
from executorch.exir.passes import MemoryPlanningPass
from executorch.exir.passes.quant_fusion_pass import QuantFusionPass
from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass
Expand Down Expand Up @@ -203,10 +203,6 @@ def main() -> None:
edge_manager = to_edge_transform_and_lower(
ep,
partitioner=[partitioner],
compile_config=EdgeCompileConfig(
# TODO: fix lowering when dim_order is enabled
_skip_dim_order=True,
),
)

print("Delegated program")
Expand Down
Loading