Skip to content

Commit 65193b0

Browse files
authored
Add DTensor prim and torch symbol for exp (#2496)
1 parent 157ac80 commit 65193b0

File tree

6 files changed

+140
-53
lines changed

6 files changed

+140
-53
lines changed

thunder/clang/__init__.py

Lines changed: 3 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import warnings
1010

1111
from thunder.clang.langctx import register_method
12+
from thunder.clang.utils import create_maybe_convert_to_dtype_with_prim, _elementwise_unary_wrapper
1213
from thunder.core import utils
1314
from thunder.core.baseutils import run_once
1415
from thunder.core.langctxs import langctx, Languages
@@ -140,39 +141,7 @@ def construct_tuple(tup: tuple, /) -> tuple:
140141

141142

142143
# TODO Review revising enforce_safe_casting to be more like NumPy's
143-
@clangop()
144-
def maybe_convert_to_dtype(a, dtype, *, enforce_safe_casting=False):
145-
"""If a has the same dtype as the given dtype, returns a unmodified.
146-
147-
Otherwise returns a converted to the given dtype.
148-
"""
149-
150-
utils.check(utils.is_dtype(dtype), lambda: f"Unknown dtype {dtype}!")
151-
152-
if isinstance(a, Sequence):
153-
return tuple(maybe_convert_to_dtype(x, dtype) for x in a)
154-
if isinstance(a, TensorProxy):
155-
# Translates numbertypes to dtypes
156-
if dtypes.is_numbertype(dtype):
157-
dtype = dtypes.numbertype_to_dtype(dtype)
158-
elif isinstance(a, (Number, NumberProxy)):
159-
# NOTE This allows conversions like (5, float32) -> 5., which is a little odd
160-
dtype = utils.dtype_to_numbertype(dtype)
161-
else:
162-
raise ValueError(
163-
f"Trying to convert the type of the data of an unknown object {a} of {type(a)} that is neither a tensor, number, or sequence!"
164-
)
165-
166-
if not utils.are_same_dtypes(a, dtype):
167-
if enforce_safe_casting:
168-
utils.check(
169-
utils.can_safe_cast_to(cast_from=utils.to_dtype(a), cast_to=dtype),
170-
lambda: f"Can't safe case from a={a} with dtype {utils.to_dtype(a)} to {dtype}!",
171-
)
172-
173-
return prims.convert_element_type(a, dtype)
174-
175-
return a
144+
maybe_convert_to_dtype = clangop()(create_maybe_convert_to_dtype_with_prim(prims.convert_element_type))
176145

177146

178147
# TODO Consider maybe_device_put analogous to maybe_convert_to_dtype above
@@ -1212,22 +1181,7 @@ def _maybe_broadcast(x, shape):
12121181
# Elementwise unary operations
12131182
#
12141183
# TODO Consider annotating these operators with kind and type promotion information
1215-
1216-
1217-
# TODO Add supported dtypes
1218-
def _elementwise_unary_wrapper(
1219-
a,
1220-
*,
1221-
prim,
1222-
type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
1223-
):
1224-
computation_dtype, result_dtype = utils.elementwise_type_promotion(a, type_promotion_kind=type_promotion_kind)
1225-
1226-
a = maybe_convert_to_dtype(a, computation_dtype)
1227-
result = prim(a)
1228-
result = maybe_convert_to_dtype(result, result_dtype)
1229-
1230-
return result
1184+
_elementwise_unary_wrapper = partial(_elementwise_unary_wrapper, dtype_conversion_fn=maybe_convert_to_dtype)
12311185

12321186

12331187
# TODO Return self for bool and uint datatypes?

thunder/clang/utils.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
from numbers import Number
2+
from collections.abc import Sequence
3+
from collections.abc import Callable
4+
5+
from thunder.core import utils
6+
import thunder.core.dtypes as dtypes
7+
from thunder.core.symbol import Symbol
8+
9+
from thunder.core.proxies import (
10+
NumberProxy,
11+
TensorProxy,
12+
)
13+
14+
15+
def create_maybe_convert_to_dtype_with_prim(conversion_prim: Symbol):
16+
assert isinstance(conversion_prim, Symbol)
17+
18+
def maybe_convert_to_dtype(a, dtype, *, enforce_safe_casting=False):
19+
"""If a has the same dtype as the given dtype, returns a unmodified.
20+
21+
Otherwise returns a converted to the given dtype.
22+
"""
23+
24+
utils.check(utils.is_dtype(dtype), lambda: f"Unknown dtype {dtype}!")
25+
26+
if isinstance(a, Sequence):
27+
return tuple(maybe_convert_to_dtype(x, dtype) for x in a)
28+
if isinstance(a, TensorProxy):
29+
# Translates numbertypes to dtypes
30+
if dtypes.is_numbertype(dtype):
31+
dtype = dtypes.numbertype_to_dtype(dtype)
32+
elif isinstance(a, (Number, NumberProxy)):
33+
# NOTE This allows conversions like (5, float32) -> 5., which is a little odd
34+
dtype = utils.dtype_to_numbertype(dtype)
35+
else:
36+
raise ValueError(
37+
f"Trying to convert the type of the data of an unknown object {a} of {type(a)} that is neither a tensor, number, or sequence!"
38+
)
39+
40+
if not utils.are_same_dtypes(a, dtype):
41+
if enforce_safe_casting:
42+
utils.check(
43+
utils.can_safe_cast_to(cast_from=utils.to_dtype(a), cast_to=dtype),
44+
lambda: f"Can't safe case from a={a} with dtype {utils.to_dtype(a)} to {dtype}!",
45+
)
46+
47+
return conversion_prim(a, dtype)
48+
49+
return a
50+
51+
return maybe_convert_to_dtype
52+
53+
54+
# TODO Add supported dtypes
55+
def _elementwise_unary_wrapper(
56+
a,
57+
*,
58+
prim,
59+
type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
60+
dtype_conversion_fn: Callable[[TensorProxy | NumberProxy, dtypes.dtype], TensorProxy | NumberProxy],
61+
):
62+
computation_dtype, result_dtype = utils.elementwise_type_promotion(a, type_promotion_kind=type_promotion_kind)
63+
64+
a = dtype_conversion_fn(a, computation_dtype)
65+
result = prim(a)
66+
result = dtype_conversion_fn(result, result_dtype)
67+
68+
return result

thunder/executors/nvfuserex_impl.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1541,6 +1541,7 @@ def exp(a: TensorProxy | Number, *, fd: FusionDefinition, lc_to_nv_map: dict) ->
15411541

15421542

15431543
register_supported(PrimIDs.EXP, exp, _elementwise_unary_check)
1544+
register_supported(DTensorPrimIDs.EXP, exp, _elementwise_unary_check)
15441545

15451546

15461547
def exp2(a: TensorProxy | Number, *, fd: FusionDefinition, lc_to_nv_map: dict) -> Any:

thunder/tests/distributed/test_dtensor.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
# to choose between DTensor supported symbol (from `dtensor_torch_and_prims.py`) or the usual `ltorch` symbol.
4040
# This is why we need to make sure that the OpInfo uses PyTorch native op as `op` which is passed to thunder.jit.
4141
class DTensorOpInfo:
42-
def __init__(self, *, name, op, torch_reference, supports_grad, sample_inputs):
42+
def __init__(self, *, name, op, torch_reference, supports_grad, sample_inputs, skip_noncontiguous_for_executor=()):
4343
self.name = name
4444
assert "torch" in op.__module__, "OpInfo must use PyTorch native op as `op` which is passed to thunder.jit"
4545
self.op = op
@@ -49,6 +49,10 @@ def __init__(self, *, name, op, torch_reference, supports_grad, sample_inputs):
4949
# NOTE: This should generally reuse the sample_inputs from the OpInfo
5050
self.sample_inputs = sample_inputs
5151

52+
# In some cases, non-contiguous inputs are not supported by the executor.
53+
assert isinstance(skip_noncontiguous_for_executor, tuple), "skip_noncontiguous_for_executor must be a tuple"
54+
self.skip_noncontiguous_for_executor = skip_noncontiguous_for_executor
55+
5256

5357
# DTensor supported ops
5458
dtensor_supported_opinfos = (
@@ -66,6 +70,15 @@ def __init__(self, *, name, op, torch_reference, supports_grad, sample_inputs):
6670
supports_grad=False,
6771
sample_inputs=get_opinfo("linear").sample_inputs,
6872
),
73+
DTensorOpInfo(
74+
name="exp",
75+
op=torch.exp,
76+
torch_reference=torch.exp,
77+
supports_grad=True,
78+
sample_inputs=get_opinfo("exp").sample_inputs,
79+
# Ref:https://github.com/NVIDIA/Fuser/pull/5124
80+
skip_noncontiguous_for_executor=("nvfuser",),
81+
),
6982
)
7083

