Skip to content

Commit 406cd2f

Browse files
authored
Handled DTensor placements as positional argument for redistribute (#2797)
1 parent c6caa17 commit 406cd2f

File tree

2 files changed

+44
-2
lines changed

2 files changed

+44
-2
lines changed

thunder/dynamo/utils.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1219,8 +1219,18 @@ def dtensor_from_local_prim_wrapper(x, mesh=mesh, placements=placements):
12191219
dtensor_from_local_prim_wrapper.thunder_supported = True
12201220
node.target = dtensor_from_local_prim_wrapper
12211221
if "redistribute" in node.target.__name__:
1222-
kwargs = closure_vars.nonlocals["kwargs_as_value"]
1223-
placements = kwargs["placements"]
1222+
args = closure_vars.nonlocals.get("args_as_value", ())
1223+
kwargs = closure_vars.nonlocals.get("kwargs_as_value", {})
1224+
1225+
# Handle positional args: redistribute(device_mesh, placements)
1226+
# and keyword args: redistribute(placements=...)
1227+
# Pytorch docs says that placements can also be None or not provided at all but it will trigger Dynamo to raise an error
1228+
# https://docs.pytorch.org/docs/stable/distributed.tensor.html#torch.distributed.tensor.DTensor.redistribute
1229+
# To be coherent with the Pytorch docs, we will use None if placements is not provided
1230+
if len(args) >= 2:
1231+
placements = args[1]
1232+
else:
1233+
placements = kwargs.get("placements", None)
12241234

12251235
def dtensor_redistribute_prim_wrapper(x, placements=placements):
12261236
return dtensor_redistribute_prim(x, placements=placements)

thunder/tests/distributed/test_dtensor.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -324,6 +324,38 @@ def test_dtensor_columnwise_parallel(self, jit_fn):
324324
assert len(tmodel._backend.subgraph_infos[0].thunder_compiled_fns) == 1
325325
assert len(tmodel._backend.subgraph_infos[0].split_reasons) == 0
326326

327+
def test_dtensor_redistribute_with_positional_args(self):
328+
num_devices = self.world_size
329+
mesh = DeviceMesh("cuda", list(range(num_devices)))
330+
dim_size = 16
331+
332+
# Test redistribute with positional args: redistribute(mesh, placements, async_op=True)
333+
def fn_positional(x):
334+
dt = DTensor.from_local(x, mesh, [Shard(0)])
335+
return dt.redistribute(mesh, [Replicate()], async_op=True)
336+
337+
# Test redistribute with keyword args: redistribute(placements=..., async_op=True)
338+
def fn_keyword(x):
339+
dt = DTensor.from_local(x, mesh, [Shard(0)])
340+
return dt.redistribute(placements=[Replicate()], async_op=True)
341+
342+
local_tensor = torch.randn(dim_size, dim_size, device="cuda")
343+
344+
# Both should work and produce the same result
345+
tmodel_positional = thunderfx(fn_positional)
346+
tmodel_keyword = thunderfx(fn_keyword)
347+
348+
result_positional = tmodel_positional(local_tensor)
349+
result_keyword = tmodel_keyword(local_tensor)
350+
351+
torch.testing.assert_close(result_positional, result_keyword)
352+
353+
# Verify no graph splits occurred (redistribute is supported)
354+
assert len(tmodel_positional._backend.subgraph_infos) == 1
355+
assert len(tmodel_positional._backend.subgraph_infos[0].split_reasons) == 0
356+
assert len(tmodel_keyword._backend.subgraph_infos) == 1
357+
assert len(tmodel_keyword._backend.subgraph_infos[0].split_reasons) == 0
358+
327359
@common_utils.parametrize("executor", tuple(executors_map.keys()))
328360
@common_utils.parametrize(
329361
"input_shardings",

0 commit comments

Comments
 (0)