Skip to content

Commit cda32a4

Browse files
authored
[DTensor] Add test with parallelize_module (#2598)
1 parent f3d8a42 commit cda32a4

File tree

5 files changed

+257
-3
lines changed

5 files changed

+257
-3
lines changed

thunder/dynamo/splitter.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
checkpoint_converter,
2020
_get_example_inputs_from_placeholder,
2121
_ThunderSplitGraphModule,
22+
translate_dtensor_ops,
2223
)
2324

2425
if TYPE_CHECKING:
@@ -98,6 +99,7 @@ def forward(self, l_x_: "f32[2]", y: "f32[2]"):
9899
split_reasons: list[SplitReason] = []
99100

100101
nodes_in_unsupported_ctx_regions = get_nodes_in_unsupported_ctx_regions(gm)
102+
translate_dtensor_ops(gm)
101103

102104
def callback(node) -> int:
103105
nonlocal prev_value, partition_cnt, split_reasons, supported_partitions
@@ -119,9 +121,14 @@ def callback(node) -> int:
119121
)
120122
split_reasons.append(split_reason)
121123
else:
122-
is_thunder_supported, split_reason = is_node_supported_by_thunder(node)
123-
if split_reason is not None:
124-
split_reasons.append(split_reason)
124+
# To support dynamo generated prims for `parallelize_module`.
125+
# `translate_dtensor_ops` will mark the target as thunder supported if it is a DTensor operation.
126+
if hasattr(node.target, "thunder_supported") and node.target.thunder_supported:
127+
is_thunder_supported, split_reason = True, None
128+
else:
129+
is_thunder_supported, split_reason = is_node_supported_by_thunder(node)
130+
if split_reason is not None:
131+
split_reasons.append(split_reason)
125132

126133
if prev_value == is_thunder_supported: # We are in the same region.
127134
return partition_cnt