7184
skip_opinfos = (
@@ -238,6 +251,10 @@ def test_dtensor_opinfo(self, op: OpInfo, executor):
238251
tested_sample_count = 0
239252

240253
for sample in op.sample_inputs("cpu", dtypes.float32, requires_grad=op.supports_grad):
254+
# Skip if non-contiguous inputs are not supported by the executor.
255+
if executor in op.skip_noncontiguous_for_executor and not sample.args[0].is_contiguous():
256+
continue
257+
241258
# DTensorConverter converts inputs tensors to DTensor and creates DTensor
242259
# with possible placements based on the input shapes.
243260
# See - https://github.com/pytorch/pytorch/blob/eaa5d9d3d3dc642832b269b184f0c3ab8c990274/torch/testing/_internal/distributed/_tensor/common_dtensor.py#L521

thunder/torch/experimental/dtensor_torch_and_prims.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import thunder.torch as ltorch
77
from thunder.core.pytree import tree_flatten
88
from thunder import clang
9+
from thunder.clang.utils import create_maybe_convert_to_dtype_with_prim, _elementwise_unary_wrapper
910
from thunder.torch.experimental.dtensor_utils import run_with_fake_tensor
1011
from thunder.torch.experimental.dtensor_proxy import DTensorProxy, create_dtensor_proxy_from_proxies
1112
from thunder.torch.langctx import register_method
@@ -35,6 +36,7 @@ class DTensorPrimIDs(Enum):
3536
RESHAPE = auto()
3637
CONVERT_ELEMENT_TYPE = auto()
3738
BROADCAST_IN_DIM = auto()
39+
EXP = auto()
3840
LINEAR = auto()
3941

4042

@@ -242,6 +244,10 @@ def dtensor_broadcast_in_dim_meta(a, shape, broadcast_dimensions):
242244
pytorchex.register_implementation(dtensor_broadcast_in_dim_prim, dtensor_broadcast_in_dim_prim_impl)
243245

244246

247+
maybe_convert_to_dtype = create_maybe_convert_to_dtype_with_prim(dtensor_convert_element_type_prim)
248+
_elementwise_unary_wrapper = partial(_elementwise_unary_wrapper, dtype_conversion_fn=maybe_convert_to_dtype)
249+
250+
245251
def dtensor_linear_meta(a, w, bias):
246252
output = run_with_fake_tensor(torch.nn.functional.linear, a, w, bias)
247253
local_tensor_proxy = TensorProxy(like=a.local_tensor)
@@ -268,7 +274,45 @@ def dtensor_linear(a: TensorLike, w: TensorLike, bias: None | TensorLike = None)
268274
return dtensor_linear_prim(a, w, bias)
269275

270276

277+
def dtensor_exp_meta(a):
278+
output = run_with_fake_tensor(torch.exp, a)
279+
local_tensor_proxy = TensorProxy(like=a.local_tensor)
280+
spec = output._spec
281+
spec_proxy = AnyProxy(spec, history=a.history)
282+
return create_dtensor_proxy_from_proxies(local_tensor_proxy, spec_proxy, False)
283+
284+
285+
dtensor_exp_prim = make_prim(DTensorPrimIDs.EXP, "dtensor_exp_prim", meta=dtensor_exp_meta)
286+
287+
dtensor_exp_prim_impl = pytorchex.register_operator("dtensor_exp_prim", like=dtensor_exp_prim, fn=torch.exp)
288+
289+
pytorchex.register_implementation(dtensor_exp_prim, dtensor_exp_prim_impl)
290+
291+
292+
def _dtensor_exp_prim_grad(a: TensorLike) -> TensorLike:
293+
fwd = dtensor_exp_prim(a)
294+
295+
g = get_grad(fwd)
296+
a_grad = g * fwd
297+
put_grad(a, a_grad)
298+
299+
return fwd
300+
301+
302+
register_grad(dtensor_exp_prim, _dtensor_exp_prim_grad)
303+
304+
305+
@dtensor_torchsymbol(torch.exp, id="dtensor.torch.exp")
306+
def dtensor_exp(a: TensorLike) -> TensorLike:
307+
return _elementwise_unary_wrapper(
308+
a,
309+
prim=dtensor_exp_prim,
310+
type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
311+
)
312+
313+
271314
def register_dtensor_torch_and_prims():
272315
register_function_for_dtensor(torch.mul, ltorch.mul, dtensor_mul, is_method=True)
273316
register_function_for_dtensor(torch.reshape, ltorch.reshape, dtensor_reshape, is_method=True)
274317
register_function_for_dtensor(torch.nn.functional.linear, ltorch.linear, dtensor_linear, is_method=False)
318+
register_function_for_dtensor(torch.exp, ltorch.exp, dtensor_exp, is_method=True)

thunder/torch/experimental/dtensor_utils.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,14 +58,17 @@ def materialize_fake_tensors(t):
5858
return t
5959

6060
if isinstance(t, DTensorProxy):
61-
i_t = torch.randn(
61+
i_t = torch.ones(
6262
t.local_tensor.shape,
6363
device=to_torch_device(t.local_tensor.device),
6464
dtype=to_torch_dtype(t.local_tensor.dtype),
6565
)
66-
return DTensor.from_local(i_t, t.spec._o.device_mesh, t.spec._o.placements)
6766

68-
return torch.randn(t.shape, device=to_torch_device(t.device), dtype=to_torch_dtype(t.dtype))
67+
shape = t.spec._o.tensor_meta.shape if t.spec._o.tensor_meta is not None else None
68+
stride = t.spec._o.tensor_meta.stride if t.spec._o.tensor_meta is not None else None
69+
return DTensor.from_local(i_t, t.spec._o.device_mesh, t.spec._o.placements, shape=shape, stride=stride)
70+
71+
return torch.ones(t.shape, device=to_torch_device(t.device), dtype=to_torch_dtype(t.dtype))
6972

7073
args, kwargs = tree_map(materialize_fake_tensors, (args, kwargs))
7174

0 commit comments

Comments
 (0)