6
6
import thunder .torch as ltorch
7
7
from thunder .core .pytree import tree_flatten
8
8
from thunder import clang
9
+ from thunder .clang .utils import create_maybe_convert_to_dtype_with_prim , _elementwise_unary_wrapper
9
10
from thunder .torch .experimental .dtensor_utils import run_with_fake_tensor
10
11
from thunder .torch .experimental .dtensor_proxy import DTensorProxy , create_dtensor_proxy_from_proxies
11
12
from thunder .torch .langctx import register_method
@@ -35,6 +36,7 @@ class DTensorPrimIDs(Enum):
35
36
RESHAPE = auto ()
36
37
CONVERT_ELEMENT_TYPE = auto ()
37
38
BROADCAST_IN_DIM = auto ()
39
+ EXP = auto ()
38
40
LINEAR = auto ()
39
41
40
42
@@ -242,6 +244,10 @@ def dtensor_broadcast_in_dim_meta(a, shape, broadcast_dimensions):
242
244
pytorchex .register_implementation (dtensor_broadcast_in_dim_prim , dtensor_broadcast_in_dim_prim_impl )
243
245
244
246
247
+ maybe_convert_to_dtype = create_maybe_convert_to_dtype_with_prim (dtensor_convert_element_type_prim )
248
+ _elementwise_unary_wrapper = partial (_elementwise_unary_wrapper , dtype_conversion_fn = maybe_convert_to_dtype )
249
+
250
+
245
251
def dtensor_linear_meta (a , w , bias ):
246
252
output = run_with_fake_tensor (torch .nn .functional .linear , a , w , bias )
247
253
local_tensor_proxy = TensorProxy (like = a .local_tensor )
@@ -268,7 +274,45 @@ def dtensor_linear(a: TensorLike, w: TensorLike, bias: None | TensorLike = None)
268
274
return dtensor_linear_prim (a , w , bias )
269
275
270
276
277
+ def dtensor_exp_meta (a ):
278
+ output = run_with_fake_tensor (torch .exp , a )
279
+ local_tensor_proxy = TensorProxy (like = a .local_tensor )
280
+ spec = output ._spec
281
+ spec_proxy = AnyProxy (spec , history = a .history )
282
+ return create_dtensor_proxy_from_proxies (local_tensor_proxy , spec_proxy , False )
283
+
284
+
285
+ dtensor_exp_prim = make_prim (DTensorPrimIDs .EXP , "dtensor_exp_prim" , meta = dtensor_exp_meta )
286
+
287
+ dtensor_exp_prim_impl = pytorchex .register_operator ("dtensor_exp_prim" , like = dtensor_exp_prim , fn = torch .exp )
288
+
289
+ pytorchex .register_implementation (dtensor_exp_prim , dtensor_exp_prim_impl )
290
+
291
+
292
+ def _dtensor_exp_prim_grad (a : TensorLike ) -> TensorLike :
293
+ fwd = dtensor_exp_prim (a )
294
+
295
+ g = get_grad (fwd )
296
+ a_grad = g * fwd
297
+ put_grad (a , a_grad )
298
+
299
+ return fwd
300
+
301
+
302
+ register_grad (dtensor_exp_prim , _dtensor_exp_prim_grad )
303
+
304
+
305
+ @dtensor_torchsymbol (torch .exp , id = "dtensor.torch.exp" )
306
+ def dtensor_exp (a : TensorLike ) -> TensorLike :
307
+ return _elementwise_unary_wrapper (
308
+ a ,
309
+ prim = dtensor_exp_prim ,
310
+ type_promotion_kind = utils .ELEMENTWISE_TYPE_PROMOTION_KIND .INT_TO_FLOAT ,
311
+ )
312
+
313
+
271
314
def register_dtensor_torch_and_prims ():
272
315
register_function_for_dtensor (torch .mul , ltorch .mul , dtensor_mul , is_method = True )
273
316
register_function_for_dtensor (torch .reshape , ltorch .reshape , dtensor_reshape , is_method = True )
274
317
register_function_for_dtensor (torch .nn .functional .linear , ltorch .linear , dtensor_linear , is_method = False )
318
+ register_function_for_dtensor (torch .exp , ltorch .exp , dtensor_exp , is_method = True )
0 commit comments