Skip to content

Commit f4c47c9

Browse files
authored
[DTensor] Add prim and torch symbol for add (#2581)
1 parent c8ef0c4 commit f4c47c9

File tree

5 files changed

+163
-76
lines changed

5 files changed

+163
-76
lines changed

thunder/clang/__init__.py

Lines changed: 9 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22
from collections.abc import Callable, Sequence
3-
from functools import partial, reduce
3+
from functools import partial, reduce, update_wrapper
44
from numbers import Number
55
from types import EllipsisType, NoneType
66
from typing import Any, Union
@@ -10,6 +10,7 @@
1010

1111
from thunder.clang.langctx import register_method
1212
from thunder.clang.utils import create_maybe_convert_to_dtype_with_prim, _elementwise_unary_wrapper
13+
import thunder.clang.utils as clang_utils
1314
from thunder.core import utils
1415
from thunder.core.baseutils import run_once
1516
from thunder.core.langctxs import langctx, Languages
@@ -368,33 +369,13 @@ def diagonal(a: TensorLike, offset: int = 0, dim1: int = 0, dim2: int = 1) -> Te
368369

369370
# Expands a to the specified shape, possibly adding new dimensions and expanding
370371
# dimensions of length 1 to any length
371-
@clangop()
372-
def expand(a: TensorLike, *shape: int) -> TensorLike:
373-
shape = utils.extract_shape_from_varargs(shape)
374-
375-
# TODO: improve this error message with error context
376-
utils.check(
377-
len(shape) >= len(a.shape),
378-
lambda: "expand: the requested shape has too few dimensions!",
379-
)
380-
381-
offset = len(shape) - len(a.shape)
382-
shape_ = list(shape)
383-
for idx, x in enumerate(a.shape):
384-
offset_idx = idx + offset
385-
requested_length = shape[offset_idx]
386-
utils.check(
387-
requested_length == x or x == 1 or requested_length == -1,
388-
lambda: f"expand: attempting to expand a dimension of length {x}!",
389-
)
390-
391-
shape_[offset_idx] = requested_length if requested_length != -1 else x
392-
393-
# At this point shape must be valid
394-
# utils.check_valid_shape(shape_)
372+
expand = clangop()(partial(clang_utils.expand_impl, broadcast_prim=prims.broadcast_in_dim))
373+
# To preserve the docstring
374+
update_wrapper(expand, clang_utils.expand_impl)
395375

396-
# NOTE: Converting shape_ to tuple makes it possible to apply CSE
397-
return prims.broadcast_in_dim(a, tuple(shape_), tuple(range(offset, len(a.shape) + offset)))
376+
maybe_broadcast = clangop()(partial(clang_utils.maybe_broadcast_impl, expand_fn=expand))
377+
# To preserve the docstring
378+
update_wrapper(maybe_broadcast, clang_utils.maybe_broadcast_impl)
398379

399380

400381
# TODO Resolve the start & end vs. start & stop inconsistencies with our operators (this one is start & end)
@@ -1085,31 +1066,7 @@ def stack(tensors: list[TensorProxy], dim: int):
10851066
return cat(tensors_, dim)
10861067

10871068

1088-
@clangop()
1089-
def compute_broadcast_shape(*_shapes):
1090-
"""Computes the common shape with the fewest dimensions that all input shapes can be broadcast to."""
1091-
shapes = tuple(x for x in filter(lambda x: x is not None, _shapes))
1092-
1093-
# Short-circuits if there are no inputs shapes
1094-
# This might happen in calls like add(2, 3)
1095-
if len(shapes) == 0:
1096-
return None
1097-
1098-
common_shape = [
1099-
1,
1100-
] * reduce(max, (len(shape) for shape in shapes))
1101-
1102-
for shape in shapes:
1103-
for idx in range(-1, -1 - len(shape), -1):
1104-
if common_shape[idx] == 1:
1105-
common_shape[idx] = shape[idx]
1106-
1107-
utils.check(
1108-
(shape[idx] == 1) or (common_shape[idx] == shape[idx]),
1109-
lambda: f"Attempting to broadcast a dimension of length {shape[idx]}!",
1110-
)
1111-
1112-
return tuple(common_shape)
1069+
compute_broadcast_shape = clangop()(clang_utils.compute_broadcast_shape)
11131070

11141071

11151072
@run_once
@@ -1155,28 +1112,6 @@ def matrix_transpose(a: TensorProxy) -> TensorProxy:
11551112
return transpose(a, permutation)
11561113

11571114

1158-
# TODO: add scalar support
1159-
# TODO: review hasattr pattern
1160-
# NOTE: the tensor is not broadcasted if it is a CPU scalar tensor and treat_cpu_scalar_tensors_as_numbers=True
1161-
@clangop()
1162-
def maybe_broadcast(*args, treat_cpu_scalar_tensors_as_numbers=True):
1163-
"""Returns tensors with the same shape, possibly broadcasting inputs to the result shape."""
1164-
1165-
# Computes common shape
1166-
common_shape = compute_broadcast_shape(*map(lambda t: t.shape if hasattr(t, "shape") else None, args))
1167-
1168-
def _maybe_broadcast(x, shape):
1169-
if treat_cpu_scalar_tensors_as_numbers and utils.is_cpu_scalar_tensor(x):
1170-
return x
1171-
if hasattr(x, "shape"):
1172-
if not utils.same_shape(x.shape, common_shape):
1173-
return expand(x, common_shape)
1174-
1175-
return x
1176-
1177-
return tuple(_maybe_broadcast(x, common_shape) for x in args)
1178-
1179-
11801115
#
11811116
# Elementwise unary operations
11821117
#

thunder/clang/utils.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from numbers import Number
22
from collections.abc import Sequence
33
from collections.abc import Callable
4+
from functools import reduce
45

56
from thunder.core import utils
67
import thunder.core.dtypes as dtypes
@@ -11,6 +12,8 @@
1112
TensorProxy,
1213
)
1314

