diff --git a/thunder/dynamo/splitter.py b/thunder/dynamo/splitter.py index 7949937ca5..b7f3d9fb9b 100644 --- a/thunder/dynamo/splitter.py +++ b/thunder/dynamo/splitter.py @@ -19,6 +19,7 @@ checkpoint_converter, _get_example_inputs_from_placeholder, _ThunderSplitGraphModule, + translate_dtensor_ops, ) if TYPE_CHECKING: @@ -98,6 +99,7 @@ def forward(self, l_x_: "f32[2]", y: "f32[2]"): split_reasons: list[SplitReason] = [] nodes_in_unsupported_ctx_regions = get_nodes_in_unsupported_ctx_regions(gm) + translate_dtensor_ops(gm) def callback(node) -> int: nonlocal prev_value, partition_cnt, split_reasons, supported_partitions @@ -119,9 +121,14 @@ def callback(node) -> int: ) split_reasons.append(split_reason) else: - is_thunder_supported, split_reason = is_node_supported_by_thunder(node) - if split_reason is not None: - split_reasons.append(split_reason) + # To support dynamo generated prims for `parallelize_module`. + # `translate_dtensor_ops` will mark the target as thunder supported if it is a DTensor operation. + if hasattr(node.target, "thunder_supported") and node.target.thunder_supported: + is_thunder_supported, split_reason = True, None + else: + is_thunder_supported, split_reason = is_node_supported_by_thunder(node) + if split_reason is not None: + split_reasons.append(split_reason) if prev_value == is_thunder_supported: # We are in the same region. return partition_cnt diff --git a/thunder/dynamo/utils.py b/thunder/dynamo/utils.py index 3c58bc2ded..89db40836f 100644 --- a/thunder/dynamo/utils.py +++ b/thunder/dynamo/utils.py @@ -1050,3 +1050,78 @@ def get_compiled_fn_and_timing(report, compile_fn, timer_fn): err_msg = ", ".join([f"{x.name} raised exception: {x.compiled_fn}" for x in sorted_compiled_gm_to_measurement]) raise RuntimeError(f"No compiler was able to compile the graph module, {err_msg}") return sorted_compiled_gm_to_measurement[0].compiled_fn + + +def translate_dtensor_ops(gm: torch.fx.GraphModule) -> None: + # We need this function because: + # + # For a program like: + # ``` + # model = nn.Linear(hidden_size, hidden_size, bias=False) + # parallel_model = parallelize_module(model, mesh, {"fc1": ColwiseParallel()}) + # model.fc1.weight.requires_grad = False + + # # parallelize_module will handle the conversion to DTensor + # i = torch.randn(hidden_size, hidden_size) + # ```` + # + # Dynamo captures an FX-Graph like: + # ``` + # def forward(self, L_x_: "f32[16, 16]", L_self_modules_fc1_parameters_weight_: "f32[16, 16]"): + # l_x_ = L_x_ + # l_self_modules_fc1_parameters_weight_ = L_self_modules_fc1_parameters_weight_ + # + # input_tensor: "f32[16, 16]" = torch__dynamo_variables_torch_prim_from_local(l_x_); l_x_ = None + # + # linear: "f32[16, 16]" = torch._C._nn.linear(input_tensor, l_self_modules_fc1_parameters_weight_, None); input_tensor = l_self_modules_fc1_parameters_weight_ = None + # + # outputs: "f32[16, 16]" = torch__dynamo_variables_tensor_prim_redistribute(linear); linear = None + # + # hook_result: "f32[16, 8]" = torch__dynamo_variables_tensor_prim_to_local(outputs); outputs = None + # return (hook_result,) + # ``` + # where: + # 1. In the FX Graph, the Tensor Parallel computation is decomposed into primitive operations such as `torch__dynamo_variables_torch_prim_from_local`, `torch__dynamo_variables_tensor_prim_redistribute`, and others. + # 2. It is important to note that these decompositions actually capture (close over) values such as `placements` and other metadata. + # For example, to understand the placements to which the output will be redistributed using `torch__dynamo_variables_tensor_prim_redistribute`, + # we need to use `inspect.getclosurevars(node.target)` to examine the values (like placements) that are captured and used during execution. + # The reference for where this closure is created can be found at: + # https://github.com/pytorch/pytorch/blob/0ab075a69e4577a60c4dcbff7bcc2ecd0a15ce46/torch/_dynamo/variables/tensor.py#L1186-L1210 + + from thunder.torch.experimental.dtensor_torch_and_prims import ( + dtensor_from_local_prim, + dtensor_redistribute_prim, + dtensor_to_local_prim, + ) + + for node in gm.graph.nodes: + try: + closure_vars = inspect.getclosurevars(node.target) + + if "from_local" in node.target.__name__: + mesh = closure_vars.nonlocals["args_as_value"][0] + placements = closure_vars.nonlocals["args_as_value"][1] + + def dtensor_from_local_prim_wrapper(x, mesh=mesh, placements=placements): + return dtensor_from_local_prim(x, mesh, placements) + + dtensor_from_local_prim_wrapper.thunder_supported = True + node.target = dtensor_from_local_prim_wrapper + if "redistribute" in node.target.__name__: + kwargs = closure_vars.nonlocals["kwargs_as_value"] + placements = kwargs["placements"] + + def dtensor_redistribute_prim_wrapper(x, placements=placements): + return dtensor_redistribute_prim(x, placements=placements) + + dtensor_redistribute_prim_wrapper.thunder_supported = True + node.target = dtensor_redistribute_prim_wrapper + if "to_local" in node.target.__name__: + + def dtensor_to_local_prim_wrapper(x): + return dtensor_to_local_prim(x) + + dtensor_to_local_prim_wrapper.thunder_supported = True + node.target = dtensor_to_local_prim_wrapper + except Exception: + pass diff --git a/thunder/tests/distributed/test_dtensor.py b/thunder/tests/distributed/test_dtensor.py index 3f5276c00b..7bc9671cc8 100644 --- a/thunder/tests/distributed/test_dtensor.py +++ b/thunder/tests/distributed/test_dtensor.py @@ -15,6 +15,10 @@ from torch.distributed._tensor import DeviceMesh, distribute_tensor from torch.distributed.tensor.placement_types import Shard, Replicate from torch.testing._internal.distributed._tensor.common_dtensor import DTensorConverter +from torch.distributed.tensor.parallel import ( + parallelize_module, + ColwiseParallel, +) from torch.testing._internal import common_utils @@ -23,6 +27,7 @@ from thunder.tests.utils import is_output_differentiable, filter_differentiable_outputs import thunder.core.dtypes as dtypes from thunder.core.pytree import tree_flatten +from thunder.dynamo import thunderfx # NOTE: We run all these similar functions seperately @@ -272,6 +277,43 @@ def fn(x): torch.testing.assert_close(actual, expected) + @common_utils.parametrize("jit_fn", (thunder.jit, thunderfx), name_fn=lambda jit_fn: jit_fn.__name__) + def test_dtensor_columnwise_parallel(self, jit_fn): + num_devices = self.world_size + mesh = DeviceMesh("cuda", list(range(num_devices))) + dim_size = 16 + in_dtensor = torch.randn(dim_size, dim_size, requires_grad=False) + m = torch.nn.Linear(dim_size, dim_size) + m.requires_grad_(False) + + parallelized_model = parallelize_module(m, mesh, ColwiseParallel()) + + # `parallelize_module` sets `requires_grad` to True, set it to False again. + parallelized_model.requires_grad_(False) + + actual = parallelized_model(in_dtensor) + expected = m(in_dtensor) + torch.testing.assert_close(actual, expected) + + tmodel = jit_fn(parallelized_model, nv_enable_linear=True) + + if jit_fn == thunder.jit: + # Original error caught by the interpreter: + # File "/opt/pytorch/lightning-thunder/thunder/core/jit_ext.py", line 444, in _general_jit_getattr_lookaside + # obj.original_value.__dict__, + # ^^^^^^^^^^^^^^^^^^^^^^^^^^^ + # AttributeError: 'object' object has no attribute '__dict__'. Did you mean: '__dir__'? + with self.assertRaises(thunder.core.interpreter.InterpreterError): + actual = tmodel(in_dtensor) + else: + actual = tmodel(in_dtensor) + torch.testing.assert_close(actual, expected) + + if jit_fn == thunderfx: + assert len(tmodel._backend.subgraph_infos) == 1 + assert len(tmodel._backend.subgraph_infos[0].thunder_compiled_fns) == 1 + assert len(tmodel._backend.subgraph_infos[0].split_reasons) == 0 + @common_utils.parametrize("executor", tuple(executors_map.keys())) @common_utils.parametrize( "input_shardings", diff --git a/thunder/torch/experimental/dtensor_proxy.py b/thunder/torch/experimental/dtensor_proxy.py index 5243dcdebf..87239ac710 100644 --- a/thunder/torch/experimental/dtensor_proxy.py +++ b/thunder/torch/experimental/dtensor_proxy.py @@ -69,6 +69,41 @@ def placements(self): def device_mesh(self): return self.spec._o.device_mesh + @staticmethod + def from_local( + x, + mesh, + placements, + *, + run_check: bool = False, + shape: torch.Size | None = None, + stride: tuple[int, ...] | None = None, + ): + import thunder.torch.experimental.dtensor_torch_and_prims as dtensor_torch_and_prims + + res = dtensor_torch_and_prims.dtensor_from_local_prim( + x, mesh, placements, run_check=run_check, shape=shape, stride=stride + ) + return res + + def redistribute( + self, + device_mesh: "Optional[DeviceMesh]" = None, + placements: "Optional[Sequence[Placement]]" = None, + *, + async_op: bool = False, + ) -> "DTensor": + import thunder.torch.experimental.dtensor_torch_and_prims as dtensor_torch_and_prims + + res = dtensor_torch_and_prims.dtensor_redistribute_prim(self, device_mesh, placements, async_op=async_op) + return res + + def to_local(self, *, grad_placements: "Optional[Sequence[Placement]]" = None): + import thunder.torch.experimental.dtensor_torch_and_prims as dtensor_torch_and_prims + + res = dtensor_torch_and_prims.dtensor_to_local_prim(self, grad_placements=grad_placements) + return res + def replace(self, **changes): r"""Return a copy of the TensorProxy object with new values for the specified fields as given to the constructor as arguments. Valid keyword arguments are ``name``, ``history``, ``shape``, ``dtype``, ``device``, ``requires_grad``, ``distparallel_type``, ``thunder_fsdp_padding_size``. diff --git a/thunder/torch/experimental/dtensor_torch_and_prims.py b/thunder/torch/experimental/dtensor_torch_and_prims.py index 393c6284bd..4297c0f3e1 100644 --- a/thunder/torch/experimental/dtensor_torch_and_prims.py +++ b/thunder/torch/experimental/dtensor_torch_and_prims.py @@ -1,6 +1,7 @@ from functools import partial from collections.abc import Callable from enum import auto, Enum +from collections.abc import Sequence from looseversion import LooseVersion from thunder.torch import torchsymbol, TensorLike, register_function @@ -371,6 +372,100 @@ def dtensor_reciprocal(a: TensorLike) -> TensorLike: ) +if torch.distributed.is_available(): + from torch.distributed.tensor import DTensor + from torch.distributed.tensor.placement_types import Placement, DeviceMesh + + def dtensor_from_local_meta( + x, + mesh, + placements, + *, + run_check: bool = False, + shape: torch.Size | None = None, + stride: tuple[int, ...] | None = None, + ): + res = run_with_fake_tensor( + DTensor.from_local, x, mesh, placements, run_check=run_check, shape=shape, stride=stride + ) + from thunder.torch.experimental.dtensor_proxy import proxify_dtensor + + res = proxify_dtensor(res) + return res + + dtensor_from_local_prim = make_prim("dtensor_from_local", "dtensor_from_local", meta=dtensor_from_local_meta) + + dtensor_from_local_prim_impl = pytorchex.register_operator( + "dtensor_from_local", like=dtensor_from_local_prim, fn=DTensor.from_local + ) + + pytorchex.register_implementation(dtensor_from_local_prim, dtensor_from_local_prim_impl) + + @dtensor_torchsymbol(DTensor.from_local, id="dtensor.torch.from_local") + def dtensor_from_local( + x, + mesh, + placements, + *, + run_check: bool = False, + shape: torch.Size | None = None, + stride: tuple[int, ...] | None = None, + ) -> DTensorProxy | None: + return dtensor_from_local_prim(x, mesh, placements, run_check=run_check, shape=shape, stride=stride) + + def dtensor_redistribute_meta( + dtensor, + device_mesh: DeviceMesh | None = None, + placements: Sequence[Placement] | None = None, + *, + async_op: bool = False, + ) -> DTensorProxy | None: + res = run_with_fake_tensor(DTensor.redistribute, dtensor, device_mesh, placements, async_op=async_op) + from thunder.torch.experimental.dtensor_proxy import proxify_dtensor + + res = proxify_dtensor(res) + return res + + dtensor_redistribute_prim = make_prim( + "dtensor_redistribute", "dtensor_redistribute", meta=dtensor_redistribute_meta + ) + + dtensor_redistribute_prim_impl = pytorchex.register_operator( + "dtensor_redistribute", like=dtensor_redistribute_prim, fn=DTensor.redistribute + ) + + @dtensor_torchsymbol(DTensor.redistribute, id="dtensor.torch.redistribute") + def dtensor_redistribute( + dtensor, + device_mesh: DeviceMesh | None = None, + placements: Sequence[Placement] | None = None, + *, + async_op: bool = False, + ) -> DTensorProxy | None: + return dtensor_redistribute_prim(dtensor, device_mesh, placements, async_op=async_op) + + pytorchex.register_implementation(dtensor_redistribute_prim, dtensor_redistribute_prim_impl) + + def dtensor_to_local_meta(dtensor, *, grad_placements: Sequence[Placement] | None = None): + res = run_with_fake_tensor(DTensor.to_local, dtensor, grad_placements=grad_placements) + from thunder.core.proxies import proxy + + res = proxy(res) + return res + + dtensor_to_local_prim = make_prim("dtensor_to_local", "dtensor_to_local", meta=dtensor_to_local_meta) + + dtensor_to_local_prim_impl = pytorchex.register_operator( + "dtensor_to_local", like=dtensor_to_local_prim, fn=DTensor.to_local + ) + + pytorchex.register_implementation(dtensor_to_local_prim, dtensor_to_local_prim_impl) + + @dtensor_torchsymbol(DTensor.to_local, id="dtensor.torch.to_local") + def dtensor_to_local(dtensor, *, grad_placements: Sequence[Placement] | None = None) -> DTensorProxy | None: + return dtensor_to_local_prim(dtensor, grad_placements=grad_placements) + + expand = partial(expand_impl, broadcast_prim=dtensor_broadcast_in_dim_prim) maybe_broadcast = partial(maybe_broadcast_impl, expand_fn=expand)