Skip to content

Commit 5c5a20b

Browse files
hsharma35facebook-github-bot
authored andcommitted
Add idma Fake operators for AoT. (pytorch#11867)
Summary: Pull Request resolved: pytorch#11867 Add idma copy and wait op definitions + fake implementation. These can be used for asynchronous copy/wait. idma_load: Used for `Any -> DTCM` data transfers. (Note: DTCM to DTCM is ok) idma_store: Used for `DTCM -> Any` (similar to idma_load). Since the above ops explicitly specify the direction of data transfer, we can use these to plan memory in specific dtcm banks. Reviewed By: nitish2112 Differential Revision: D77164673
1 parent 6f44a79 commit 5c5a20b

File tree

3 files changed

+98
-3
lines changed

3 files changed

+98
-3
lines changed

backends/cadence/aot/TARGETS

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -518,3 +518,16 @@ python_unittest(
518518
"//executorch/exir/tests:models",
519519
],
520520
)
521+
522+
python_unittest(
523+
name = "test_idma_ops",
524+
srcs = [
525+
"tests/test_idma_ops.py",
526+
],
527+
typing = True,
528+
deps = [
529+
"//executorch/backends/cadence/aot:graph_builder",
530+
"//executorch/backends/cadence/aot:ops_registrations",
531+
"//later:lib",
532+
],
533+
)

backends/cadence/aot/ops_registrations.py

Lines changed: 54 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,21 @@
174174
"rope.out(Tensor input, Tensor sin_tensor, Tensor cos_tensor, Tensor? pos, *, Tensor(a!) out) -> Tensor(a!)"
175175
)
176176

177+
# Load/store with iDMA. These only exist before memory planning.
178+
# Post memory planning, we check that outputs/inputs for the load/store are in
179+
# DTCM and replace idma_load/idma_store with idma_copy.
180+
lib.define("idma_load(Tensor src, int task_num=0, int channel=0) -> Tensor")
181+
lib.define("idma_store(Tensor src, int task_num=0, int channel=0) -> Tensor")
182+
183+
# Non-blocking iDMA copy.
184+
lib.define("idma_copy(Tensor src, int task_num=0, int channel=0) -> Tensor")
185+
lib.define(
186+
"idma_copy.out(Tensor src, int task_num=0, int channel=0, *, Tensor(a!) out) -> Tensor(a!)"
187+
)
188+
# iDMA wait.
189+
lib.define("idma_wait(Tensor src, int task_num=0) -> Tensor")
190+
lib.define("idma_wait.out(Tensor src, int task_num=0, *, Tensor(a!) out) -> Tensor(a!)")
191+
177192
# ------------------------------------ #
178193
# Migrated from custom_ops.yaml #
179194
# ------------------------------------ #
@@ -983,7 +998,43 @@ def rope_meta(
983998
assert (
984999
len(sin_shape) == 2 and sin_shape[-1] == hd // 2
9851000
), f"{sin_shape=} must be [seq, hd/2]"
986-
assert (
987-
pos is None or len(pos.shape) == 1 and pos.shape[0] == seq
988-
), f"{pos.shape} must be [{seq}]"
1001+
if pos is not None:
1002+
assert (
1003+
len(pos.shape) == 1 and pos.shape[0] == seq
1004+
), f"{pos.shape} must be [{seq}]"
9891005
return input.new_empty(input.shape, dtype=input.dtype)
1006+
1007+
1008+
@register_fake("cadence::idma_copy")
1009+
def copy_idma_copy_impl(
1010+
src: torch.Tensor,
1011+
task_num: int = 0,
1012+
channel: int = 0,
1013+
) -> torch.Tensor:
1014+
return src.new_empty(*src.shape, dtype=src.dtype)
1015+
1016+
1017+
@register_fake("cadence::idma_wait")
1018+
def copy_idma_wait_impl(
1019+
src: torch.Tensor,
1020+
task_num: int = 0,
1021+
) -> torch.Tensor:
1022+
return src.new_empty(*src.shape, dtype=src.dtype)
1023+
1024+
1025+
@register_fake("cadence::idma_load")
1026+
def idma_load_impl(
1027+
src: torch.Tensor,
1028+
task_num: int = 0,
1029+
channel: int = 0,
1030+
) -> torch.Tensor:
1031+
return copy_idma_copy_impl(src, task_num, channel)
1032+
1033+
1034+
@register_fake("cadence::idma_store")
1035+
def idma_store_impl(
1036+
src: torch.Tensor,
1037+
task_num: int = 0,
1038+
channel: int = 0,
1039+
) -> torch.Tensor:
1040+
return copy_idma_copy_impl(src, task_num, channel)
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
3+
# pyre-strict
4+
5+
import executorch.backends.cadence.aot.ops_registrations # noqa
6+
import torch
7+
8+
from executorch.backends.cadence.aot.graph_builder import GraphBuilder
9+
from executorch.exir.dialects._ops import ops as exir_ops
10+
11+
from later.unittest import TestCase
12+
13+
14+
class TestIdmaOps(TestCase):
15+
async def test_idma_load_store_wait(self) -> None:
16+
"""Check that the idma load/store/wait ops are registered correctly."""
17+
builder = GraphBuilder()
18+
x = builder.placeholder("x", torch.ones(2, 7, dtype=torch.float32))
19+
load = builder.call_operator(
20+
op=exir_ops.edge.cadence.idma_load.default, args=(x,)
21+
)
22+
wait = builder.call_operator(
23+
op=exir_ops.edge.cadence.idma_wait.default, args=(load,)
24+
)
25+
store = builder.call_operator(
26+
op=exir_ops.edge.cadence.idma_store.default, args=(wait,)
27+
)
28+
wait2 = builder.call_operator(
29+
op=exir_ops.edge.cadence.idma_wait.default, args=(store,)
30+
)
31+
builder.output([wait2])

0 commit comments

Comments
 (0)