thunder/dynamo/utils.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1050,3 +1050,78 @@ def get_compiled_fn_and_timing(report, compile_fn, timer_fn):
10501050
err_msg = ", ".join([f"{x.name} raised exception: {x.compiled_fn}" for x in sorted_compiled_gm_to_measurement])
10511051
raise RuntimeError(f"No compiler was able to compile the graph module, {err_msg}")
10521052
return sorted_compiled_gm_to_measurement[0].compiled_fn
1053+
1054+
1055+
def translate_dtensor_ops(gm: torch.fx.GraphModule) -> None:
1056+
# We need this function because:
1057+
#
1058+
# For a program like:
1059+
# ```
1060+
# model = nn.Linear(hidden_size, hidden_size, bias=False)
1061+
# parallel_model = parallelize_module(model, mesh, {"fc1": ColwiseParallel()})
1062+
# model.fc1.weight.requires_grad = False
1063+
1064+
# # parallelize_module will handle the conversion to DTensor
1065+
# i = torch.randn(hidden_size, hidden_size)
1066+
# ````
1067+
#
1068+
# Dynamo captures an FX-Graph like:
1069+
# ```
1070+
# def forward(self, L_x_: "f32[16, 16]", L_self_modules_fc1_parameters_weight_: "f32[16, 16]"):
1071+
# l_x_ = L_x_
1072+
# l_self_modules_fc1_parameters_weight_ = L_self_modules_fc1_parameters_weight_
1073+
#
1074+
# input_tensor: "f32[16, 16]" = torch__dynamo_variables_torch_prim_from_local(l_x_); l_x_ = None
1075+
#
1076+
# linear: "f32[16, 16]" = torch._C._nn.linear(input_tensor, l_self_modules_fc1_parameters_weight_, None); input_tensor = l_self_modules_fc1_parameters_weight_ = None
1077+
#
1078+
# outputs: "f32[16, 16]" = torch__dynamo_variables_tensor_prim_redistribute(linear); linear = None
1079+
#
1080+
# hook_result: "f32[16, 8]" = torch__dynamo_variables_tensor_prim_to_local(outputs); outputs = None
1081+
# return (hook_result,)
1082+
# ```
1083+
# where:
1084+
# 1. In the FX Graph, the Tensor Parallel computation is decomposed into primitive operations such as `torch__dynamo_variables_torch_prim_from_local`, `torch__dynamo_variables_tensor_prim_redistribute`, and others.
1085+
# 2. It is important to note that these decompositions actually capture (close over) values such as `placements` and other metadata.
1086+
# For example, to understand the placements to which the output will be redistributed using `torch__dynamo_variables_tensor_prim_redistribute`,
1087+
# we need to use `inspect.getclosurevars(node.target)` to examine the values (like placements) that are captured and used during execution.
1088+
# The reference for where this closure is created can be found at:
1089+
# https://github.com/pytorch/pytorch/blob/0ab075a69e4577a60c4dcbff7bcc2ecd0a15ce46/torch/_dynamo/variables/tensor.py#L1186-L1210
1090+
1091+
from thunder.torch.experimental.dtensor_torch_and_prims import (
1092+
dtensor_from_local_prim,
1093+
dtensor_redistribute_prim,
1094+
dtensor_to_local_prim,
1095+
)
1096+
1097+
for node in gm.graph.nodes:
1098+
try:
1099+
closure_vars = inspect.getclosurevars(node.target)
1100+
1101+
if "from_local" in node.target.__name__:
1102+
mesh = closure_vars.nonlocals["args_as_value"][0]
1103+
placements = closure_vars.nonlocals["args_as_value"][1]
1104+
1105+
def dtensor_from_local_prim_wrapper(x, mesh=mesh, placements=placements):
1106+
return dtensor_from_local_prim(x, mesh, placements)
1107+
1108+
dtensor_from_local_prim_wrapper.thunder_supported = True
1109+
node.target = dtensor_from_local_prim_wrapper
1110+
if "redistribute" in node.target.__name__:
1111+
kwargs = closure_vars.nonlocals["kwargs_as_value"]
1112+
placements = kwargs["placements"]
1113+
1114+
def dtensor_redistribute_prim_wrapper(x, placements=placements):
1115+
return dtensor_redistribute_prim(x, placements=placements)
1116+
1117+
dtensor_redistribute_prim_wrapper.thunder_supported = True
1118+
node.target = dtensor_redistribute_prim_wrapper
1119+
if "to_local" in node.target.__name__:
1120+
1121+
def dtensor_to_local_prim_wrapper(x):
1122+
return dtensor_to_local_prim(x)
1123+
1124+
dtensor_to_local_prim_wrapper.thunder_supported = True
1125+
node.target = dtensor_to_local_prim_wrapper
1126+
except Exception:
1127+
pass

thunder/tests/distributed/test_dtensor.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,10 @@
1515
from torch.distributed._tensor import DeviceMesh, distribute_tensor
1616
from torch.distributed.tensor.placement_types import Shard, Replicate
1717
from torch.testing._internal.distributed._tensor.common_dtensor import DTensorConverter
18+
from torch.distributed.tensor.parallel import (
19+
parallelize_module,
20+
ColwiseParallel,
21+
)
1822

1923
from torch.testing._internal import common_utils
2024

@@ -23,6 +27,7 @@
2327
from thunder.tests.utils import is_output_differentiable, filter_differentiable_outputs
2428
import thunder.core.dtypes as dtypes
2529
from thunder.core.pytree import tree_flatten
30+
from thunder.dynamo import thunderfx
2631

2732

2833
# NOTE: We run all these similar functions seperately
@@ -272,6 +277,43 @@ def fn(x):
272277

273278
torch.testing.assert_close(actual, expected)
274279

