diff --git a/thunder/executors/nvfuserex_impl.py b/thunder/executors/nvfuserex_impl.py index f51e303380..a5dadbbaba 100644 --- a/thunder/executors/nvfuserex_impl.py +++ b/thunder/executors/nvfuserex_impl.py @@ -536,6 +536,14 @@ def __call__(self, *args): if self.store_inputs: self.last_inputs = args + if dist.is_available(): + # When using DTensor with FXGraph, argument can be AsyncCollectiveTensor. + # Eg. https://github.com/pytorch/pytorch/blob/0ab075a69e4577a60c4dcbff7bcc2ecd0a15ce46/torch/distributed/tensor/parallel/style.py#L142-L145 + args = tuple( + arg.wait() if isinstance(arg, torch.distributed._functional_collectives.AsyncCollectiveTensor) else arg + for arg in args + ) + if dist.is_available() and any(isinstance(t, torch.distributed.tensor.DTensor) for t in args): with annotate_for_profile(self.name): output = execute_with_dtensors(fd, args) @@ -838,7 +846,22 @@ def _can_fuse_node(n: Node): cuda_in_or_out: bool = self.has_cuda_input_or_output(bsym) return can_fuse and cuda_in_or_out - return _can_fuse_node(a) and _can_fuse_node(b) + def bsym_input_types_match(a, b): + # NOTE: Don't allow creating a Fusion with mix of DTensor and Tensor inputs. + bsym_a: BoundSymbol = a.group_bsyms[0] + bsym_b: BoundSymbol = b.group_bsyms[0] + bsym_a_args_type = {type(arg) for arg in bsym_a.flat_proxy_args} + bsym_b_args_type = {type(arg) for arg in bsym_b.flat_proxy_args} + + if DTensorProxy in bsym_a_args_type: + assert bsym_a_args_type == {DTensorProxy} + return bsym_b_args_type == bsym_a_args_type + if DTensorProxy in bsym_b_args_type: + assert bsym_b_args_type == {DTensorProxy} + return bsym_a_args_type == bsym_b_args_type + return True + + return _can_fuse_node(a) and _can_fuse_node(b) and bsym_input_types_match(a, b) bound_symbol_groups = fuse_bound_symbols(trace, _should_fuse) diff --git a/thunder/tests/distributed/test_moe.py b/thunder/tests/distributed/test_moe.py index 3cbb233432..fd54d15820 100644 --- a/thunder/tests/distributed/test_moe.py +++ b/thunder/tests/distributed/test_moe.py @@ -19,6 +19,7 @@ import thunder.tests.llama4_moe as llama4_moe from thunder.tests.distributed.helper import DistributedParallelTestCase +from thunder.dynamo import thunderfx # Referred from torchtitan: https://github.com/pytorch/torchtitan/blob/827255bb/torchtitan/experiments/llama4/infra/expert_parallel.py#L25 @@ -170,3 +171,7 @@ def test_llama4_moe_distributed(self): expected = model(inp) torch.testing.assert_close(actual, expected, atol=1e-5, rtol=1e-5) + + tmodel = thunderfx(parallelized_model, nv_enable_linear=True) + actual = tmodel(inp) + torch.testing.assert_close(actual, expected, atol=1e-5, rtol=1e-5) diff --git a/thunder/torch/experimental/dtensor_torch_and_prims.py b/thunder/torch/experimental/dtensor_torch_and_prims.py index 4297c0f3e1..3d848d18e9 100644 --- a/thunder/torch/experimental/dtensor_torch_and_prims.py +++ b/thunder/torch/experimental/dtensor_torch_and_prims.py @@ -127,40 +127,6 @@ def handle_check_dtensor_spec_in_prologue(prim, prologue_trace, args) -> bool: return False -def dtensor_mul_meta(a, b): - output = run_with_fake_tensor(torch.mul, a, b) - local_tensor_proxy = TensorProxy(like=a.local_tensor) - spec = output._spec - spec_proxy = AnyProxy(spec, history=a.history) - return create_dtensor_proxy_from_proxies(local_tensor_proxy, spec_proxy, False) - - -dtensor_mul_prim = make_prim(DTensorPrimIDs.MUL, "dtensor_mul_prim", meta=dtensor_mul_meta) - -dtensor_mul_prim_impl = pytorchex.register_operator("dtensor_mul_prim", like=dtensor_mul_prim, fn=torch.mul) - -pytorchex.register_implementation(dtensor_mul_prim, dtensor_mul_prim_impl) - - -def _dtensor_mul_prim_grad(a: TensorLike, b: TensorLike) -> TensorLike: - fwd = dtensor_mul_prim(a, b) - - g = get_grad(fwd) - a_grad = dtensor_mul_prim(b, g) - b_grad = dtensor_mul_prim(a, g) - put_grads((a, b), (a_grad, b_grad)) - - return fwd - - -register_grad(dtensor_mul_prim, _dtensor_mul_prim_grad) - - -@dtensor_torchsymbol(torch.mul, id="dtensor.torch.mul") -def dtensor_mul(a: TensorLike, b: TensorLike) -> TensorLike: - return dtensor_mul_prim(a, b) - - def dtensor_reshape_meta(a, shape): output = run_with_fake_tensor(torch.reshape, a, shape) local_tensor_proxy = TensorProxy( @@ -507,6 +473,45 @@ def dtensor_add(a: TensorLike, b: TensorLike) -> TensorLike: ) +def dtensor_mul_meta(a, b): + output = run_with_fake_tensor(torch.mul, a, b) + local_tensor_proxy = TensorProxy(like=a.local_tensor) + spec = output._spec + spec_proxy = AnyProxy(spec, history=a.history) + return create_dtensor_proxy_from_proxies(local_tensor_proxy, spec_proxy, False) + + +dtensor_mul_prim = make_prim(DTensorPrimIDs.MUL, "dtensor_mul_prim", meta=dtensor_mul_meta) + +dtensor_mul_prim_impl = pytorchex.register_operator("dtensor_mul_prim", like=dtensor_mul_prim, fn=torch.mul) + +pytorchex.register_implementation(dtensor_mul_prim, dtensor_mul_prim_impl) + + +def _dtensor_mul_prim_grad(a: TensorLike, b: TensorLike) -> TensorLike: + fwd = dtensor_mul_prim(a, b) + + g = get_grad(fwd) + a_grad = dtensor_mul_prim(b, g) + b_grad = dtensor_mul_prim(a, g) + put_grads((a, b), (a_grad, b_grad)) + + return fwd + + +register_grad(dtensor_mul_prim, _dtensor_mul_prim_grad) + + +@dtensor_torchsymbol(torch.mul, id="dtensor.torch.mul") +def dtensor_mul(a: TensorLike, b: TensorLike) -> TensorLike: + return _elementwise_binary_wrapper( + a, + b, + prim=dtensor_mul_prim, + type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, + ) + + if LooseVersion(torch.__version__) >= "2.8": def dtensor_grouped_mm_meta(a, b, offsets):