Skip to content

Commit b23aa27

Browse files
authored
DTensor: support linear (#2422)
1 parent 7693fd9 commit b23aa27

File tree

4 files changed

+106
-35
lines changed

4 files changed

+106
-35
lines changed

thunder/executors/nvfuserex_impl.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2434,6 +2434,7 @@ def linear(
24342434

24352435

24362436
register_supported(PrimIDs.LINEAR, linear, _linear_check)
2437+
register_supported(DTensorPrimIDs.LINEAR, linear, _linear_check)
24372438

24382439

24392440
def _matmul_check(

thunder/tests/distributed/test_dtensor.py

Lines changed: 76 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,45 @@
3434
}
3535

3636

37+
# NOTE: OpInfo may use `clang` or `ltorch` ops to be jitted with thunder.jit.
38+
# However, for the current DTensor implementation, we add a dispatch in the `torch` operation lookaside
39+
# to choose between DTensor supported symbol (from `dtensor_torch_and_prims.py`) or the usual `ltorch` symbol.
40+
# This is why we need to make sure that the OpInfo uses PyTorch native op as `op` which is passed to thunder.jit.
41+
class DTensorOpInfo:
42+
def __init__(self, *, name, op, torch_reference, supports_grad, sample_inputs):
43+
self.name = name
44+
assert "torch" in op.__module__, "OpInfo must use PyTorch native op as `op` which is passed to thunder.jit"
45+
self.op = op
46+
self.torch_reference = torch_reference
47+
# NOTE: Not all DTensor ops support grad initially, use this to disable grad tests for them
48+
self.supports_grad = supports_grad
49+
# NOTE: This should generally reuse the sample_inputs from the OpInfo
50+
self.sample_inputs = sample_inputs
51+
52+
3753
# DTensor supported ops
38-
dtensor_supported_ops = ("reshape",)
54+
dtensor_supported_opinfos = (
55+
DTensorOpInfo(
56+
name="reshape",
57+
op=torch.reshape,
58+
torch_reference=torch.reshape,
59+
supports_grad=True,
60+
sample_inputs=get_opinfo("reshape").sample_inputs,
61+
),
62+
DTensorOpInfo(
63+
name="linear",
64+
op=torch.nn.functional.linear,
65+
torch_reference=torch.nn.functional.linear,
66+
supports_grad=False,
67+
sample_inputs=get_opinfo("linear").sample_inputs,
68+
),
69+
)
3970

40-
dtensor_supported_opinfos = [get_opinfo(op) for op in dtensor_supported_ops]
71+
skip_opinfos = (
72+
# RuntimeError: Metadata (placement and mesh) has changed for cotangent between tracing and runtimeduring tracing
73+
# it was Spec(S(1) on (1, 2, 1)) but at runtime it is Spec(S(1) on (1, 2, 1)).
74+
"reshape",
75+
)
4176

4277

4378
@unittest.skipUnless(
@@ -189,15 +224,20 @@ def fn(x):
189224
lambda op, executor: op.name + "_" + executor,
190225
)
191226
def test_dtensor_opinfo(self, op: OpInfo, executor):
227+
if op.name in skip_opinfos:
228+
raise unittest.SkipTest(f"test_dtensor_opinfo: Skipping {op.name} as it is in skip_opinfos")
229+
192230
# NOTE: This test only tests for dtype=torch.float32 and requires_grad=True
193231
# not for all dtype which are supported by the operation.
194232
num_devices = self.world_size
195233
mesh = DeviceMesh("cuda", list(range(num_devices)))
196234

197-
thunder_op = thunder.jit(op.op, executors=executors_map[executor].executors_list())
235+
thunder_op = thunder.jit(op.op, executors=executors_map[executor].executors_list(), nv_enable_linear=True)
236+
torch_op = op.torch_reference
198237

199238
tested_sample_count = 0
200-
for sample in op.sample_inputs("cpu", dtypes.float32, requires_grad=True):
239+
240+
for sample in op.sample_inputs("cpu", dtypes.float32, requires_grad=op.supports_grad):
201241
# DTensorConverter converts inputs tensors to DTensor and creates DTensor
202242
# with possible placements based on the input shapes.
203243
# See - https://github.com/pytorch/pytorch/blob/eaa5d9d3d3dc642832b269b184f0c3ab8c990274/torch/testing/_internal/distributed/_tensor/common_dtensor.py#L521
@@ -206,8 +246,6 @@ def test_dtensor_opinfo(self, op: OpInfo, executor):
206246
if not dtensor_converter.successful():
207247
continue
208248

209-
torch_op = op.torch_reference
210-
211249
# Computes PyTorch result
212250
try:
213251
torch_result = torch_op(*dtensor_args, **dtensor_kwargs)
@@ -220,34 +258,38 @@ def test_dtensor_opinfo(self, op: OpInfo, executor):
220258
thunder_result = thunder_op(*dtensor_args, **dtensor_kwargs)
221259
torch.testing.assert_close(thunder_result, torch_result)
222260

223-
torch_flats, _ = tree_flatten((dtensor_args, dtensor_kwargs))
224-
torch_result = filter_differentiable_outputs(torch_result)
225-
if torch_result == []:
226-
raise RuntimeError("test_dtensor_opinfo: Expected atleast 1 differentiable output.")
227-
228-
grads = []
229-
assert isinstance(torch_result, torch.Tensor) or isinstance(torch_result, Sequence), (
230-
"test_dtensor_opinfo:Expected a single torch tensor or a sequence of torch tensors"
231-
)
232-
if isinstance(torch_result, Sequence):
233-
for x in torch_result:
234-
assert isinstance(x, torch.Tensor), (
235-
"test_dtensor_opinfo: Expected a single torch tensor or a sequence of torch tensors"
236-
)
237-
if is_output_differentiable(x):
238-
grads.append(torch.ones_like(x))
239-
else:
240-
if is_output_differentiable(torch_result):
241-
grads = [torch.ones_like(torch_result)]
242-
243-
torch_tensors_requiring_grad = tuple(
244-
f for f in torch_flats if isinstance(f, torch.Tensor) and f.requires_grad
245-
)
246-
torch_grad_result = torch.autograd.grad(torch_result, torch_tensors_requiring_grad, grads)
247-
248-
thunder_result = filter_differentiable_outputs(thunder_result)
249-
thunder_grad_result = torch.autograd.grad(thunder_result, torch_tensors_requiring_grad, grads)
250-
torch.testing.assert_close(thunder_grad_result, torch_grad_result)
261+
trace = thunder.last_traces(thunder_op)[0]
262+
assert any("dtensor" in bsym.sym.name for bsym in trace.bound_symbols)
263+
264+
if op.supports_grad:
265+
torch_flats, _ = tree_flatten((dtensor_args, dtensor_kwargs))
266+
torch_result = filter_differentiable_outputs(torch_result)
267+
if torch_result == []:
268+
raise RuntimeError("test_dtensor_opinfo: Expected atleast 1 differentiable output.")
269+
270+
grads = []
271+
assert isinstance(torch_result, torch.Tensor) or isinstance(torch_result, Sequence), (
272+
"test_dtensor_opinfo:Expected a single torch tensor or a sequence of torch tensors"
273+
)
274+
if isinstance(torch_result, Sequence):
275+
for x in torch_result:
276+
assert isinstance(x, torch.Tensor), (
277+
"test_dtensor_opinfo: Expected a single torch tensor or a sequence of torch tensors"
278+
)
279+
if is_output_differentiable(x):
280+
grads.append(torch.ones_like(x))
281+
else:
282+
if is_output_differentiable(torch_result):
283+
grads = [torch.ones_like(torch_result)]
284+
285+
torch_tensors_requiring_grad = tuple(
286+
f for f in torch_flats if isinstance(f, torch.Tensor) and f.requires_grad
287+
)
288+
torch_grad_result = torch.autograd.grad(torch_result, torch_tensors_requiring_grad, grads)
289+
290+
thunder_result = filter_differentiable_outputs(thunder_result)
291+
thunder_grad_result = torch.autograd.grad(thunder_result, torch_tensors_requiring_grad, grads)
292+
torch.testing.assert_close(thunder_grad_result, torch_grad_result)
251293

252294
# Increment tested sample count
253295
tested_sample_count += 1

thunder/tests/test_dynamo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -841,7 +841,7 @@ def find_target_module(model, target_module_name):
841841
assert submodule is not None
842842
for n in submodule.graph.nodes:
843843
if n.op == "call_function":
844-
assert isinstance(n.target, Symbol)
844+
assert isinstance(n.target, Symbol) or callable(n.target)
845845

846846

847847
@instantiate(

thunder/torch/experimental/dtensor_torch_and_prims.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ class DTensorPrimIDs(Enum):
3535
RESHAPE = auto()
3636
CONVERT_ELEMENT_TYPE = auto()
3737
BROADCAST_IN_DIM = auto()
38+
LINEAR = auto()
3839

3940

4041
dtensor_torchsymbol = partial(torchsymbol, allow_tensor_subclass_proxy=True)
@@ -241,6 +242,33 @@ def dtensor_broadcast_in_dim_meta(a, shape, broadcast_dimensions):
241242
pytorchex.register_implementation(dtensor_broadcast_in_dim_prim, dtensor_broadcast_in_dim_prim_impl)
242243

243244

245+
def dtensor_linear_meta(a, w, bias):
246+
output = run_with_fake_tensor(torch.nn.functional.linear, a, w, bias)
247+
local_tensor_proxy = TensorProxy(like=a.local_tensor)
248+
local_tensor_proxy = TensorProxy(
249+
like=a.local_tensor, shape=output._local_tensor.shape, dtype=dtypes.to_dtype(output._local_tensor.dtype)
250+
)
251+
spec = output._spec
252+
spec_proxy = AnyProxy(spec, history=a.history)
253+
return create_dtensor_proxy_from_proxies(local_tensor_proxy, spec_proxy, False)
254+
255+
256+
# TODO: Add grad rule once the prims used for linear grad-rule are available.
257+
dtensor_linear_prim = make_prim(DTensorPrimIDs.LINEAR, "dtensor_linear_prim", meta=dtensor_linear_meta)
258+
259+
dtensor_linear_prim_impl = pytorchex.register_operator(
260+
"dtensor_linear_prim", like=dtensor_linear_prim, fn=torch.nn.functional.linear
261+
)
262+
263+
pytorchex.register_implementation(dtensor_linear_prim, dtensor_linear_prim_impl)
264+
265+
266+
@dtensor_torchsymbol(torch.nn.functional.linear, id="dtensor.torch.nn.functional.linear")
267+
def dtensor_linear(a: TensorLike, w: TensorLike, bias: None | TensorLike = None) -> TensorLike:
268+
return dtensor_linear_prim(a, w, bias)
269+
270+
244271
def register_dtensor_torch_and_prims():
245272
register_function_for_dtensor(torch.mul, ltorch.mul, dtensor_mul, is_method=True)
246273
register_function_for_dtensor(torch.reshape, ltorch.reshape, dtensor_reshape, is_method=True)
274+
register_function_for_dtensor(torch.nn.functional.linear, ltorch.linear, dtensor_linear, is_method=False)

0 commit comments

Comments
 (0)