15+
TensorLike = TensorProxy
16+
1417

1518
def create_maybe_convert_to_dtype_with_prim(conversion_prim: Symbol):
1619
assert isinstance(conversion_prim, Symbol)
@@ -66,3 +69,78 @@ def _elementwise_unary_wrapper(
6669
result = dtype_conversion_fn(result, result_dtype)
6770

6871
return result
72+
73+
74+
def compute_broadcast_shape(*_shapes):
75+
"""Computes the common shape with the fewest dimensions that all input shapes can be broadcast to."""
76+
shapes = tuple(x for x in filter(lambda x: x is not None, _shapes))
77+
78+
# Short-circuits if there are no inputs shapes
79+
# This might happen in calls like add(2, 3)
80+
if len(shapes) == 0:
81+
return None
82+
83+
common_shape = [
84+
1,
85+
] * reduce(max, (len(shape) for shape in shapes))
86+
87+
for shape in shapes:
88+
for idx in range(-1, -1 - len(shape), -1):
89+
if common_shape[idx] == 1:
90+
common_shape[idx] = shape[idx]
91+
92+
utils.check(
93+
(shape[idx] == 1) or (common_shape[idx] == shape[idx]),
94+
lambda: f"Attempting to broadcast a dimension of length {shape[idx]}!",
95+
)
96+
97+
return tuple(common_shape)
98+
99+
100+
def expand_impl(a: TensorLike, *shape: int, broadcast_prim: Symbol) -> TensorLike:
101+
shape = utils.extract_shape_from_varargs(shape)
102+
103+
# TODO: improve this error message with error context
104+
utils.check(
105+
len(shape) >= len(a.shape),
106+
lambda: "expand: the requested shape has too few dimensions!",
107+
)
108+
109+
offset = len(shape) - len(a.shape)
110+
shape_ = list(shape)
111+
for idx, x in enumerate(a.shape):
112+
offset_idx = idx + offset
113+
requested_length = shape[offset_idx]
114+
utils.check(
115+
requested_length == x or x == 1 or requested_length == -1,
116+
lambda: f"expand: attempting to expand a dimension of length {x}!",
117+
)
118+
119+
shape_[offset_idx] = requested_length if requested_length != -1 else x
120+
121+
# At this point shape must be valid
122+
# utils.check_valid_shape(shape_)
123+
124+
# NOTE: Converting shape_ to tuple makes it possible to apply CSE
125+
return broadcast_prim(a, tuple(shape_), tuple(range(offset, len(a.shape) + offset)))
126+
127+
128+
# TODO: add scalar support
129+
# TODO: review hasattr pattern
130+
# NOTE: the tensor is not broadcasted if it is a CPU scalar tensor and treat_cpu_scalar_tensors_as_numbers=True
131+
def maybe_broadcast_impl(*args, treat_cpu_scalar_tensors_as_numbers=True, expand_fn: Callable):
132+
"""Returns tensors with the same shape, possibly broadcasting inputs to the result shape."""
133+
134+
# Computes common shape
135+
common_shape = compute_broadcast_shape(*map(lambda t: t.shape if hasattr(t, "shape") else None, args))
136+
137+
def _maybe_broadcast(x, shape):
138+
if treat_cpu_scalar_tensors_as_numbers and utils.is_cpu_scalar_tensor(x):
139+
return x
140+
if hasattr(x, "shape"):
141+
if not utils.same_shape(x.shape, common_shape):
142+
return expand_fn(x, common_shape)
143+
144+
return x
145+
146+
return tuple(_maybe_broadcast(x, common_shape) for x in args)

thunder/executors/nvfuserex_impl.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1772,6 +1772,7 @@ def _add(a: TensorProxy | Number, b: TensorProxy | Number, *, fd: FusionDefiniti
17721772

17731773

17741774
register_supported(PrimIDs.ADD, _add, _elementwise_binary_check)
1775+
register_dtensor_supported(DTensorPrimIDs.ADD, _add, _elementwise_binary_check)
17751776

17761777

17771778
def atan2(a: TensorProxy | Number, b: TensorProxy | Number, *, fd: FusionDefinition, lc_to_nv_map: dict) -> Any:

thunder/tests/distributed/test_dtensor.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,17 @@
4040
# to choose between DTensor supported symbol (from `dtensor_torch_and_prims.py`) or the usual `ltorch` symbol.
4141
# This is why we need to make sure that the OpInfo uses PyTorch native op as `op` which is passed to thunder.jit.
4242
class DTensorOpInfo:
43-
def __init__(self, *, name, op, torch_reference, supports_grad, sample_inputs, skip_noncontiguous_for_executor=()):
43+
def __init__(
44+
self,
45+
*,
46+
name,
47+
op,
48+
torch_reference,
49+
supports_grad,
50+
sample_inputs,
51+
skip_noncontiguous_for_executor=(),
52+
skip_for_executor=(),
53+
):
4454
self.name = name
4555
assert "torch" in op.__module__, "OpInfo must use PyTorch native op as `op` which is passed to thunder.jit"
4656
self.op = op
@@ -54,6 +64,9 @@ def __init__(self, *, name, op, torch_reference, supports_grad, sample_inputs, s
5464
assert isinstance(skip_noncontiguous_for_executor, tuple), "skip_noncontiguous_for_executor must be a tuple"
5565
self.skip_noncontiguous_for_executor = skip_noncontiguous_for_executor
5666

67+
assert isinstance(skip_for_executor, tuple), "skip_for_executor must be a tuple"
68+
self.skip_for_executor = skip_for_executor
69+
5770

5871
# DTensor supported ops
5972
dtensor_supported_opinfos = (
@@ -98,6 +111,15 @@ def __init__(self, *, name, op, torch_reference, supports_grad, sample_inputs, s
98111
# Ref:https://github.com/NVIDIA/Fuser/pull/5124
99112
skip_noncontiguous_for_executor=("nvfuser",),
100113
),
114+
DTensorOpInfo(
115+
name="add",
116+
op=torch.add,
117+
torch_reference=torch.add,
118+
supports_grad=False,
119+
sample_inputs=get_opinfo("add").sample_inputs,
120+
# Ref:https://github.com/NVIDIA/Fuser/issues/5314
121+
skip_for_executor=("nvfuser",),
122+
),
101123
)
102124

103125
skip_opinfos = (
@@ -309,6 +331,9 @@ def test_dtensor_opinfo(self, op: OpInfo, executor):
309331
if op.name in skip_opinfos:
310332
raise unittest.SkipTest(f"test_dtensor_opinfo: Skipping {op.name} as it is in skip_opinfos")
311333

334+
if executor in op.skip_for_executor:
335+
raise unittest.SkipTest(f"test_dtensor_opinfo: Skipping {op.name} as it is in skip_for_executor")
336+
312337
# NOTE: This test only tests for dtype=torch.float32 and requires_grad=True
313338
# not for all dtype which are supported by the operation.
314339
num_devices = self.world_size

thunder/torch/experimental/dtensor_torch_and_prims.py

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,12 @@
77
import thunder.torch as ltorch
88
from thunder.core.pytree import tree_flatten
99
from thunder import clang
10-
from thunder.clang.utils import create_maybe_convert_to_dtype_with_prim, _elementwise_unary_wrapper
10+
from thunder.clang.utils import (
11+
create_maybe_convert_to_dtype_with_prim,
12+
_elementwise_unary_wrapper,
13+
maybe_broadcast_impl,
14+
expand_impl,
15+
)
1116
from thunder.torch.experimental.dtensor_utils import run_with_fake_tensor
1217
from thunder.torch.experimental.dtensor_proxy import DTensorProxy, create_dtensor_proxy_from_proxies
1318
from thunder.torch.langctx import register_method
@@ -33,6 +38,7 @@
3338
class DTensorPrimIDs(Enum):
3439
# DTensor-specific primitives
3540
CHECK_DTENSOR_SPEC_REPR = auto()
41+
ADD = auto()
3642
MUL = auto()
3743
RESHAPE = auto()
3844
CONVERT_ELEMENT_TYPE = auto()
@@ -365,6 +371,47 @@ def dtensor_reciprocal(a: TensorLike) -> TensorLike:
365371
)
366372

367373

374+
expand = partial(expand_impl, broadcast_prim=dtensor_broadcast_in_dim_prim)
375+
maybe_broadcast = partial(maybe_broadcast_impl, expand_fn=expand)
376+
377+
378+
def _elementwise_binary_wrapper(a, b, *, prim, type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT):
379+
computation_dtype, result_dtype = utils.elementwise_type_promotion(a, b, type_promotion_kind=type_promotion_kind)
380+
381+
a, b = maybe_broadcast(a, b)
382+
a, b = maybe_convert_to_dtype(a, computation_dtype), maybe_convert_to_dtype(b, computation_dtype)
383+
384+
result = prim(a, b)
385+
result = maybe_convert_to_dtype(result, result_dtype)
386+
387+
return result
388+
389+
390+
def dtensor_add_meta(a, b):
391+
output = run_with_fake_tensor(torch.add, a, b)
392+
local_tensor_proxy = TensorProxy(like=a.local_tensor)
393+
spec = output._spec
394+
spec_proxy = AnyProxy(spec, history=a.history)
395+
return create_dtensor_proxy_from_proxies(local_tensor_proxy, spec_proxy, False)
396+
397+
398+
dtensor_add_prim = make_prim(DTensorPrimIDs.ADD, "dtensor_add_prim", meta=dtensor_add_meta)
399+
400+
dtensor_add_prim_impl = pytorchex.register_operator("dtensor_add_prim", like=dtensor_add_prim, fn=torch.add)
401+
402+
pytorchex.register_implementation(dtensor_add_prim, dtensor_add_prim_impl)
403+
404+
405+
@dtensor_torchsymbol(torch.add, id="dtensor.torch.add")
406+
def dtensor_add(a: TensorLike, b: TensorLike) -> TensorLike:
407+
return _elementwise_binary_wrapper(
408+
a,
409+
b,
410+
prim=dtensor_add_prim,
411+
type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
412+
)
413+
414+
368415
if LooseVersion(torch.__version__) >= "2.8":
369416

370417
def dtensor_grouped_mm_meta(a, b, offsets):
@@ -394,6 +441,7 @@ def dtensor_grouped_mm(a: TensorLike, b: TensorLike, offsets: TensorLike, *, bia
394441

395442

396443
def register_dtensor_torch_and_prims():
444+
register_function_for_dtensor(torch.add, ltorch.add, dtensor_add, is_method=True)
397445
register_function_for_dtensor(torch.mul, ltorch.mul, dtensor_mul, is_method=True)
398446
register_function_for_dtensor(torch.reshape, ltorch.reshape, dtensor_reshape, is_method=True)
399447
register_function_for_dtensor(torch.nn.functional.linear, ltorch.linear, dtensor_linear, is_method=False)

0 commit comments

Comments
 (0)