From 596fb72dde749240a4a992739814928182f970ee Mon Sep 17 00:00:00 2001 From: kshitij12345 Date: Fri, 3 Oct 2025 07:04:55 -0700 Subject: [PATCH 1/5] DTensor: Add test with parallelize_module --- thunder/dynamo/splitter.py | 13 ++- thunder/dynamo/utils.py | 40 ++++++++ thunder/executors/nvfuserex_impl.py | 2 + thunder/tests/distributed/test_dtensor.py | 39 ++++++++ thunder/torch/experimental/dtensor_proxy.py | 35 +++++++ .../experimental/dtensor_torch_and_prims.py | 95 +++++++++++++++++++ 6 files changed, 221 insertions(+), 3 deletions(-) diff --git a/thunder/dynamo/splitter.py b/thunder/dynamo/splitter.py index 7949937ca5..b7f3d9fb9b 100644 --- a/thunder/dynamo/splitter.py +++ b/thunder/dynamo/splitter.py @@ -19,6 +19,7 @@ checkpoint_converter, _get_example_inputs_from_placeholder, _ThunderSplitGraphModule, + translate_dtensor_ops, ) if TYPE_CHECKING: @@ -98,6 +99,7 @@ def forward(self, l_x_: "f32[2]", y: "f32[2]"): split_reasons: list[SplitReason] = [] nodes_in_unsupported_ctx_regions = get_nodes_in_unsupported_ctx_regions(gm) + translate_dtensor_ops(gm) def callback(node) -> int: nonlocal prev_value, partition_cnt, split_reasons, supported_partitions @@ -119,9 +121,14 @@ def callback(node) -> int: ) split_reasons.append(split_reason) else: - is_thunder_supported, split_reason = is_node_supported_by_thunder(node) - if split_reason is not None: - split_reasons.append(split_reason) + # To support dynamo generated prims for `parallelize_module`. + # `translate_dtensor_ops` will mark the target as thunder supported if it is a DTensor operation. + if hasattr(node.target, "thunder_supported") and node.target.thunder_supported: + is_thunder_supported, split_reason = True, None + else: + is_thunder_supported, split_reason = is_node_supported_by_thunder(node) + if split_reason is not None: + split_reasons.append(split_reason) if prev_value == is_thunder_supported: # We are in the same region. return partition_cnt diff --git a/thunder/dynamo/utils.py b/thunder/dynamo/utils.py index 3c58bc2ded..9e43e81354 100644 --- a/thunder/dynamo/utils.py +++ b/thunder/dynamo/utils.py @@ -1050,3 +1050,43 @@ def get_compiled_fn_and_timing(report, compile_fn, timer_fn): err_msg = ", ".join([f"{x.name} raised exception: {x.compiled_fn}" for x in sorted_compiled_gm_to_measurement]) raise RuntimeError(f"No compiler was able to compile the graph module, {err_msg}") return sorted_compiled_gm_to_measurement[0].compiled_fn + + +def translate_dtensor_ops(gm: torch.fx.GraphModule): + for node in gm.graph.nodes: + from thunder.torch.experimental.dtensor_torch_and_prims import ( + dtensor_from_local_prim, + dtensor_redistribute_prim, + dtensor_to_local_prim, + ) + + try: + closure_vars = inspect.getclosurevars(node.target) + + if "from_local" in node.target.__name__: + mesh = closure_vars.nonlocals["args_as_value"][0] + placements = closure_vars.nonlocals["args_as_value"][1] + + def dtensor_from_local_prim_wrapper(x, mesh=mesh, placements=placements): + return dtensor_from_local_prim(x, mesh, placements) + + dtensor_from_local_prim_wrapper.thunder_supported = True + node.target = dtensor_from_local_prim_wrapper + if "redistribute" in node.target.__name__: + kwargs = closure_vars.nonlocals["kwargs_as_value"] + placements = kwargs["placements"] + + def dtensor_redistribute_prim_wrapper(x, placements=placements): + return dtensor_redistribute_prim(x, placements=placements) + + dtensor_redistribute_prim_wrapper.thunder_supported = True + node.target = dtensor_redistribute_prim_wrapper + if "to_local" in node.target.__name__: + + def dtensor_to_local_prim_wrapper(x): + return dtensor_to_local_prim(x) + + dtensor_to_local_prim_wrapper.thunder_supported = True + node.target = dtensor_to_local_prim_wrapper + except Exception: + pass diff --git a/thunder/executors/nvfuserex_impl.py b/thunder/executors/nvfuserex_impl.py index 8cc0addca9..cdbd1ed1cc 100644 --- a/thunder/executors/nvfuserex_impl.py +++ b/thunder/executors/nvfuserex_impl.py @@ -368,6 +368,8 @@ def check_dtensor_tracing_and_runtime_metadata(inp): _, _, dtensor_metadata = y runtime_device_mesh_repr = dtensor_metadata[0] runtime_placements_repr = dtensor_metadata[1] + print(x.device_mesh, runtime_device_mesh_repr) + print(x.placements, runtime_placements_repr) return x.device_mesh == runtime_device_mesh_repr and x.placements == runtime_placements_repr utils.check( diff --git a/thunder/tests/distributed/test_dtensor.py b/thunder/tests/distributed/test_dtensor.py index b010644a47..ec44e3993c 100644 --- a/thunder/tests/distributed/test_dtensor.py +++ b/thunder/tests/distributed/test_dtensor.py @@ -14,6 +14,10 @@ from torch.distributed._tensor import DeviceMesh, distribute_tensor from torch.distributed.tensor.placement_types import Shard from torch.testing._internal.distributed._tensor.common_dtensor import DTensorConverter +from torch.distributed.tensor.parallel import ( + parallelize_module, + ColwiseParallel, +) from torch.testing._internal import common_utils @@ -22,6 +26,7 @@ from thunder.tests.utils import is_output_differentiable, filter_differentiable_outputs import thunder.core.dtypes as dtypes from thunder.core.pytree import tree_flatten +from thunder.dynamo import thunderfx # NOTE: We run all these similar functions seperately @@ -249,6 +254,40 @@ def fn(x): torch.testing.assert_close(actual, expected) + @common_utils.parametrize("jit_fn", (thunder.jit, thunderfx), name_fn=lambda jit_fn: jit_fn.__name__) + def test_dtensor_columnwise_parallel(self, jit_fn): + if jit_fn == thunder.jit: + # File "/opt/pytorch/lightning-thunder/thunder/core/jit_ext.py", line 444, in _general_jit_getattr_lookaside + # obj.original_value.__dict__, + # ^^^^^^^^^^^^^^^^^^^^^^^^^^^ + # AttributeError: 'object' object has no attribute '__dict__'. Did you mean: '__dir__'? + raise unittest.SkipTest("thunder.jit fails with AttributeError") + + num_devices = self.world_size + mesh = DeviceMesh("cuda", list(range(num_devices))) + dim_size = 16 + in_dtensor = torch.randn(dim_size, dim_size, requires_grad=False) + m = torch.nn.Linear(dim_size, dim_size) + m.requires_grad_(False) + + parallelized_model = parallelize_module(m, mesh, ColwiseParallel()) + + # `parallelize_module` sets `requires_grad` to True, set it to False again. + parallelized_model.requires_grad_(False) + + actual = parallelized_model(in_dtensor) + expected = m(in_dtensor) + torch.testing.assert_close(actual, expected) + + tmodel = jit_fn(parallelized_model, nv_enable_linear=True) + actual = tmodel(in_dtensor) + torch.testing.assert_close(actual, expected) + + if jit_fn == thunderfx: + assert len(tmodel._backend.subgraph_infos) == 1 + assert len(tmodel._backend.subgraph_infos[0].thunder_compiled_fns) == 1 + assert len(tmodel._backend.subgraph_infos[0].split_reasons) == 0 + @common_utils.parametrize( "op, executor", product(dtensor_supported_opinfos, tuple(executors_map.keys())), diff --git a/thunder/torch/experimental/dtensor_proxy.py b/thunder/torch/experimental/dtensor_proxy.py index 5243dcdebf..87239ac710 100644 --- a/thunder/torch/experimental/dtensor_proxy.py +++ b/thunder/torch/experimental/dtensor_proxy.py @@ -69,6 +69,41 @@ def placements(self): def device_mesh(self): return self.spec._o.device_mesh + @staticmethod + def from_local( + x, + mesh, + placements, + *, + run_check: bool = False, + shape: torch.Size | None = None, + stride: tuple[int, ...] | None = None, + ): + import thunder.torch.experimental.dtensor_torch_and_prims as dtensor_torch_and_prims + + res = dtensor_torch_and_prims.dtensor_from_local_prim( + x, mesh, placements, run_check=run_check, shape=shape, stride=stride + ) + return res + + def redistribute( + self, + device_mesh: "Optional[DeviceMesh]" = None, + placements: "Optional[Sequence[Placement]]" = None, + *, + async_op: bool = False, + ) -> "DTensor": + import thunder.torch.experimental.dtensor_torch_and_prims as dtensor_torch_and_prims + + res = dtensor_torch_and_prims.dtensor_redistribute_prim(self, device_mesh, placements, async_op=async_op) + return res + + def to_local(self, *, grad_placements: "Optional[Sequence[Placement]]" = None): + import thunder.torch.experimental.dtensor_torch_and_prims as dtensor_torch_and_prims + + res = dtensor_torch_and_prims.dtensor_to_local_prim(self, grad_placements=grad_placements) + return res + def replace(self, **changes): r"""Return a copy of the TensorProxy object with new values for the specified fields as given to the constructor as arguments. Valid keyword arguments are ``name``, ``history``, ``shape``, ``dtype``, ``device``, ``requires_grad``, ``distparallel_type``, ``thunder_fsdp_padding_size``. diff --git a/thunder/torch/experimental/dtensor_torch_and_prims.py b/thunder/torch/experimental/dtensor_torch_and_prims.py index 00202c62bb..12d7d07cc4 100644 --- a/thunder/torch/experimental/dtensor_torch_and_prims.py +++ b/thunder/torch/experimental/dtensor_torch_and_prims.py @@ -1,6 +1,7 @@ from functools import partial from collections.abc import Callable from enum import auto, Enum +from collections.abc import Sequence from thunder.torch import torchsymbol, TensorLike, register_function import thunder.torch as ltorch @@ -363,6 +364,100 @@ def dtensor_reciprocal(a: TensorLike) -> TensorLike: ) +if torch.distributed.is_available(): + from torch.distributed.tensor import DTensor + from torch.distributed.tensor.placement_types import Placement, DeviceMesh + + def dtensor_from_local_meta( + x, + mesh, + placements, + *, + run_check: bool = False, + shape: torch.Size | None = None, + stride: tuple[int, ...] | None = None, + ): + res = run_with_fake_tensor( + DTensor.from_local, x, mesh, placements, run_check=run_check, shape=shape, stride=stride + ) + from thunder.torch.experimental.dtensor_proxy import proxify_dtensor + + res = proxify_dtensor(res) + return res + + dtensor_from_local_prim = make_prim("dtensor_from_local", "dtensor_from_local", meta=dtensor_from_local_meta) + + dtensor_from_local_prim_impl = pytorchex.register_operator( + "dtensor_from_local", like=dtensor_from_local_prim, fn=DTensor.from_local + ) + + pytorchex.register_implementation(dtensor_from_local_prim, dtensor_from_local_prim_impl) + + @dtensor_torchsymbol(DTensor.from_local, id="dtensor.torch.from_local") + def dtensor_from_local( + x, + mesh, + placements, + *, + run_check: bool = False, + shape: torch.Size | None = None, + stride: tuple[int, ...] | None = None, + ) -> DTensorProxy | None: + return dtensor_from_local_prim(x, mesh, placements, run_check=run_check, shape=shape, stride=stride) + + def dtensor_redistribute_meta( + dtensor, + device_mesh: DeviceMesh | None = None, + placements: Sequence[Placement] | None = None, + *, + async_op: bool = False, + ) -> DTensorProxy | None: + res = run_with_fake_tensor(DTensor.redistribute, dtensor, device_mesh, placements, async_op=async_op) + from thunder.torch.experimental.dtensor_proxy import proxify_dtensor + + res = proxify_dtensor(res) + return res + + dtensor_redistribute_prim = make_prim( + "dtensor_redistribute", "dtensor_redistribute", meta=dtensor_redistribute_meta + ) + + dtensor_redistribute_prim_impl = pytorchex.register_operator( + "dtensor_redistribute", like=dtensor_redistribute_prim, fn=DTensor.redistribute + ) + + @dtensor_torchsymbol(DTensor.redistribute, id="dtensor.torch.redistribute") + def dtensor_redistribute( + dtensor, + device_mesh: DeviceMesh | None = None, + placements: Sequence[Placement] | None = None, + *, + async_op: bool = False, + ) -> DTensorProxy | None: + return dtensor_redistribute_prim(dtensor, device_mesh, placements, async_op=async_op) + + pytorchex.register_implementation(dtensor_redistribute_prim, dtensor_redistribute_prim_impl) + + def dtensor_to_local_meta(dtensor, *, grad_placements: Sequence[Placement] | None = None): + res = run_with_fake_tensor(DTensor.to_local, dtensor, grad_placements=grad_placements) + from thunder.core.proxies import proxy + + res = proxy(res) + return res + + dtensor_to_local_prim = make_prim("dtensor_to_local", "dtensor_to_local", meta=dtensor_to_local_meta) + + dtensor_to_local_prim_impl = pytorchex.register_operator( + "dtensor_to_local", like=dtensor_to_local_prim, fn=DTensor.to_local + ) + + pytorchex.register_implementation(dtensor_to_local_prim, dtensor_to_local_prim_impl) + + @dtensor_torchsymbol(DTensor.to_local, id="dtensor.torch.to_local") + def dtensor_to_local(dtensor, *, grad_placements: Sequence[Placement] | None = None) -> DTensorProxy | None: + return dtensor_to_local_prim(dtensor, grad_placements=grad_placements) + + def register_dtensor_torch_and_prims(): register_function_for_dtensor(torch.mul, ltorch.mul, dtensor_mul, is_method=True) register_function_for_dtensor(torch.reshape, ltorch.reshape, dtensor_reshape, is_method=True) From e5fc243956cda7aee0eed23e9c5767e54190bfb4 Mon Sep 17 00:00:00 2001 From: kshitij12345 Date: Fri, 3 Oct 2025 07:07:34 -0700 Subject: [PATCH 2/5] remove stray change --- thunder/executors/nvfuserex_impl.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/thunder/executors/nvfuserex_impl.py b/thunder/executors/nvfuserex_impl.py index cdbd1ed1cc..8cc0addca9 100644 --- a/thunder/executors/nvfuserex_impl.py +++ b/thunder/executors/nvfuserex_impl.py @@ -368,8 +368,6 @@ def check_dtensor_tracing_and_runtime_metadata(inp): _, _, dtensor_metadata = y runtime_device_mesh_repr = dtensor_metadata[0] runtime_placements_repr = dtensor_metadata[1] - print(x.device_mesh, runtime_device_mesh_repr) - print(x.placements, runtime_placements_repr) return x.device_mesh == runtime_device_mesh_repr and x.placements == runtime_placements_repr utils.check( From e23385e1202ea05798eb3a095ac61eda86926a55 Mon Sep 17 00:00:00 2001 From: kshitij12345 Date: Tue, 7 Oct 2025 03:54:44 -0700 Subject: [PATCH 3/5] add comment --- thunder/dynamo/utils.py | 35 +++++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/thunder/dynamo/utils.py b/thunder/dynamo/utils.py index 9e43e81354..d88f854ea3 100644 --- a/thunder/dynamo/utils.py +++ b/thunder/dynamo/utils.py @@ -1053,6 +1053,41 @@ def get_compiled_fn_and_timing(report, compile_fn, timer_fn): def translate_dtensor_ops(gm: torch.fx.GraphModule): + # We need this function because: + # + # For a program like: + # ``` + # model = nn.Linear(hidden_size, hidden_size, bias=False) + # parallel_model = parallelize_module(model, mesh, {"fc1": ColwiseParallel()}) + # model.fc1.weight.requires_grad = False + + # # parallelize_module will handle the conversion to DTensor + # i = torch.randn(hidden_size, hidden_size) + # ```` + # + # Dynamo captures an FX-Graph like: + # ``` + # def forward(self, L_x_: "f32[16, 16]", L_self_modules_fc1_parameters_weight_: "f32[16, 16]"): + # l_x_ = L_x_ + # l_self_modules_fc1_parameters_weight_ = L_self_modules_fc1_parameters_weight_ + # + # input_tensor: "f32[16, 16]" = torch__dynamo_variables_torch_prim_from_local(l_x_); l_x_ = None + # + # linear: "f32[16, 16]" = torch._C._nn.linear(input_tensor, l_self_modules_fc1_parameters_weight_, None); input_tensor = l_self_modules_fc1_parameters_weight_ = None + # + # outputs: "f32[16, 16]" = torch__dynamo_variables_tensor_prim_redistribute(linear); linear = None + # + # hook_result: "f32[16, 8]" = torch__dynamo_variables_tensor_prim_to_local(outputs); outputs = None + # return (hook_result,) + # ``` + # where: + # 1. In the FX Graph, the Tensor Parallel computation is decomposed into primitive operations such as `torch__dynamo_variables_torch_prim_from_local`, `torch__dynamo_variables_tensor_prim_redistribute`, and others. + # 2. It is important to note that these decompositions actually capture (close over) values such as `placements` and other metadata. + # For example, to understand the placements to which the output will be redistributed using `torch__dynamo_variables_tensor_prim_redistribute`, + # we need to use `inspect.getclosurevars(node.target)` to examine the values (like placements) that are captured and used during execution. + # The reference for where this closure is created can be found at: + # https://github.com/pytorch/pytorch/blob/0ab075a69e4577a60c4dcbff7bcc2ecd0a15ce46/torch/_dynamo/variables/tensor.py#L1186-L1210 + for node in gm.graph.nodes: from thunder.torch.experimental.dtensor_torch_and_prims import ( dtensor_from_local_prim, From 02d90efcbe9de650d2a3dcd54a67e1b505adae8a Mon Sep 17 00:00:00 2001 From: kshitij12345 Date: Tue, 7 Oct 2025 02:22:44 -0700 Subject: [PATCH 4/5] Enable MoE TP with thunderfx --- thunder/executors/nvfuserex_impl.py | 25 ++++++- thunder/tests/distributed/test_moe.py | 5 ++ .../experimental/dtensor_torch_and_prims.py | 73 ++++++++++--------- 3 files changed, 68 insertions(+), 35 deletions(-) diff --git a/thunder/executors/nvfuserex_impl.py b/thunder/executors/nvfuserex_impl.py index f51e303380..a5dadbbaba 100644 --- a/thunder/executors/nvfuserex_impl.py +++ b/thunder/executors/nvfuserex_impl.py @@ -536,6 +536,14 @@ def __call__(self, *args): if self.store_inputs: self.last_inputs = args + if dist.is_available(): + # When using DTensor with FXGraph, argument can be AsyncCollectiveTensor. + # Eg. https://github.com/pytorch/pytorch/blob/0ab075a69e4577a60c4dcbff7bcc2ecd0a15ce46/torch/distributed/tensor/parallel/style.py#L142-L145 + args = tuple( + arg.wait() if isinstance(arg, torch.distributed._functional_collectives.AsyncCollectiveTensor) else arg + for arg in args + ) + if dist.is_available() and any(isinstance(t, torch.distributed.tensor.DTensor) for t in args): with annotate_for_profile(self.name): output = execute_with_dtensors(fd, args) @@ -838,7 +846,22 @@ def _can_fuse_node(n: Node): cuda_in_or_out: bool = self.has_cuda_input_or_output(bsym) return can_fuse and cuda_in_or_out - return _can_fuse_node(a) and _can_fuse_node(b) + def bsym_input_types_match(a, b): + # NOTE: Don't allow creating a Fusion with mix of DTensor and Tensor inputs. + bsym_a: BoundSymbol = a.group_bsyms[0] + bsym_b: BoundSymbol = b.group_bsyms[0] + bsym_a_args_type = {type(arg) for arg in bsym_a.flat_proxy_args} + bsym_b_args_type = {type(arg) for arg in bsym_b.flat_proxy_args} + + if DTensorProxy in bsym_a_args_type: + assert bsym_a_args_type == {DTensorProxy} + return bsym_b_args_type == bsym_a_args_type + if DTensorProxy in bsym_b_args_type: + assert bsym_b_args_type == {DTensorProxy} + return bsym_a_args_type == bsym_b_args_type + return True + + return _can_fuse_node(a) and _can_fuse_node(b) and bsym_input_types_match(a, b) bound_symbol_groups = fuse_bound_symbols(trace, _should_fuse) diff --git a/thunder/tests/distributed/test_moe.py b/thunder/tests/distributed/test_moe.py index 3cbb233432..fd54d15820 100644 --- a/thunder/tests/distributed/test_moe.py +++ b/thunder/tests/distributed/test_moe.py @@ -19,6 +19,7 @@ import thunder.tests.llama4_moe as llama4_moe from thunder.tests.distributed.helper import DistributedParallelTestCase +from thunder.dynamo import thunderfx # Referred from torchtitan: https://github.com/pytorch/torchtitan/blob/827255bb/torchtitan/experiments/llama4/infra/expert_parallel.py#L25 @@ -170,3 +171,7 @@ def test_llama4_moe_distributed(self): expected = model(inp) torch.testing.assert_close(actual, expected, atol=1e-5, rtol=1e-5) + + tmodel = thunderfx(parallelized_model, nv_enable_linear=True) + actual = tmodel(inp) + torch.testing.assert_close(actual, expected, atol=1e-5, rtol=1e-5) diff --git a/thunder/torch/experimental/dtensor_torch_and_prims.py b/thunder/torch/experimental/dtensor_torch_and_prims.py index 4297c0f3e1..3d848d18e9 100644 --- a/thunder/torch/experimental/dtensor_torch_and_prims.py +++ b/thunder/torch/experimental/dtensor_torch_and_prims.py @@ -127,40 +127,6 @@ def handle_check_dtensor_spec_in_prologue(prim, prologue_trace, args) -> bool: return False -def dtensor_mul_meta(a, b): - output = run_with_fake_tensor(torch.mul, a, b) - local_tensor_proxy = TensorProxy(like=a.local_tensor) - spec = output._spec - spec_proxy = AnyProxy(spec, history=a.history) - return create_dtensor_proxy_from_proxies(local_tensor_proxy, spec_proxy, False) - - -dtensor_mul_prim = make_prim(DTensorPrimIDs.MUL, "dtensor_mul_prim", meta=dtensor_mul_meta) - -dtensor_mul_prim_impl = pytorchex.register_operator("dtensor_mul_prim", like=dtensor_mul_prim, fn=torch.mul) - -pytorchex.register_implementation(dtensor_mul_prim, dtensor_mul_prim_impl) - - -def _dtensor_mul_prim_grad(a: TensorLike, b: TensorLike) -> TensorLike: - fwd = dtensor_mul_prim(a, b) - - g = get_grad(fwd) - a_grad = dtensor_mul_prim(b, g) - b_grad = dtensor_mul_prim(a, g) - put_grads((a, b), (a_grad, b_grad)) - - return fwd - - -register_grad(dtensor_mul_prim, _dtensor_mul_prim_grad) - - -@dtensor_torchsymbol(torch.mul, id="dtensor.torch.mul") -def dtensor_mul(a: TensorLike, b: TensorLike) -> TensorLike: - return dtensor_mul_prim(a, b) - - def dtensor_reshape_meta(a, shape): output = run_with_fake_tensor(torch.reshape, a, shape) local_tensor_proxy = TensorProxy( @@ -507,6 +473,45 @@ def dtensor_add(a: TensorLike, b: TensorLike) -> TensorLike: ) +def dtensor_mul_meta(a, b): + output = run_with_fake_tensor(torch.mul, a, b) + local_tensor_proxy = TensorProxy(like=a.local_tensor) + spec = output._spec + spec_proxy = AnyProxy(spec, history=a.history) + return create_dtensor_proxy_from_proxies(local_tensor_proxy, spec_proxy, False) + + +dtensor_mul_prim = make_prim(DTensorPrimIDs.MUL, "dtensor_mul_prim", meta=dtensor_mul_meta) + +dtensor_mul_prim_impl = pytorchex.register_operator("dtensor_mul_prim", like=dtensor_mul_prim, fn=torch.mul) + +pytorchex.register_implementation(dtensor_mul_prim, dtensor_mul_prim_impl) + + +def _dtensor_mul_prim_grad(a: TensorLike, b: TensorLike) -> TensorLike: + fwd = dtensor_mul_prim(a, b) + + g = get_grad(fwd) + a_grad = dtensor_mul_prim(b, g) + b_grad = dtensor_mul_prim(a, g) + put_grads((a, b), (a_grad, b_grad)) + + return fwd + + +register_grad(dtensor_mul_prim, _dtensor_mul_prim_grad) + + +@dtensor_torchsymbol(torch.mul, id="dtensor.torch.mul") +def dtensor_mul(a: TensorLike, b: TensorLike) -> TensorLike: + return _elementwise_binary_wrapper( + a, + b, + prim=dtensor_mul_prim, + type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, + ) + + if LooseVersion(torch.__version__) >= "2.8": def dtensor_grouped_mm_meta(a, b, offsets): From 6bed289666f0741f817e9197abeba2dda834e0a8 Mon Sep 17 00:00:00 2001 From: kshitij12345 Date: Tue, 7 Oct 2025 05:15:31 -0700 Subject: [PATCH 5/5] add DTensor silu --- thunder/tests/distributed/test_dtensor.py | 9 +++++++++ thunder/tests/distributed/test_moe.py | 8 ++++++-- .../experimental/dtensor_torch_and_prims.py | 16 ++++++++++++++++ 3 files changed, 31 insertions(+), 2 deletions(-) diff --git a/thunder/tests/distributed/test_dtensor.py b/thunder/tests/distributed/test_dtensor.py index e50133a184..d777133e27 100644 --- a/thunder/tests/distributed/test_dtensor.py +++ b/thunder/tests/distributed/test_dtensor.py @@ -125,6 +125,15 @@ def __init__( # Ref:https://github.com/NVIDIA/Fuser/issues/5314 skip_for_executor=("nvfuser",), ), + DTensorOpInfo( + name="silu", + op=torch.nn.functional.silu, + torch_reference=torch.nn.functional.silu, + supports_grad=False, + sample_inputs=get_opinfo("silu").sample_inputs, + # Ref:https://github.com/NVIDIA/Fuser/pull/5124 + skip_noncontiguous_for_executor=("nvfuser",), + ), ) skip_opinfos = ( diff --git a/thunder/tests/distributed/test_moe.py b/thunder/tests/distributed/test_moe.py index fd54d15820..6a98ff1e8a 100644 --- a/thunder/tests/distributed/test_moe.py +++ b/thunder/tests/distributed/test_moe.py @@ -172,6 +172,10 @@ def test_llama4_moe_distributed(self): torch.testing.assert_close(actual, expected, atol=1e-5, rtol=1e-5) - tmodel = thunderfx(parallelized_model, nv_enable_linear=True) + tmodel = thunderfx(model, nv_enable_linear=True, nv_enable_scatter=True) actual = tmodel(inp) - torch.testing.assert_close(actual, expected, atol=1e-5, rtol=1e-5) + + assert len(tmodel._backend.subgraph_infos) == 1 + assert len(tmodel._backend.subgraph_infos[0].split_reasons) == 0 + + torch.testing.assert_close(actual, expected, atol=1e-2, rtol=1e-2) diff --git a/thunder/torch/experimental/dtensor_torch_and_prims.py b/thunder/torch/experimental/dtensor_torch_and_prims.py index 3d848d18e9..e31d2562b7 100644 --- a/thunder/torch/experimental/dtensor_torch_and_prims.py +++ b/thunder/torch/experimental/dtensor_torch_and_prims.py @@ -540,6 +540,21 @@ def dtensor_grouped_mm(a: TensorLike, b: TensorLike, offsets: TensorLike, *, bia return dtensor_grouped_mm_prim(a, b, offsets) +@dtensor_torchsymbol(torch.nn.functional.silu, id="dtensor.torch.nn.functional.silu") +def dtensor_silu(a: TensorLike, inplace: bool = False) -> TensorLike: + assert not inplace, "inplace is not supported" + + def sigmoid(x): + computation_dtype, result_dtype = utils.elementwise_type_promotion( + x, type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT + ) + x = dtensor_convert_element_type_prim(x, computation_dtype) + result = dtensor_reciprocal(dtensor_add(dtensor_exp(-x), 1.0)) + return dtensor_convert_element_type_prim(result, result_dtype) + + return a * sigmoid(a) + + def register_dtensor_torch_and_prims(): register_function_for_dtensor(torch.add, ltorch.add, dtensor_add, is_method=True) register_function_for_dtensor(torch.mul, ltorch.mul, dtensor_mul, is_method=True) @@ -548,5 +563,6 @@ def register_dtensor_torch_and_prims(): register_function_for_dtensor(torch.exp, ltorch.exp, dtensor_exp, is_method=True) register_function_for_dtensor(torch.neg, ltorch.neg, dtensor_neg, is_method=True) register_function_for_dtensor(torch.reciprocal, ltorch.reciprocal, dtensor_reciprocal, is_method=True) + register_function_for_dtensor(torch.nn.functional.silu, ltorch.silu, dtensor_silu, is_method=False) if LooseVersion(torch.__version__) >= "2.8": register_function_for_dtensor(torch._grouped_mm, ltorch._grouped_mm, dtensor_grouped_mm, is_method=False)