From c409a0d56c734b841237f9f9921ffb0146ef1666 Mon Sep 17 00:00:00 2001 From: Alex Vasile <48962821+Alex-Vasile@users.noreply.github.com> Date: Mon, 20 Oct 2025 22:09:59 +0000 Subject: [PATCH 1/3] Handling for transfer_n_pin and trivially_replicable Signed-off-by: Alex Vasile <48962821+Alex-Vasile@users.noreply.github.com> --- sharktank/sharktank/ops/_registry.py | 13 ++++- sharktank/sharktank/ops/sharded_impls.py | 12 ++++ sharktank/sharktank/ops/utils.py | 4 ++ sharktank/tests/ops/dispatch_test.py | 70 ++++++++++++++++++++++++ 4 files changed, 98 insertions(+), 1 deletion(-) create mode 100644 sharktank/tests/ops/dispatch_test.py diff --git a/sharktank/sharktank/ops/_registry.py b/sharktank/sharktank/ops/_registry.py index aacbbf3ceaf..948ae198854 100644 --- a/sharktank/sharktank/ops/_registry.py +++ b/sharktank/sharktank/ops/_registry.py @@ -270,7 +270,18 @@ def __call__(self, *args, **kwargs): selected_override, *results = trampoline(self, *args, **kwargs) if _ENABLE_TEST_LAST_OP_DISPATCH: global _TEST_LAST_OP_DISPATCH - _TEST_LAST_OP_DISPATCH = selected_override + + if hasattr(selected_override, "_trivially_replicable_wrapper"): + # For trivially replicable wrappers, don't set _TEST_LAST_OP_DISPATCH + # the inner calls (which occured already)will set it to the actual op. + # NOTE: This assumes that all shards called the same op. + pass + else: + # For wrappers such as `transfer_n_pin`, we set _TEST_LAST_OP_DISPATCH to the original op (not the wrapper). + _TEST_LAST_OP_DISPATCH = getattr( + selected_override, "_unwrapped", selected_override + ) + arity = len(results) if arity == 1: return results[0] diff --git a/sharktank/sharktank/ops/sharded_impls.py b/sharktank/sharktank/ops/sharded_impls.py index 4ac5c950098..5361dedc4e3 100644 --- a/sharktank/sharktank/ops/sharded_impls.py +++ b/sharktank/sharktank/ops/sharded_impls.py @@ -109,6 +109,18 @@ def func_wrapper(*args: Tuple, **kwargs: Dict[str, Any]): return res func_wrapper._impl_name = getattr(f, "_impl_name", None) # For impl selection + + if hasattr(f, "_trivially_replicable_wrapper"): + # If wrapping a trivially replicable function, we do not know what underlying op will be called on each shard, + # since we don't dispatch based on shards. + # Instead label this wrapper as a trivially replicable wrapper so that + # _TEST_LAST_OP_DISPATCH tracking can handle it correctly. + # _TEST_LAST_OP_DISPATCH will not update for this wrapper, but instead allow the last inner call to set it. + func_wrapper._trivially_replicable_wrapper = f._trivially_replicable_wrapper + else: + # We know the underlying op will be called, set __unwrapped__ to the original op + # so that _TEST_LAST_OP_DISPATCH tracking can handle it correctly. + func_wrapper._unwrapped = f return func_wrapper def wrap_override(signature_dispatcher_override): diff --git a/sharktank/sharktank/ops/utils.py b/sharktank/sharktank/ops/utils.py index bb75305dd9c..1b7d58a823c 100644 --- a/sharktank/sharktank/ops/utils.py +++ b/sharktank/sharktank/ops/utils.py @@ -152,6 +152,10 @@ def fn(a: torch.Tensor) -> torch.Tensor: def wrapper(*args, **kwargs): return call_trivially_replicable(fn, args, kwargs) + # Mark the trivially replicable wrapper as such so that _TEST_LAST_OP_DISPATCH tracking handles it correctly. + # There is no way for us to know which op will be called on each shard, so we cannot set _unwrapped the way we can for `transfer_n_pin`. + wrapper._trivially_replicable_wrapper = True + return wrapper diff --git a/sharktank/tests/ops/dispatch_test.py b/sharktank/tests/ops/dispatch_test.py new file mode 100644 index 00000000000..881e79f3d0f --- /dev/null +++ b/sharktank/tests/ops/dispatch_test.py @@ -0,0 +1,70 @@ +# Copyright 2025 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from typing import Callable +import unittest +import torch + +from parameterized import parameterized + +from sharktank import ops +from sharktank.ops.sharded_impls import zeros_replicated +from sharktank.types import DefaultPrimitiveTensor +from sharktank.ops.default_impls import abs_default, cos_default, zeros_default +from sharktank.types.tensors import InferenceTensor, ReplicatedTensor +from sharktank.utils.testing import assert_tensor_close + + +class DispatchTest(unittest.TestCase): + def setUp(self): + ops._registry._test_enable_last_op_dispatch(True) + + def tearDown(self): + ops._registry._test_enable_last_op_dispatch(False) + + def make_tensor(self, shards: int) -> InferenceTensor: + tensor = torch.tensor([1.0, 2.0, 3.0]) + if shards == 1: + return DefaultPrimitiveTensor(data=tensor) + else: + return ReplicatedTensor(ts=tensor, shard_count=shards) + + @parameterized.expand( + [ + (None, zeros_default), + ((2, 3), zeros_replicated._unwrapped), + ] + ) + def test_non_trivially_replicable_op( + self, devices: tuple[int, ...] | None, expected_dispatch: Callable + ): + ops.zeros(1, devices=devices) + + last_dispatch_after_zeros = ops._registry._test_get_last_op_dispatch() + self.assertIs(last_dispatch_after_zeros, expected_dispatch) + + @parameterized.expand([(1,), (2,)]) + def test_trivially_replicable_op(self, shard_count: int): + self.make_tensor(shard_count).abs() + + last_dispatch = ops._registry._test_get_last_op_dispatch() + self.assertIs(last_dispatch, abs_default) + + @parameterized.expand([(1,), (2,)]) + def test_multiple_dispatch(self, shard_count: int): + tensor = self.make_tensor(shard_count) + + tensor.abs() + last_dispatch_after_abs = ops._registry._test_get_last_op_dispatch() + self.assertIs(last_dispatch_after_abs, abs_default) + + tensor.cos() + last_dispatch_after_cos = ops._registry._test_get_last_op_dispatch() + self.assertIs(last_dispatch_after_cos, cos_default) + + +if __name__ == "__main__": + unittest.main() From bab0d835abbfb6a303f2cc034ec4e740f7b2cf56 Mon Sep 17 00:00:00 2001 From: Alex Vasile <48962821+Alex-Vasile@users.noreply.github.com> Date: Tue, 21 Oct 2025 15:33:38 +0000 Subject: [PATCH 2/3] Fix leftover name Signed-off-by: Alex Vasile <48962821+Alex-Vasile@users.noreply.github.com> --- sharktank/sharktank/ops/sharded_impls.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sharktank/sharktank/ops/sharded_impls.py b/sharktank/sharktank/ops/sharded_impls.py index 5361dedc4e3..c51ad07d216 100644 --- a/sharktank/sharktank/ops/sharded_impls.py +++ b/sharktank/sharktank/ops/sharded_impls.py @@ -118,7 +118,7 @@ def func_wrapper(*args: Tuple, **kwargs: Dict[str, Any]): # _TEST_LAST_OP_DISPATCH will not update for this wrapper, but instead allow the last inner call to set it. func_wrapper._trivially_replicable_wrapper = f._trivially_replicable_wrapper else: - # We know the underlying op will be called, set __unwrapped__ to the original op + # We know the underlying op will be called, set _unwrapped to the original op # so that _TEST_LAST_OP_DISPATCH tracking can handle it correctly. func_wrapper._unwrapped = f return func_wrapper From 1c3c8b21a255e19efb964b89b191aa7907c08b4a Mon Sep 17 00:00:00 2001 From: Alex Vasile <48962821+Alex-Vasile@users.noreply.github.com> Date: Tue, 21 Oct 2025 16:19:47 +0000 Subject: [PATCH 3/3] Add helper function and cleanup tests Signed-off-by: Alex Vasile <48962821+Alex-Vasile@users.noreply.github.com> --- sharktank/sharktank/ops/_registry.py | 14 ++- sharktank/sharktank/ops/sharded_impls.py | 110 ++++++++++++----------- sharktank/tests/ops/dispatch_test.py | 54 +++++++---- 3 files changed, 106 insertions(+), 72 deletions(-) diff --git a/sharktank/sharktank/ops/_registry.py b/sharktank/sharktank/ops/_registry.py index 948ae198854..b19344fc031 100644 --- a/sharktank/sharktank/ops/_registry.py +++ b/sharktank/sharktank/ops/_registry.py @@ -35,6 +35,7 @@ "SignatureDispatcher", "BoolTypeExpr", "get_all_registered_ops", + "unwrap_if_possible", ] _TargetOverride = collections.namedtuple( @@ -60,6 +61,15 @@ def get_all_registered_ops() -> dict[str, "SignatureDispatcher"]: _TEST_LAST_OP_DISPATCH = None +def unwrap_if_possible(op: Callable) -> Callable: + """ + If the op is unwrapped, return it unchanges. + If the wrapped op is a specific override (e.g. abs_default) and is wrapped, then return the original op. + If the wrapped op is a trivially replicable wrapper, then the wrapper is returned since we cannot know which op will be called on each shard. + """ + return getattr(op, "_unwrapped", op) + + def _test_enable_last_op_dispatch(en: bool = True): global _TEST_LAST_OP_DISPATCH global _ENABLE_TEST_LAST_OP_DISPATCH @@ -278,9 +288,7 @@ def __call__(self, *args, **kwargs): pass else: # For wrappers such as `transfer_n_pin`, we set _TEST_LAST_OP_DISPATCH to the original op (not the wrapper). - _TEST_LAST_OP_DISPATCH = getattr( - selected_override, "_unwrapped", selected_override - ) + _TEST_LAST_OP_DISPATCH = unwrap_if_possible(selected_override) arity = len(results) if arity == 1: diff --git a/sharktank/sharktank/ops/sharded_impls.py b/sharktank/sharktank/ops/sharded_impls.py index c51ad07d216..058f39f9812 100644 --- a/sharktank/sharktank/ops/sharded_impls.py +++ b/sharktank/sharktank/ops/sharded_impls.py @@ -40,6 +40,7 @@ BoolTypeExpr, IsOfType, SignatureDispatcher, + unwrap_if_possible, ) from .shape import ( broadcast_dims, @@ -65,64 +66,65 @@ def assert_on_same_devices(*tensors: Tuple[ShardedTensor]) -> None: raise ValueError("All tensors must be placed on the same devices.") -def sharded_wrap_override(): - def transfer_n_pin(f): +def transfer_n_pin(f): + """ + Wrapper for each NON-TRANSFERRING op defined in this file. + """ + + def func_wrapper(*args: Tuple, **kwargs: Dict[str, Any]): """ - Wrapper for each NON-TRANSFERRING op defined in this file. + Wraps each NON-TRANSFERRING operation, f, to ensure that all incoming tensors are on the same device and that the result has the devices correctly labelled. + + If no ShardedTensors are present in the input, then no changes are made to input/output. """ + sharded_tensors = [] + for value in itertools.chain(args, kwargs.values()): + if isinstance(value, ShardedTensor): + sharded_tensors.append(value) + continue + if isinstance( + value, + ( + InferenceTensor, + Tensor, + ), + ): + continue + if isinstance(value, Iterable): + for val in value: + if isinstance(val, ShardedTensor): + sharded_tensors.append(val) + + assert_on_same_devices(*sharded_tensors) + res = f(*args, **kwargs) + if len(sharded_tensors) > 0: + if isinstance(res, ShardedTensor): + res = res.clone(devices=sharded_tensors[0].devices) + elif isinstance(res, Iterable) and all( + isinstance(r, ShardedTensor) for r in res + ): + res = type(res)( + r.clone(devices=sharded_tensors[0].devices) for r in res + ) + return res - def func_wrapper(*args: Tuple, **kwargs: Dict[str, Any]): - """ - Wraps each NON-TRANSFERRING operation, f, to ensure that all incoming tensors are on the same device and that the result has the devices correctly labelled. - - If no ShardedTensors are present in the input, then no changes are made to input/output. - """ - sharded_tensors = [] - for value in itertools.chain(args, kwargs.values()): - if isinstance(value, ShardedTensor): - sharded_tensors.append(value) - continue - if isinstance( - value, - ( - InferenceTensor, - Tensor, - ), - ): - continue - if isinstance(value, Iterable): - for val in value: - if isinstance(val, ShardedTensor): - sharded_tensors.append(val) - - assert_on_same_devices(*sharded_tensors) - res = f(*args, **kwargs) - if len(sharded_tensors) > 0: - if isinstance(res, ShardedTensor): - res = res.clone(devices=sharded_tensors[0].devices) - elif isinstance(res, Iterable) and all( - isinstance(r, ShardedTensor) for r in res - ): - res = type(res)( - r.clone(devices=sharded_tensors[0].devices) for r in res - ) - return res - - func_wrapper._impl_name = getattr(f, "_impl_name", None) # For impl selection - - if hasattr(f, "_trivially_replicable_wrapper"): - # If wrapping a trivially replicable function, we do not know what underlying op will be called on each shard, - # since we don't dispatch based on shards. - # Instead label this wrapper as a trivially replicable wrapper so that - # _TEST_LAST_OP_DISPATCH tracking can handle it correctly. - # _TEST_LAST_OP_DISPATCH will not update for this wrapper, but instead allow the last inner call to set it. - func_wrapper._trivially_replicable_wrapper = f._trivially_replicable_wrapper - else: - # We know the underlying op will be called, set _unwrapped to the original op - # so that _TEST_LAST_OP_DISPATCH tracking can handle it correctly. - func_wrapper._unwrapped = f - return func_wrapper + func_wrapper._impl_name = getattr(f, "_impl_name", None) # For impl selection + if hasattr(f, "_trivially_replicable_wrapper"): + # If wrapping a trivially replicable function, we do not know what underlying op will be called on each shard, + # since we don't dispatch based on shards. + # Instead label this wrapper as a trivially replicable wrapper so that + # _TEST_LAST_OP_DISPATCH tracking can handle it correctly. + # _TEST_LAST_OP_DISPATCH will not update for this wrapper, but instead allow the last inner call to set it. + func_wrapper._trivially_replicable_wrapper = f._trivially_replicable_wrapper + else: + # We know which underlying op will be called, set _unwrapped to the original op + # so that _TEST_LAST_OP_DISPATCH tracking can handle it correctly. + func_wrapper._unwrapped = unwrap_if_possible(f) + return func_wrapper + + +def sharded_wrap_override(): def wrap_override(signature_dispatcher_override): """ Wrap [op].override's result so that the transfer_n_pin(f) becomes the target in _TargetOverride rather than f itself. diff --git a/sharktank/tests/ops/dispatch_test.py b/sharktank/tests/ops/dispatch_test.py index 881e79f3d0f..cc63c6f27b3 100644 --- a/sharktank/tests/ops/dispatch_test.py +++ b/sharktank/tests/ops/dispatch_test.py @@ -11,19 +11,23 @@ from parameterized import parameterized from sharktank import ops -from sharktank.ops.sharded_impls import zeros_replicated -from sharktank.types import DefaultPrimitiveTensor +from sharktank.types import DefaultPrimitiveTensor, InferenceTensor, ReplicatedTensor +from sharktank.ops.sharded_impls import zeros_replicated, transfer_n_pin from sharktank.ops.default_impls import abs_default, cos_default, zeros_default -from sharktank.types.tensors import InferenceTensor, ReplicatedTensor -from sharktank.utils.testing import assert_tensor_close +from sharktank.ops._registry import ( + unwrap_if_possible, + _test_enable_last_op_dispatch, + _test_get_last_op_dispatch, +) +from sharktank.ops.utils import trivially_replicable class DispatchTest(unittest.TestCase): def setUp(self): - ops._registry._test_enable_last_op_dispatch(True) + _test_enable_last_op_dispatch(True) def tearDown(self): - ops._registry._test_enable_last_op_dispatch(False) + _test_enable_last_op_dispatch(False) def make_tensor(self, shards: int) -> InferenceTensor: tensor = torch.tensor([1.0, 2.0, 3.0]) @@ -34,8 +38,8 @@ def make_tensor(self, shards: int) -> InferenceTensor: @parameterized.expand( [ - (None, zeros_default), - ((2, 3), zeros_replicated._unwrapped), + (None, unwrap_if_possible(zeros_default)), + ((2, 3), unwrap_if_possible(zeros_replicated)), ] ) def test_non_trivially_replicable_op( @@ -43,27 +47,47 @@ def test_non_trivially_replicable_op( ): ops.zeros(1, devices=devices) - last_dispatch_after_zeros = ops._registry._test_get_last_op_dispatch() + last_dispatch_after_zeros = _test_get_last_op_dispatch() self.assertIs(last_dispatch_after_zeros, expected_dispatch) @parameterized.expand([(1,), (2,)]) def test_trivially_replicable_op(self, shard_count: int): self.make_tensor(shard_count).abs() - last_dispatch = ops._registry._test_get_last_op_dispatch() - self.assertIs(last_dispatch, abs_default) + last_dispatch = _test_get_last_op_dispatch() + self.assertIs(last_dispatch, unwrap_if_possible(abs_default)) @parameterized.expand([(1,), (2,)]) def test_multiple_dispatch(self, shard_count: int): tensor = self.make_tensor(shard_count) tensor.abs() - last_dispatch_after_abs = ops._registry._test_get_last_op_dispatch() - self.assertIs(last_dispatch_after_abs, abs_default) + last_dispatch_after_abs = _test_get_last_op_dispatch() + self.assertIs(last_dispatch_after_abs, unwrap_if_possible(abs_default)) tensor.cos() - last_dispatch_after_cos = ops._registry._test_get_last_op_dispatch() - self.assertIs(last_dispatch_after_cos, cos_default) + last_dispatch_after_cos = _test_get_last_op_dispatch() + self.assertIs(last_dispatch_after_cos, unwrap_if_possible(cos_default)) + + +def f(*args, **kwargs) -> torch.Tensor: + ... + + +class UnwrapIfPossibleTest(unittest.TestCase): + def test_unwrap_no_wrapper(self): + self.assertIs(unwrap_if_possible(f), f) + + @parameterized.expand([transfer_n_pin]) + def test_unwrap_with_wrapper(self, wrapping_fn: Callable): + f_wrapped = wrapping_fn(f) + self.assertIsNot(f_wrapped, f) + self.assertIs(unwrap_if_possible(f_wrapped), f) + + def test_unwrap_with_trivially_replicable_wrapper(self): + f_wrapped = trivially_replicable(f) + self.assertIsNot(unwrap_if_possible(f_wrapped), f) + assert f_wrapped._trivially_replicable_wrapper if __name__ == "__main__":