Skip to content

Commit 19b9d67

Browse files
committed
Add helper function and cleanup tests
Signed-off-by: Alex Vasile <[email protected]>
1 parent 7e8cede commit 19b9d67

File tree

3 files changed

+106
-72
lines changed

3 files changed

+106
-72
lines changed

sharktank/sharktank/ops/_registry.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
"SignatureDispatcher",
3636
"BoolTypeExpr",
3737
"get_all_registered_ops",
38+
"unwrap_if_possible",
3839
]
3940

4041
_TargetOverride = collections.namedtuple(
@@ -60,6 +61,15 @@ def get_all_registered_ops() -> dict[str, "SignatureDispatcher"]:
6061
_TEST_LAST_OP_DISPATCH = None
6162

6263

64+
def unwrap_if_possible(op: Callable) -> Callable:
65+
"""
66+
If the op is unwrapped, return it unchanges.
67+
If the wrapped op is a specific override (e.g. abs_default) and is wrapped, then return the original op.
68+
If the wrapped op is a trivially replicable wrapper, then the wrapper is returned since we cannot know which op will be called on each shard.
69+
"""
70+
return getattr(op, "_unwrapped", op)
71+
72+
6373
def _test_enable_last_op_dispatch(en: bool = True):
6474
global _TEST_LAST_OP_DISPATCH
6575
global _ENABLE_TEST_LAST_OP_DISPATCH
@@ -278,9 +288,7 @@ def __call__(self, *args, **kwargs):
278288
pass
279289
else:
280290
# For wrappers such as `transfer_n_pin`, we set _TEST_LAST_OP_DISPATCH to the original op (not the wrapper).
281-
_TEST_LAST_OP_DISPATCH = getattr(
282-
selected_override, "_unwrapped", selected_override
283-
)
291+
_TEST_LAST_OP_DISPATCH = unwrap_if_possible(selected_override)
284292

285293
arity = len(results)
286294
if arity == 1:

sharktank/sharktank/ops/sharded_impls.py

Lines changed: 56 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
BoolTypeExpr,
4141
IsOfType,
4242
SignatureDispatcher,
43+
unwrap_if_possible,
4344
)
4445
from .shape import (
4546
broadcast_dims,
@@ -65,64 +66,65 @@ def assert_on_same_devices(*tensors: Tuple[ShardedTensor]) -> None:
6566
raise ValueError("All tensors must be placed on the same devices.")
6667

6768

68-
def sharded_wrap_override():
69-
def transfer_n_pin(f):
69+
def transfer_n_pin(f):
70+
"""
71+
Wrapper for each NON-TRANSFERRING op defined in this file.
72+
"""
73+
74+
def func_wrapper(*args: Tuple, **kwargs: Dict[str, Any]):
7075
"""
71-
Wrapper for each NON-TRANSFERRING op defined in this file.
76+
Wraps each NON-TRANSFERRING operation, f, to ensure that all incoming tensors are on the same device and that the result has the devices correctly labelled.
77+
78+
If no ShardedTensors are present in the input, then no changes are made to input/output.
7279
"""
80+
sharded_tensors = []
81+
for value in itertools.chain(args, kwargs.values()):
82+
if isinstance(value, ShardedTensor):
83+
sharded_tensors.append(value)
84+
continue
85+
if isinstance(
86+
value,
87+
(
88+
InferenceTensor,
89+
Tensor,
90+
),
91+
):
92+
continue
93+
if isinstance(value, Iterable):
94+
for val in value:
95+
if isinstance(val, ShardedTensor):
96+
sharded_tensors.append(val)
97+
98+
assert_on_same_devices(*sharded_tensors)
99+
res = f(*args, **kwargs)
100+
if len(sharded_tensors) > 0:
101+
if isinstance(res, ShardedTensor):
102+
res = res.clone(devices=sharded_tensors[0].devices)
103+
elif isinstance(res, Iterable) and all(
104+
isinstance(r, ShardedTensor) for r in res
105+
):
106+
res = type(res)(
107+
r.clone(devices=sharded_tensors[0].devices) for r in res
108+
)
109+
return res
73110

74-
def func_wrapper(*args: Tuple, **kwargs: Dict[str, Any]):
75-
"""
76-
Wraps each NON-TRANSFERRING operation, f, to ensure that all incoming tensors are on the same device and that the result has the devices correctly labelled.
77-
78-
If no ShardedTensors are present in the input, then no changes are made to input/output.
79-
"""
80-
sharded_tensors = []
81-
for value in itertools.chain(args, kwargs.values()):
82-
if isinstance(value, ShardedTensor):
83-
sharded_tensors.append(value)
84-
continue
85-
if isinstance(
86-
value,
87-
(
88-
InferenceTensor,
89-
torch.Tensor,
90-
),
91-
):
92-
continue
93-
if isinstance(value, Iterable):
94-
for val in value:
95-
if isinstance(val, ShardedTensor):
96-
sharded_tensors.append(val)
97-
98-
assert_on_same_devices(*sharded_tensors)
99-
res = f(*args, **kwargs)
100-
if len(sharded_tensors) > 0:
101-
if isinstance(res, ShardedTensor):
102-
res = res.clone(devices=sharded_tensors[0].devices)
103-
elif isinstance(res, Iterable) and all(
104-
isinstance(r, ShardedTensor) for r in res
105-
):
106-
res = type(res)(
107-
r.clone(devices=sharded_tensors[0].devices) for r in res
108-
)
109-
return res
110-
111-
func_wrapper._impl_name = getattr(f, "_impl_name", None) # For impl selection
112-
113-
if hasattr(f, "_trivially_replicable_wrapper"):
114-
# If wrapping a trivially replicable function, we do not know what underlying op will be called on each shard,
115-
# since we don't dispatch based on shards.
116-
# Instead label this wrapper as a trivially replicable wrapper so that
117-
# _TEST_LAST_OP_DISPATCH tracking can handle it correctly.
118-
# _TEST_LAST_OP_DISPATCH will not update for this wrapper, but instead allow the last inner call to set it.
119-
func_wrapper._trivially_replicable_wrapper = f._trivially_replicable_wrapper
120-
else:
121-
# We know the underlying op will be called, set _unwrapped to the original op
122-
# so that _TEST_LAST_OP_DISPATCH tracking can handle it correctly.
123-
func_wrapper._unwrapped = f
124-
return func_wrapper
111+
func_wrapper._impl_name = getattr(f, "_impl_name", None) # For impl selection
125112

113+
if hasattr(f, "_trivially_replicable_wrapper"):
114+
# If wrapping a trivially replicable function, we do not know what underlying op will be called on each shard,
115+
# since we don't dispatch based on shards.
116+
# Instead label this wrapper as a trivially replicable wrapper so that
117+
# _TEST_LAST_OP_DISPATCH tracking can handle it correctly.
118+
# _TEST_LAST_OP_DISPATCH will not update for this wrapper, but instead allow the last inner call to set it.
119+
func_wrapper._trivially_replicable_wrapper = f._trivially_replicable_wrapper
120+
else:
121+
# We know which underlying op will be called, set _unwrapped to the original op
122+
# so that _TEST_LAST_OP_DISPATCH tracking can handle it correctly.
123+
func_wrapper._unwrapped = unwrap_if_possible(f)
124+
return func_wrapper
125+
126+
127+
def sharded_wrap_override():
126128
def wrap_override(signature_dispatcher_override):
127129
"""
128130
Wrap [op].override's result so that the transfer_n_pin(f) becomes the target in _TargetOverride rather than f itself.

sharktank/tests/ops/dispatch_test.py

Lines changed: 39 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -11,19 +11,23 @@
1111
from parameterized import parameterized
1212

1313
from sharktank import ops
14-
from sharktank.ops.sharded_impls import zeros_replicated
15-
from sharktank.types import DefaultPrimitiveTensor
14+
from sharktank.types import DefaultPrimitiveTensor, InferenceTensor, ReplicatedTensor
15+
from sharktank.ops.sharded_impls import zeros_replicated, transfer_n_pin
1616
from sharktank.ops.default_impls import abs_default, cos_default, zeros_default
17-
from sharktank.types.tensors import InferenceTensor, ReplicatedTensor
18-
from sharktank.utils.testing import assert_tensor_close
17+
from sharktank.ops._registry import (
18+
unwrap_if_possible,
19+
_test_enable_last_op_dispatch,
20+
_test_get_last_op_dispatch,
21+
)
22+
from sharktank.ops.utils import trivially_replicable
1923

2024

2125
class DispatchTest(unittest.TestCase):
2226
def setUp(self):
23-
ops._registry._test_enable_last_op_dispatch(True)
27+
_test_enable_last_op_dispatch(True)
2428

2529
def tearDown(self):
26-
ops._registry._test_enable_last_op_dispatch(False)
30+
_test_enable_last_op_dispatch(False)
2731

2832
def make_tensor(self, shards: int) -> InferenceTensor:
2933
tensor = torch.tensor([1.0, 2.0, 3.0])
@@ -34,36 +38,56 @@ def make_tensor(self, shards: int) -> InferenceTensor:
3438

3539
@parameterized.expand(
3640
[
37-
(None, zeros_default),
38-
((2, 3), zeros_replicated._unwrapped),
41+
(None, unwrap_if_possible(zeros_default)),
42+
((2, 3), unwrap_if_possible(zeros_replicated)),
3943
]
4044
)
4145
def test_non_trivially_replicable_op(
4246
self, devices: tuple[int, ...] | None, expected_dispatch: Callable
4347
):
4448
ops.zeros(1, devices=devices)
4549

46-
last_dispatch_after_zeros = ops._registry._test_get_last_op_dispatch()
50+
last_dispatch_after_zeros = _test_get_last_op_dispatch()
4751
self.assertIs(last_dispatch_after_zeros, expected_dispatch)
4852

4953
@parameterized.expand([(1,), (2,)])
5054
def test_trivially_replicable_op(self, shard_count: int):
5155
self.make_tensor(shard_count).abs()
5256

53-
last_dispatch = ops._registry._test_get_last_op_dispatch()
54-
self.assertIs(last_dispatch, abs_default)
57+
last_dispatch = _test_get_last_op_dispatch()
58+
self.assertIs(last_dispatch, unwrap_if_possible(abs_default))
5559

5660
@parameterized.expand([(1,), (2,)])
5761
def test_multiple_dispatch(self, shard_count: int):
5862
tensor = self.make_tensor(shard_count)
5963

6064
tensor.abs()
61-
last_dispatch_after_abs = ops._registry._test_get_last_op_dispatch()
62-
self.assertIs(last_dispatch_after_abs, abs_default)
65+
last_dispatch_after_abs = _test_get_last_op_dispatch()
66+
self.assertIs(last_dispatch_after_abs, unwrap_if_possible(abs_default))
6367

6468
tensor.cos()
65-
last_dispatch_after_cos = ops._registry._test_get_last_op_dispatch()
66-
self.assertIs(last_dispatch_after_cos, cos_default)
69+
last_dispatch_after_cos = _test_get_last_op_dispatch()
70+
self.assertIs(last_dispatch_after_cos, unwrap_if_possible(cos_default))
71+
72+
73+
def f(*args, **kwargs) -> torch.Tensor:
74+
...
75+
76+
77+
class UnwrapIfPossibleTest(unittest.TestCase):
78+
def test_unwrap_no_wrapper(self):
79+
self.assertIs(unwrap_if_possible(f), f)
80+
81+
@parameterized.expand([transfer_n_pin])
82+
def test_unwrap_with_wrapper(self, wrapping_fn: Callable):
83+
f_wrapped = wrapping_fn(f)
84+
self.assertIsNot(f_wrapped, f)
85+
self.assertIs(unwrap_if_possible(f_wrapped), f)
86+
87+
def test_unwrap_with_trivially_replicable_wrapper(self):
88+
f_wrapped = trivially_replicable(f)
89+
self.assertIsNot(unwrap_if_possible(f_wrapped), f)
90+
assert f_wrapped._trivially_replicable_wrapper
6791

6892

6993
if __name__ == "__main__":

0 commit comments

Comments
 (0)