Skip to content

Commit a4dc961

Browse files
committed
Handling for transfer_n_pin and trivially_replicable
Signed-off-by: Alex Vasile <[email protected]>
1 parent bedc1ac commit a4dc961

File tree

4 files changed

+98
-1
lines changed

4 files changed

+98
-1
lines changed

sharktank/sharktank/ops/_registry.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,18 @@ def __call__(self, *args, **kwargs):
270270
selected_override, *results = trampoline(self, *args, **kwargs)
271271
if _ENABLE_TEST_LAST_OP_DISPATCH:
272272
global _TEST_LAST_OP_DISPATCH
273-
_TEST_LAST_OP_DISPATCH = selected_override
273+
274+
if hasattr(selected_override, "_trivially_replicable_wrapper"):
275+
# For trivially replicable wrappers, don't set _TEST_LAST_OP_DISPATCH
276+
# the inner calls (which occured already)will set it to the actual op.
277+
# NOTE: This assumes that all shards called the same op.
278+
pass
279+
else:
280+
# For wrappers such as `transfer_n_pin`, we set _TEST_LAST_OP_DISPATCH to the original op (not the wrapper).
281+
_TEST_LAST_OP_DISPATCH = getattr(
282+
selected_override, "_unwrapped", selected_override
283+
)
284+
274285
arity = len(results)
275286
if arity == 1:
276287
return results[0]

sharktank/sharktank/ops/sharded_impls.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,18 @@ def func_wrapper(*args: Tuple, **kwargs: Dict[str, Any]):
109109
return res
110110

111111
func_wrapper._impl_name = getattr(f, "_impl_name", None) # For impl selection
112+
113+
if hasattr(f, "_trivially_replicable_wrapper"):
114+
# If wrapping a trivially replicable function, we do not know what underlying op will be called on each shard,
115+
# since we don't dispatch based on shards.
116+
# Instead label this wrapper as a trivially replicable wrapper so that
117+
# _TEST_LAST_OP_DISPATCH tracking can handle it correctly.
118+
# _TEST_LAST_OP_DISPATCH will not update for this wrapper, but instead allow the last inner call to set it.
119+
func_wrapper._trivially_replicable_wrapper = f._trivially_replicable_wrapper
120+
else:
121+
# We know the underlying op will be called, set __unwrapped__ to the original op
122+
# so that _TEST_LAST_OP_DISPATCH tracking can handle it correctly.
123+
func_wrapper._unwrapped = f
112124
return func_wrapper
113125

114126
def wrap_override(signature_dispatcher_override):

sharktank/sharktank/ops/utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,10 @@ def fn(a: torch.Tensor) -> torch.Tensor:
152152
def wrapper(*args, **kwargs):
153153
return call_trivially_replicable(fn, args, kwargs)
154154

155+
# Mark the trivially replicable wrapper as such so that _TEST_LAST_OP_DISPATCH tracking handles it correctly.
156+
# There is no way for us to know which op will be called on each shard, so we cannot set _unwrapped the way we can for `transfer_n_pin`.
157+
wrapper._trivially_replicable_wrapper = True
158+
155159
return wrapper
156160

157161

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
# Copyright 2025 Advanced Micro Devices, Inc.
2+
#
3+
# Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
# See https://llvm.org/LICENSE.txt for license information.
5+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
7+
from typing import Callable
8+
import unittest
9+
import torch
10+
11+
from parameterized import parameterized
12+
13+
from sharktank import ops
14+
from sharktank.ops.sharded_impls import zeros_replicated
15+
from sharktank.types import DefaultPrimitiveTensor
16+
from sharktank.ops.default_impls import abs_default, cos_default, zeros_default
17+
from sharktank.types.tensors import InferenceTensor, ReplicatedTensor
18+
from sharktank.utils.testing import assert_tensor_close
19+
20+
21+
class DispatchTest(unittest.TestCase):
22+
def setUp(self):
23+
ops._registry._test_enable_last_op_dispatch(True)
24+
25+
def tearDown(self):
26+
ops._registry._test_enable_last_op_dispatch(False)
27+
28+
def make_tensor(self, shards: int) -> InferenceTensor:
29+
tensor = torch.tensor([1.0, 2.0, 3.0])
30+
if shards == 1:
31+
return DefaultPrimitiveTensor(data=tensor)
32+
else:
33+
return ReplicatedTensor(ts=tensor, shard_count=shards)
34+
35+
@parameterized.expand(
36+
[
37+
(None, zeros_default),
38+
((2, 3), zeros_replicated._unwrapped),
39+
]
40+
)
41+
def test_non_trivially_replicable_op(
42+
self, devices: tuple[int, ...] | None, expected_dispatch: Callable
43+
):
44+
ops.zeros(1, devices=devices)
45+
46+
last_dispatch_after_zeros = ops._registry._test_get_last_op_dispatch()
47+
self.assertIs(last_dispatch_after_zeros, expected_dispatch)
48+
49+
@parameterized.expand([(1,), (2,)])
50+
def test_trivially_replicable_op(self, shard_count: int):
51+
self.make_tensor(shard_count).abs()
52+
53+
last_dispatch = ops._registry._test_get_last_op_dispatch()
54+
self.assertIs(last_dispatch, abs_default)
55+
56+
@parameterized.expand([(1,), (2,)])
57+
def test_multiple_dispatch(self, shard_count: int):
58+
tensor = self.make_tensor(shard_count)
59+
60+
tensor.abs()
61+
last_dispatch_after_abs = ops._registry._test_get_last_op_dispatch()
62+
self.assertIs(last_dispatch_after_abs, abs_default)
63+
64+
tensor.cos()
65+
last_dispatch_after_cos = ops._registry._test_get_last_op_dispatch()
66+
self.assertIs(last_dispatch_after_cos, cos_default)
67+
68+
69+
if __name__ == "__main__":
70+
unittest.main()

0 commit comments

Comments
 (0)