diff --git a/backends/cadence/aot/TARGETS b/backends/cadence/aot/TARGETS index a85cc0ca925..d7bf5f51690 100644 --- a/backends/cadence/aot/TARGETS +++ b/backends/cadence/aot/TARGETS @@ -518,3 +518,16 @@ python_unittest( "//executorch/exir/tests:models", ], ) + +python_unittest( + name = "test_idma_ops", + srcs = [ + "tests/test_idma_ops.py", + ], + typing = True, + deps = [ + "//executorch/backends/cadence/aot:graph_builder", + "//executorch/backends/cadence/aot:ops_registrations", + "//later:lib", + ], +) diff --git a/backends/cadence/aot/ops_registrations.py b/backends/cadence/aot/ops_registrations.py index 4a6edf03c0e..ff7e921741f 100644 --- a/backends/cadence/aot/ops_registrations.py +++ b/backends/cadence/aot/ops_registrations.py @@ -174,6 +174,21 @@ "rope.out(Tensor input, Tensor sin_tensor, Tensor cos_tensor, Tensor? pos, *, Tensor(a!) out) -> Tensor(a!)" ) +# Load/store with iDMA. These only exist before memory planning. +# Post memory planning, we check that outputs/inputs for the load/store are in +# DTCM and replace idma_load/idma_store with idma_copy. +lib.define("idma_load(Tensor src, int task_num=0, int channel=0) -> Tensor") +lib.define("idma_store(Tensor src, int task_num=0, int channel=0) -> Tensor") + +# Non-blocking iDMA copy. +lib.define("idma_copy(Tensor src, int task_num=0, int channel=0) -> Tensor") +lib.define( + "idma_copy.out(Tensor src, int task_num=0, int channel=0, *, Tensor(a!) out) -> Tensor(a!)" +) +# iDMA wait. +lib.define("idma_wait(Tensor src, int task_num=0) -> Tensor") +lib.define("idma_wait.out(Tensor src, int task_num=0, *, Tensor(a!) out) -> Tensor(a!)") + # ------------------------------------ # # Migrated from custom_ops.yaml # # ------------------------------------ # @@ -983,7 +998,43 @@ def rope_meta( assert ( len(sin_shape) == 2 and sin_shape[-1] == hd // 2 ), f"{sin_shape=} must be [seq, hd/2]" - assert ( - pos is None or len(pos.shape) == 1 and pos.shape[0] == seq - ), f"{pos.shape} must be [{seq}]" + if pos is not None: + assert ( + len(pos.shape) == 1 and pos.shape[0] == seq + ), f"{pos.shape} must be [{seq}]" return input.new_empty(input.shape, dtype=input.dtype) + + +@register_fake("cadence::idma_copy") +def copy_idma_copy_impl( + src: torch.Tensor, + task_num: int = 0, + channel: int = 0, +) -> torch.Tensor: + return src.new_empty(*src.shape, dtype=src.dtype) + + +@register_fake("cadence::idma_wait") +def copy_idma_wait_impl( + src: torch.Tensor, + task_num: int = 0, +) -> torch.Tensor: + return src.new_empty(*src.shape, dtype=src.dtype) + + +@register_fake("cadence::idma_load") +def idma_load_impl( + src: torch.Tensor, + task_num: int = 0, + channel: int = 0, +) -> torch.Tensor: + return copy_idma_copy_impl(src, task_num, channel) + + +@register_fake("cadence::idma_store") +def idma_store_impl( + src: torch.Tensor, + task_num: int = 0, + channel: int = 0, +) -> torch.Tensor: + return copy_idma_copy_impl(src, task_num, channel) diff --git a/backends/cadence/aot/tests/test_idma_ops.py b/backends/cadence/aot/tests/test_idma_ops.py new file mode 100644 index 00000000000..6320bfc482b --- /dev/null +++ b/backends/cadence/aot/tests/test_idma_ops.py @@ -0,0 +1,31 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. + +# pyre-strict + +import executorch.backends.cadence.aot.ops_registrations # noqa +import torch + +from executorch.backends.cadence.aot.graph_builder import GraphBuilder +from executorch.exir.dialects._ops import ops as exir_ops + +from later.unittest import TestCase + + +class TestIdmaOps(TestCase): + def test_idma_load_store_wait(self) -> None: + """Check that the idma load/store/wait ops are registered correctly.""" + builder = GraphBuilder() + x = builder.placeholder("x", torch.ones(2, 7, dtype=torch.float32)) + load = builder.call_operator( + op=exir_ops.edge.cadence.idma_load.default, args=(x,) + ) + wait = builder.call_operator( + op=exir_ops.edge.cadence.idma_wait.default, args=(load,) + ) + store = builder.call_operator( + op=exir_ops.edge.cadence.idma_store.default, args=(wait,) + ) + wait2 = builder.call_operator( + op=exir_ops.edge.cadence.idma_wait.default, args=(store,) + ) + builder.output([wait2])