Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion sharktank/sharktank/ops/_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
"SignatureDispatcher",
"BoolTypeExpr",
"get_all_registered_ops",
"unwrap_if_possible",
]

_TargetOverride = collections.namedtuple(
Expand All @@ -60,6 +61,15 @@ def get_all_registered_ops() -> dict[str, "SignatureDispatcher"]:
_TEST_LAST_OP_DISPATCH = None


def unwrap_if_possible(op: Callable) -> Callable:
"""
If the op is unwrapped, return it unchanges.
If the wrapped op is a specific override (e.g. abs_default) and is wrapped, then return the original op.
If the wrapped op is a trivially replicable wrapper, then the wrapper is returned since we cannot know which op will be called on each shard.
"""
return getattr(op, "_unwrapped", op)


def _test_enable_last_op_dispatch(en: bool = True):
global _TEST_LAST_OP_DISPATCH
global _ENABLE_TEST_LAST_OP_DISPATCH
Expand Down Expand Up @@ -270,7 +280,16 @@ def __call__(self, *args, **kwargs):
selected_override, *results = trampoline(self, *args, **kwargs)
if _ENABLE_TEST_LAST_OP_DISPATCH:
global _TEST_LAST_OP_DISPATCH
_TEST_LAST_OP_DISPATCH = selected_override

if hasattr(selected_override, "_trivially_replicable_wrapper"):
# For trivially replicable wrappers, don't set _TEST_LAST_OP_DISPATCH
# the inner calls (which occured already)will set it to the actual op.
# NOTE: This assumes that all shards called the same op.
pass
else:
# For wrappers such as `transfer_n_pin`, we set _TEST_LAST_OP_DISPATCH to the original op (not the wrapper).
_TEST_LAST_OP_DISPATCH = unwrap_if_possible(selected_override)

arity = len(results)
if arity == 1:
return results[0]
Expand Down
96 changes: 55 additions & 41 deletions sharktank/sharktank/ops/sharded_impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
BoolTypeExpr,
IsOfType,
SignatureDispatcher,
unwrap_if_possible,
)
from .shape import (
broadcast_dims,
Expand All @@ -65,52 +66,65 @@ def assert_on_same_devices(*tensors: Tuple[ShardedTensor]) -> None:
raise ValueError("All tensors must be placed on the same devices.")


def sharded_wrap_override():
def transfer_n_pin(f):
def transfer_n_pin(f):
"""
Wrapper for each NON-TRANSFERRING op defined in this file.
"""

def func_wrapper(*args: Tuple, **kwargs: Dict[str, Any]):
"""
Wrapper for each NON-TRANSFERRING op defined in this file.
Wraps each NON-TRANSFERRING operation, f, to ensure that all incoming tensors are on the same device and that the result has the devices correctly labelled.

If no ShardedTensors are present in the input, then no changes are made to input/output.
"""
sharded_tensors = []
for value in itertools.chain(args, kwargs.values()):
if isinstance(value, ShardedTensor):
sharded_tensors.append(value)
continue
if isinstance(
value,
(
InferenceTensor,
Tensor,
),
):
continue
if isinstance(value, Iterable):
for val in value:
if isinstance(val, ShardedTensor):
sharded_tensors.append(val)

assert_on_same_devices(*sharded_tensors)
res = f(*args, **kwargs)
if len(sharded_tensors) > 0:
if isinstance(res, ShardedTensor):
res = res.clone(devices=sharded_tensors[0].devices)
elif isinstance(res, Iterable) and all(
isinstance(r, ShardedTensor) for r in res
):
res = type(res)(
r.clone(devices=sharded_tensors[0].devices) for r in res
)
return res

def func_wrapper(*args: Tuple, **kwargs: Dict[str, Any]):
"""
Wraps each NON-TRANSFERRING operation, f, to ensure that all incoming tensors are on the same device and that the result has the devices correctly labelled.

If no ShardedTensors are present in the input, then no changes are made to input/output.
"""
sharded_tensors = []
for value in itertools.chain(args, kwargs.values()):
if isinstance(value, ShardedTensor):
sharded_tensors.append(value)
continue
if isinstance(
value,
(
InferenceTensor,
Tensor,
),
):
continue
if isinstance(value, Iterable):
for val in value:
if isinstance(val, ShardedTensor):
sharded_tensors.append(val)

assert_on_same_devices(*sharded_tensors)
res = f(*args, **kwargs)
if len(sharded_tensors) > 0:
if isinstance(res, ShardedTensor):
res = res.clone(devices=sharded_tensors[0].devices)
elif isinstance(res, Iterable) and all(
isinstance(r, ShardedTensor) for r in res
):
res = type(res)(
r.clone(devices=sharded_tensors[0].devices) for r in res
)
return res
func_wrapper._impl_name = getattr(f, "_impl_name", None) # For impl selection

if hasattr(f, "_trivially_replicable_wrapper"):
# If wrapping a trivially replicable function, we do not know what underlying op will be called on each shard,
# since we don't dispatch based on shards.
# Instead label this wrapper as a trivially replicable wrapper so that
# _TEST_LAST_OP_DISPATCH tracking can handle it correctly.
# _TEST_LAST_OP_DISPATCH will not update for this wrapper, but instead allow the last inner call to set it.
func_wrapper._trivially_replicable_wrapper = f._trivially_replicable_wrapper
else:
# We know which underlying op will be called, set _unwrapped to the original op
# so that _TEST_LAST_OP_DISPATCH tracking can handle it correctly.
func_wrapper._unwrapped = unwrap_if_possible(f)
return func_wrapper

func_wrapper._impl_name = getattr(f, "_impl_name", None) # For impl selection
return func_wrapper

def sharded_wrap_override():
def wrap_override(signature_dispatcher_override):
"""
Wrap [op].override's result so that the transfer_n_pin(f) becomes the target in _TargetOverride rather than f itself.
Expand Down
4 changes: 4 additions & 0 deletions sharktank/sharktank/ops/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,10 @@ def fn(a: torch.Tensor) -> torch.Tensor:
def wrapper(*args, **kwargs):
return call_trivially_replicable(fn, args, kwargs)

# Mark the trivially replicable wrapper as such so that _TEST_LAST_OP_DISPATCH tracking handles it correctly.
# There is no way for us to know which op will be called on each shard, so we cannot set _unwrapped the way we can for `transfer_n_pin`.
wrapper._trivially_replicable_wrapper = True

return wrapper


Expand Down
94 changes: 94 additions & 0 deletions sharktank/tests/ops/dispatch_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# Copyright 2025 Advanced Micro Devices, Inc.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from typing import Callable
import unittest
import torch

from parameterized import parameterized

from sharktank import ops
from sharktank.types import DefaultPrimitiveTensor, InferenceTensor, ReplicatedTensor
from sharktank.ops.sharded_impls import zeros_replicated, transfer_n_pin
from sharktank.ops.default_impls import abs_default, cos_default, zeros_default
from sharktank.ops._registry import (
unwrap_if_possible,
_test_enable_last_op_dispatch,
_test_get_last_op_dispatch,
)
from sharktank.ops.utils import trivially_replicable


class DispatchTest(unittest.TestCase):
def setUp(self):
_test_enable_last_op_dispatch(True)

def tearDown(self):
_test_enable_last_op_dispatch(False)

def make_tensor(self, shards: int) -> InferenceTensor:
tensor = torch.tensor([1.0, 2.0, 3.0])
if shards == 1:
return DefaultPrimitiveTensor(data=tensor)
else:
return ReplicatedTensor(ts=tensor, shard_count=shards)

@parameterized.expand(
[
(None, unwrap_if_possible(zeros_default)),
((2, 3), unwrap_if_possible(zeros_replicated)),
]
)
def test_non_trivially_replicable_op(
self, devices: tuple[int, ...] | None, expected_dispatch: Callable
):
ops.zeros(1, devices=devices)

last_dispatch_after_zeros = _test_get_last_op_dispatch()
self.assertIs(last_dispatch_after_zeros, expected_dispatch)

@parameterized.expand([(1,), (2,)])
def test_trivially_replicable_op(self, shard_count: int):
self.make_tensor(shard_count).abs()

last_dispatch = _test_get_last_op_dispatch()
self.assertIs(last_dispatch, unwrap_if_possible(abs_default))

@parameterized.expand([(1,), (2,)])
def test_multiple_dispatch(self, shard_count: int):
tensor = self.make_tensor(shard_count)

tensor.abs()
last_dispatch_after_abs = _test_get_last_op_dispatch()
self.assertIs(last_dispatch_after_abs, unwrap_if_possible(abs_default))

tensor.cos()
last_dispatch_after_cos = _test_get_last_op_dispatch()
self.assertIs(last_dispatch_after_cos, unwrap_if_possible(cos_default))


def f(*args, **kwargs) -> torch.Tensor:
...


class UnwrapIfPossibleTest(unittest.TestCase):
def test_unwrap_no_wrapper(self):
self.assertIs(unwrap_if_possible(f), f)

@parameterized.expand([transfer_n_pin])
def test_unwrap_with_wrapper(self, wrapping_fn: Callable):
f_wrapped = wrapping_fn(f)
self.assertIsNot(f_wrapped, f)
self.assertIs(unwrap_if_possible(f_wrapped), f)

def test_unwrap_with_trivially_replicable_wrapper(self):
f_wrapped = trivially_replicable(f)
self.assertIsNot(unwrap_if_possible(f_wrapped), f)
assert f_wrapped._trivially_replicable_wrapper


if __name__ == "__main__":
unittest.main()
Loading