Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions backends/cadence/aot/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -518,3 +518,16 @@ python_unittest(
"//executorch/exir/tests:models",
],
)

python_unittest(
name = "test_idma_ops",
srcs = [
"tests/test_idma_ops.py",
],
typing = True,
deps = [
"//executorch/backends/cadence/aot:graph_builder",
"//executorch/backends/cadence/aot:ops_registrations",
"//later:lib",
],
)
57 changes: 54 additions & 3 deletions backends/cadence/aot/ops_registrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,21 @@
"rope.out(Tensor input, Tensor sin_tensor, Tensor cos_tensor, Tensor? pos, *, Tensor(a!) out) -> Tensor(a!)"
)

# Load/store with iDMA. These only exist before memory planning.
# Post memory planning, we check that outputs/inputs for the load/store are in
# DTCM and replace idma_load/idma_store with idma_copy.
lib.define("idma_load(Tensor src, int task_num=0, int channel=0) -> Tensor")
lib.define("idma_store(Tensor src, int task_num=0, int channel=0) -> Tensor")

# Non-blocking iDMA copy.
lib.define("idma_copy(Tensor src, int task_num=0, int channel=0) -> Tensor")
lib.define(
"idma_copy.out(Tensor src, int task_num=0, int channel=0, *, Tensor(a!) out) -> Tensor(a!)"
)
# iDMA wait.
lib.define("idma_wait(Tensor src, int task_num=0) -> Tensor")
lib.define("idma_wait.out(Tensor src, int task_num=0, *, Tensor(a!) out) -> Tensor(a!)")

# ------------------------------------ #
# Migrated from custom_ops.yaml #
# ------------------------------------ #
Expand Down Expand Up @@ -983,7 +998,43 @@ def rope_meta(
assert (
len(sin_shape) == 2 and sin_shape[-1] == hd // 2
), f"{sin_shape=} must be [seq, hd/2]"
assert (
pos is None or len(pos.shape) == 1 and pos.shape[0] == seq
), f"{pos.shape} must be [{seq}]"
if pos is not None:
assert (
len(pos.shape) == 1 and pos.shape[0] == seq
), f"{pos.shape} must be [{seq}]"
return input.new_empty(input.shape, dtype=input.dtype)


@register_fake("cadence::idma_copy")
def copy_idma_copy_impl(
src: torch.Tensor,
task_num: int = 0,
channel: int = 0,
) -> torch.Tensor:
return src.new_empty(*src.shape, dtype=src.dtype)


@register_fake("cadence::idma_wait")
def copy_idma_wait_impl(
src: torch.Tensor,
task_num: int = 0,
) -> torch.Tensor:
return src.new_empty(*src.shape, dtype=src.dtype)


@register_fake("cadence::idma_load")
def idma_load_impl(
src: torch.Tensor,
task_num: int = 0,
channel: int = 0,
) -> torch.Tensor:
return copy_idma_copy_impl(src, task_num, channel)


@register_fake("cadence::idma_store")
def idma_store_impl(
src: torch.Tensor,
task_num: int = 0,
channel: int = 0,
) -> torch.Tensor:
return copy_idma_copy_impl(src, task_num, channel)
31 changes: 31 additions & 0 deletions backends/cadence/aot/tests/test_idma_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.

# pyre-strict

import executorch.backends.cadence.aot.ops_registrations # noqa
import torch

from executorch.backends.cadence.aot.graph_builder import GraphBuilder
from executorch.exir.dialects._ops import ops as exir_ops

from later.unittest import TestCase


class TestIdmaOps(TestCase):
def test_idma_load_store_wait(self) -> None:
"""Check that the idma load/store/wait ops are registered correctly."""
builder = GraphBuilder()
x = builder.placeholder("x", torch.ones(2, 7, dtype=torch.float32))
load = builder.call_operator(
op=exir_ops.edge.cadence.idma_load.default, args=(x,)
)
wait = builder.call_operator(
op=exir_ops.edge.cadence.idma_wait.default, args=(load,)
)
store = builder.call_operator(
op=exir_ops.edge.cadence.idma_store.default, args=(wait,)
)
wait2 = builder.call_operator(
op=exir_ops.edge.cadence.idma_wait.default, args=(store,)
)
builder.output([wait2])
Loading