34
34
}
35
35
36
36
37
+ # NOTE: OpInfo may use `clang` or `ltorch` ops to be jitted with thunder.jit.
38
+ # However, for the current DTensor implementation, we add a dispatch in the `torch` operation lookaside
39
+ # to choose between DTensor supported symbol (from `dtensor_torch_and_prims.py`) or the usual `ltorch` symbol.
40
+ # This is why we need to make sure that the OpInfo uses PyTorch native op as `op` which is passed to thunder.jit.
41
+ class DTensorOpInfo :
42
+ def __init__ (self , * , name , op , torch_reference , supports_grad , sample_inputs ):
43
+ self .name = name
44
+ assert "torch" in op .__module__ , "OpInfo must use PyTorch native op as `op` which is passed to thunder.jit"
45
+ self .op = op
46
+ self .torch_reference = torch_reference
47
+ # NOTE: Not all DTensor ops support grad initially, use this to disable grad tests for them
48
+ self .supports_grad = supports_grad
49
+ # NOTE: This should generally reuse the sample_inputs from the OpInfo
50
+ self .sample_inputs = sample_inputs
51
+
52
+
37
53
# DTensor supported ops
38
- dtensor_supported_ops = ("reshape" ,)
54
+ dtensor_supported_opinfos = (
55
+ DTensorOpInfo (
56
+ name = "reshape" ,
57
+ op = torch .reshape ,
58
+ torch_reference = torch .reshape ,
59
+ supports_grad = True ,
60
+ sample_inputs = get_opinfo ("reshape" ).sample_inputs ,
61
+ ),
62
+ DTensorOpInfo (
63
+ name = "linear" ,
64
+ op = torch .nn .functional .linear ,
65
+ torch_reference = torch .nn .functional .linear ,
66
+ supports_grad = False ,
67
+ sample_inputs = get_opinfo ("linear" ).sample_inputs ,
68
+ ),
69
+ )
39
70
40
- dtensor_supported_opinfos = [get_opinfo (op ) for op in dtensor_supported_ops ]
71
+ skip_opinfos = (
72
+ # RuntimeError: Metadata (placement and mesh) has changed for cotangent between tracing and runtimeduring tracing
73
+ # it was Spec(S(1) on (1, 2, 1)) but at runtime it is Spec(S(1) on (1, 2, 1)).
74
+ "reshape" ,
75
+ )
41
76
42
77
43
78
@unittest .skipUnless (
@@ -189,15 +224,20 @@ def fn(x):
189
224
lambda op , executor : op .name + "_" + executor ,
190
225
)
191
226
def test_dtensor_opinfo (self , op : OpInfo , executor ):
227
+ if op .name in skip_opinfos :
228
+ raise unittest .SkipTest (f"test_dtensor_opinfo: Skipping { op .name } as it is in skip_opinfos" )
229
+
192
230
# NOTE: This test only tests for dtype=torch.float32 and requires_grad=True
193
231
# not for all dtype which are supported by the operation.
194
232
num_devices = self .world_size
195
233
mesh = DeviceMesh ("cuda" , list (range (num_devices )))
196
234
197
- thunder_op = thunder .jit (op .op , executors = executors_map [executor ].executors_list ())
235
+ thunder_op = thunder .jit (op .op , executors = executors_map [executor ].executors_list (), nv_enable_linear = True )
236
+ torch_op = op .torch_reference
198
237
199
238
tested_sample_count = 0
200
- for sample in op .sample_inputs ("cpu" , dtypes .float32 , requires_grad = True ):
239
+
240
+ for sample in op .sample_inputs ("cpu" , dtypes .float32 , requires_grad = op .supports_grad ):
201
241
# DTensorConverter converts inputs tensors to DTensor and creates DTensor
202
242
# with possible placements based on the input shapes.
203
243
# See - https://github.com/pytorch/pytorch/blob/eaa5d9d3d3dc642832b269b184f0c3ab8c990274/torch/testing/_internal/distributed/_tensor/common_dtensor.py#L521
@@ -206,8 +246,6 @@ def test_dtensor_opinfo(self, op: OpInfo, executor):
206
246
if not dtensor_converter .successful ():
207
247
continue
208
248
209
- torch_op = op .torch_reference
210
-
211
249
# Computes PyTorch result
212
250
try :
213
251
torch_result = torch_op (* dtensor_args , ** dtensor_kwargs )
@@ -220,34 +258,38 @@ def test_dtensor_opinfo(self, op: OpInfo, executor):
220
258
thunder_result = thunder_op (* dtensor_args , ** dtensor_kwargs )
221
259
torch .testing .assert_close (thunder_result , torch_result )
222
260
223
- torch_flats , _ = tree_flatten ((dtensor_args , dtensor_kwargs ))
224
- torch_result = filter_differentiable_outputs (torch_result )
225
- if torch_result == []:
226
- raise RuntimeError ("test_dtensor_opinfo: Expected atleast 1 differentiable output." )
227
-
228
- grads = []
229
- assert isinstance (torch_result , torch .Tensor ) or isinstance (torch_result , Sequence ), (
230
- "test_dtensor_opinfo:Expected a single torch tensor or a sequence of torch tensors"
231
- )
232
- if isinstance (torch_result , Sequence ):
233
- for x in torch_result :
234
- assert isinstance (x , torch .Tensor ), (
235
- "test_dtensor_opinfo: Expected a single torch tensor or a sequence of torch tensors"
236
- )
237
- if is_output_differentiable (x ):
238
- grads .append (torch .ones_like (x ))
239
- else :
240
- if is_output_differentiable (torch_result ):
241
- grads = [torch .ones_like (torch_result )]
242
-
243
- torch_tensors_requiring_grad = tuple (
244
- f for f in torch_flats if isinstance (f , torch .Tensor ) and f .requires_grad
245
- )
246
- torch_grad_result = torch .autograd .grad (torch_result , torch_tensors_requiring_grad , grads )
247
-
248
- thunder_result = filter_differentiable_outputs (thunder_result )
249
- thunder_grad_result = torch .autograd .grad (thunder_result , torch_tensors_requiring_grad , grads )
250
- torch .testing .assert_close (thunder_grad_result , torch_grad_result )
261
+ trace = thunder .last_traces (thunder_op )[0 ]
262
+ assert any ("dtensor" in bsym .sym .name for bsym in trace .bound_symbols )
263
+
264
+ if op .supports_grad :
265
+ torch_flats , _ = tree_flatten ((dtensor_args , dtensor_kwargs ))
266
+ torch_result = filter_differentiable_outputs (torch_result )
267
+ if torch_result == []:
268
+ raise RuntimeError ("test_dtensor_opinfo: Expected atleast 1 differentiable output." )
269
+
270
+ grads = []
271
+ assert isinstance (torch_result , torch .Tensor ) or isinstance (torch_result , Sequence ), (
272
+ "test_dtensor_opinfo:Expected a single torch tensor or a sequence of torch tensors"
273
+ )
274
+ if isinstance (torch_result , Sequence ):
275
+ for x in torch_result :
276
+ assert isinstance (x , torch .Tensor ), (
277
+ "test_dtensor_opinfo: Expected a single torch tensor or a sequence of torch tensors"
278
+ )
279
+ if is_output_differentiable (x ):
280
+ grads .append (torch .ones_like (x ))
281
+ else :
282
+ if is_output_differentiable (torch_result ):
283
+ grads = [torch .ones_like (torch_result )]
284
+
285
+ torch_tensors_requiring_grad = tuple (
286
+ f for f in torch_flats if isinstance (f , torch .Tensor ) and f .requires_grad
287
+ )
288
+ torch_grad_result = torch .autograd .grad (torch_result , torch_tensors_requiring_grad , grads )
289
+
290
+ thunder_result = filter_differentiable_outputs (thunder_result )
291
+ thunder_grad_result = torch .autograd .grad (thunder_result , torch_tensors_requiring_grad , grads )
292
+ torch .testing .assert_close (thunder_grad_result , torch_grad_result )
251
293
252
294
# Increment tested sample count
253
295
tested_sample_count += 1
0 commit comments