280+
@common_utils.parametrize("jit_fn", (thunder.jit, thunderfx), name_fn=lambda jit_fn: jit_fn.__name__)
281+
def test_dtensor_columnwise_parallel(self, jit_fn):
282+
num_devices = self.world_size
283+
mesh = DeviceMesh("cuda", list(range(num_devices)))
284+
dim_size = 16
285+
in_dtensor = torch.randn(dim_size, dim_size, requires_grad=False)
286+
m = torch.nn.Linear(dim_size, dim_size)
287+
m.requires_grad_(False)
288+
289+
parallelized_model = parallelize_module(m, mesh, ColwiseParallel())
290+
291+
# `parallelize_module` sets `requires_grad` to True, set it to False again.
292+
parallelized_model.requires_grad_(False)
293+
294+
actual = parallelized_model(in_dtensor)
295+
expected = m(in_dtensor)
296+
torch.testing.assert_close(actual, expected)
297+
298+
tmodel = jit_fn(parallelized_model, nv_enable_linear=True)
299+
300+
if jit_fn == thunder.jit:
301+
# Original error caught by the interpreter:
302+
# File "/opt/pytorch/lightning-thunder/thunder/core/jit_ext.py", line 444, in _general_jit_getattr_lookaside
303+
# obj.original_value.__dict__,
304+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^
305+
# AttributeError: 'object' object has no attribute '__dict__'. Did you mean: '__dir__'?
306+
with self.assertRaises(thunder.core.interpreter.InterpreterError):
307+
actual = tmodel(in_dtensor)
308+
else:
309+
actual = tmodel(in_dtensor)
310+
torch.testing.assert_close(actual, expected)
311+
312+
if jit_fn == thunderfx:
313+
assert len(tmodel._backend.subgraph_infos) == 1
314+
assert len(tmodel._backend.subgraph_infos[0].thunder_compiled_fns) == 1
315+
assert len(tmodel._backend.subgraph_infos[0].split_reasons) == 0
316+
275317
@common_utils.parametrize("executor", tuple(executors_map.keys()))
276318
@common_utils.parametrize(
277319
"input_shardings",

thunder/torch/experimental/dtensor_proxy.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,41 @@ def placements(self):
6969
def device_mesh(self):
7070
return self.spec._o.device_mesh
7171

72+
@staticmethod
73+
def from_local(
74+
x,
75+
mesh,
76+
placements,
77+
*,
78+
run_check: bool = False,
79+
shape: torch.Size | None = None,
80+
stride: tuple[int, ...] | None = None,
81+
):
82+
import thunder.torch.experimental.dtensor_torch_and_prims as dtensor_torch_and_prims
83+
84+
res = dtensor_torch_and_prims.dtensor_from_local_prim(
85+
x, mesh, placements, run_check=run_check, shape=shape, stride=stride
86+
)
87+
return res
88+
89+
def redistribute(
90+
self,
91+
device_mesh: "Optional[DeviceMesh]" = None,
92+
placements: "Optional[Sequence[Placement]]" = None,
93+
*,
94+
async_op: bool = False,
95+
) -> "DTensor":
96+
import thunder.torch.experimental.dtensor_torch_and_prims as dtensor_torch_and_prims
97+
98+
res = dtensor_torch_and_prims.dtensor_redistribute_prim(self, device_mesh, placements, async_op=async_op)
99+
return res
100+
101+
def to_local(self, *, grad_placements: "Optional[Sequence[Placement]]" = None):
102+
import thunder.torch.experimental.dtensor_torch_and_prims as dtensor_torch_and_prims
103+
104+
res = dtensor_torch_and_prims.dtensor_to_local_prim(self, grad_placements=grad_placements)
105+
return res
106+
72107
def replace(self, **changes):
73108
r"""Return a copy of the TensorProxy object with new values for the specified fields as given to the constructor as arguments.
74109
Valid keyword arguments are ``name``, ``history``, ``shape``, ``dtype``, ``device``, ``requires_grad``, ``distparallel_type``, ``thunder_fsdp_padding_size``.

thunder/torch/experimental/dtensor_torch_and_prims.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from functools import partial
22
from collections.abc import Callable
33
from enum import auto, Enum
4+
from collections.abc import Sequence
45
from looseversion import LooseVersion
56

67
from thunder.torch import torchsymbol, TensorLike, register_function
@@ -371,6 +372,100 @@ def dtensor_reciprocal(a: TensorLike) -> TensorLike:
371372
)
372373

373374

375+
if torch.distributed.is_available():
376+
from torch.distributed.tensor import DTensor
377+
from torch.distributed.tensor.placement_types import Placement, DeviceMesh
378+
379+
def dtensor_from_local_meta(
380+
x,
381+
mesh,
382+
placements,
383+
*,
384+
run_check: bool = False,
385+
shape: torch.Size | None = None,
386+
stride: tuple[int, ...] | None = None,
387+
):
388+
res = run_with_fake_tensor(
389+
DTensor.from_local, x, mesh, placements, run_check=run_check, shape=shape, stride=stride
390+
)
391+
from thunder.torch.experimental.dtensor_proxy import proxify_dtensor
392+
393+
res = proxify_dtensor(res)
394+
return res
395+
396+
dtensor_from_local_prim = make_prim("dtensor_from_local", "dtensor_from_local", meta=dtensor_from_local_meta)
397+
398+
dtensor_from_local_prim_impl = pytorchex.register_operator(
399+
"dtensor_from_local", like=dtensor_from_local_prim, fn=DTensor.from_local
400+
)
401+
402+
pytorchex.register_implementation(dtensor_from_local_prim, dtensor_from_local_prim_impl)
403+
404+
@dtensor_torchsymbol(DTensor.from_local, id="dtensor.torch.from_local")
405+
def dtensor_from_local(
406+
x,
407+
mesh,
408+
placements,
409+
*,
410+
run_check: bool = False,
411+
shape: torch.Size | None = None,
412+
stride: tuple[int, ...] | None = None,
413+
) -> DTensorProxy | None:
414+
return dtensor_from_local_prim(x, mesh, placements, run_check=run_check, shape=shape, stride=stride)
415+
416+
def dtensor_redistribute_meta(
417+
dtensor,
418+
device_mesh: DeviceMesh | None = None,
419+
placements: Sequence[Placement] | None = None,
420+
*,
421+
async_op: bool = False,
422+
) -> DTensorProxy | None:
423+
res = run_with_fake_tensor(DTensor.redistribute, dtensor, device_mesh, placements, async_op=async_op)
424+
from thunder.torch.experimental.dtensor_proxy import proxify_dtensor
425+
426+
res = proxify_dtensor(res)
427+
return res
428+
429+
dtensor_redistribute_prim = make_prim(
430+
"dtensor_redistribute", "dtensor_redistribute", meta=dtensor_redistribute_meta
431+
)
432+
433+
dtensor_redistribute_prim_impl = pytorchex.register_operator(
434+
"dtensor_redistribute", like=dtensor_redistribute_prim, fn=DTensor.redistribute
435+
)
436+
437+
@dtensor_torchsymbol(DTensor.redistribute, id="dtensor.torch.redistribute")
438+
def dtensor_redistribute(
439+
dtensor,
440+
device_mesh: DeviceMesh | None = None,
441+
placements: Sequence[Placement] | None = None,
442+
*,
443+
async_op: bool = False,
444+
) -> DTensorProxy | None:
445+
return dtensor_redistribute_prim(dtensor, device_mesh, placements, async_op=async_op)
446+
447+
pytorchex.register_implementation(dtensor_redistribute_prim, dtensor_redistribute_prim_impl)
448+
449+
def dtensor_to_local_meta(dtensor, *, grad_placements: Sequence[Placement] | None = None):
450+
res = run_with_fake_tensor(DTensor.to_local, dtensor, grad_placements=grad_placements)
451+
from thunder.core.proxies import proxy
452+
453+
res = proxy(res)
454+
return res
455+
456+
dtensor_to_local_prim = make_prim("dtensor_to_local", "dtensor_to_local", meta=dtensor_to_local_meta)
457+
458+
dtensor_to_local_prim_impl = pytorchex.register_operator(
459+
"dtensor_to_local", like=dtensor_to_local_prim, fn=DTensor.to_local
460+
)
461+
462+
pytorchex.register_implementation(dtensor_to_local_prim, dtensor_to_local_prim_impl)
463+
464+
@dtensor_torchsymbol(DTensor.to_local, id="dtensor.torch.to_local")
465+
def dtensor_to_local(dtensor, *, grad_placements: Sequence[Placement] | None = None) -> DTensorProxy | None:
466+
return dtensor_to_local_prim(dtensor, grad_placements=grad_placements)
467+
468+
374469
expand = partial(expand_impl, broadcast_prim=dtensor_broadcast_in_dim_prim)
375470
maybe_broadcast = partial(maybe_broadcast_impl, expand_fn=expand)
376471

0 commit comments

Comments
 (0)