Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 24 additions & 1 deletion thunder/executors/nvfuserex_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,6 +536,14 @@ def __call__(self, *args):
if self.store_inputs:
self.last_inputs = args

if dist.is_available():
# When using DTensor with FXGraph, argument can be AsyncCollectiveTensor.
# Eg. https://github.com/pytorch/pytorch/blob/0ab075a69e4577a60c4dcbff7bcc2ecd0a15ce46/torch/distributed/tensor/parallel/style.py#L142-L145
args = tuple(
arg.wait() if isinstance(arg, torch.distributed._functional_collectives.AsyncCollectiveTensor) else arg
for arg in args
)

if dist.is_available() and any(isinstance(t, torch.distributed.tensor.DTensor) for t in args):
with annotate_for_profile(self.name):
output = execute_with_dtensors(fd, args)
Expand Down Expand Up @@ -838,7 +846,22 @@ def _can_fuse_node(n: Node):
cuda_in_or_out: bool = self.has_cuda_input_or_output(bsym)
return can_fuse and cuda_in_or_out

return _can_fuse_node(a) and _can_fuse_node(b)
def bsym_input_types_match(a, b):
# NOTE: Don't allow creating a Fusion with mix of DTensor and Tensor inputs.
bsym_a: BoundSymbol = a.group_bsyms[0]
bsym_b: BoundSymbol = b.group_bsyms[0]
bsym_a_args_type = {type(arg) for arg in bsym_a.flat_proxy_args}
bsym_b_args_type = {type(arg) for arg in bsym_b.flat_proxy_args}

if DTensorProxy in bsym_a_args_type:
assert bsym_a_args_type == {DTensorProxy}
return bsym_b_args_type == bsym_a_args_type
if DTensorProxy in bsym_b_args_type:
assert bsym_b_args_type == {DTensorProxy}
return bsym_a_args_type == bsym_b_args_type
return True

return _can_fuse_node(a) and _can_fuse_node(b) and bsym_input_types_match(a, b)

bound_symbol_groups = fuse_bound_symbols(trace, _should_fuse)

Expand Down
9 changes: 9 additions & 0 deletions thunder/tests/distributed/test_dtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,15 @@ def __init__(
# Ref:https://github.com/NVIDIA/Fuser/issues/5314
skip_for_executor=("nvfuser",),
),
DTensorOpInfo(
name="silu",
op=torch.nn.functional.silu,
torch_reference=torch.nn.functional.silu,
supports_grad=False,
sample_inputs=get_opinfo("silu").sample_inputs,
# Ref:https://github.com/NVIDIA/Fuser/pull/5124
skip_noncontiguous_for_executor=("nvfuser",),
),
)

skip_opinfos = (
Expand Down
9 changes: 9 additions & 0 deletions thunder/tests/distributed/test_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import thunder.tests.llama4_moe as llama4_moe
from thunder.tests.distributed.helper import DistributedParallelTestCase
from thunder.dynamo import thunderfx


# Referred from torchtitan: https://github.com/pytorch/torchtitan/blob/827255bb/torchtitan/experiments/llama4/infra/expert_parallel.py#L25
Expand Down Expand Up @@ -170,3 +171,11 @@ def test_llama4_moe_distributed(self):
expected = model(inp)

torch.testing.assert_close(actual, expected, atol=1e-5, rtol=1e-5)

tmodel = thunderfx(model, nv_enable_linear=True, nv_enable_scatter=True)
actual = tmodel(inp)

assert len(tmodel._backend.subgraph_infos) == 1
assert len(tmodel._backend.subgraph_infos[0].split_reasons) == 0

torch.testing.assert_close(actual, expected, atol=1e-2, rtol=1e-2)
89 changes: 55 additions & 34 deletions thunder/torch/experimental/dtensor_torch_and_prims.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,40 +127,6 @@ def handle_check_dtensor_spec_in_prologue(prim, prologue_trace, args) -> bool:
return False


def dtensor_mul_meta(a, b):
output = run_with_fake_tensor(torch.mul, a, b)
local_tensor_proxy = TensorProxy(like=a.local_tensor)
spec = output._spec
spec_proxy = AnyProxy(spec, history=a.history)
return create_dtensor_proxy_from_proxies(local_tensor_proxy, spec_proxy, False)


dtensor_mul_prim = make_prim(DTensorPrimIDs.MUL, "dtensor_mul_prim", meta=dtensor_mul_meta)

dtensor_mul_prim_impl = pytorchex.register_operator("dtensor_mul_prim", like=dtensor_mul_prim, fn=torch.mul)

pytorchex.register_implementation(dtensor_mul_prim, dtensor_mul_prim_impl)


def _dtensor_mul_prim_grad(a: TensorLike, b: TensorLike) -> TensorLike:
fwd = dtensor_mul_prim(a, b)

g = get_grad(fwd)
a_grad = dtensor_mul_prim(b, g)
b_grad = dtensor_mul_prim(a, g)
put_grads((a, b), (a_grad, b_grad))

return fwd


register_grad(dtensor_mul_prim, _dtensor_mul_prim_grad)


@dtensor_torchsymbol(torch.mul, id="dtensor.torch.mul")
def dtensor_mul(a: TensorLike, b: TensorLike) -> TensorLike:
return dtensor_mul_prim(a, b)


def dtensor_reshape_meta(a, shape):
output = run_with_fake_tensor(torch.reshape, a, shape)
local_tensor_proxy = TensorProxy(
Expand Down Expand Up @@ -507,6 +473,45 @@ def dtensor_add(a: TensorLike, b: TensorLike) -> TensorLike:
)


def dtensor_mul_meta(a, b):
output = run_with_fake_tensor(torch.mul, a, b)
local_tensor_proxy = TensorProxy(like=a.local_tensor)
spec = output._spec
spec_proxy = AnyProxy(spec, history=a.history)
return create_dtensor_proxy_from_proxies(local_tensor_proxy, spec_proxy, False)


dtensor_mul_prim = make_prim(DTensorPrimIDs.MUL, "dtensor_mul_prim", meta=dtensor_mul_meta)

dtensor_mul_prim_impl = pytorchex.register_operator("dtensor_mul_prim", like=dtensor_mul_prim, fn=torch.mul)

pytorchex.register_implementation(dtensor_mul_prim, dtensor_mul_prim_impl)


def _dtensor_mul_prim_grad(a: TensorLike, b: TensorLike) -> TensorLike:
fwd = dtensor_mul_prim(a, b)

g = get_grad(fwd)
a_grad = dtensor_mul_prim(b, g)
b_grad = dtensor_mul_prim(a, g)
put_grads((a, b), (a_grad, b_grad))

return fwd


register_grad(dtensor_mul_prim, _dtensor_mul_prim_grad)


@dtensor_torchsymbol(torch.mul, id="dtensor.torch.mul")
def dtensor_mul(a: TensorLike, b: TensorLike) -> TensorLike:
return _elementwise_binary_wrapper(
a,
b,
prim=dtensor_mul_prim,
type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
)


if LooseVersion(torch.__version__) >= "2.8":

def dtensor_grouped_mm_meta(a, b, offsets):
Expand Down Expand Up @@ -535,6 +540,21 @@ def dtensor_grouped_mm(a: TensorLike, b: TensorLike, offsets: TensorLike, *, bia
return dtensor_grouped_mm_prim(a, b, offsets)


@dtensor_torchsymbol(torch.nn.functional.silu, id="dtensor.torch.nn.functional.silu")
def dtensor_silu(a: TensorLike, inplace: bool = False) -> TensorLike:
assert not inplace, "inplace is not supported"

def sigmoid(x):
computation_dtype, result_dtype = utils.elementwise_type_promotion(
x, type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
)
x = dtensor_convert_element_type_prim(x, computation_dtype)
result = dtensor_reciprocal(dtensor_add(dtensor_exp(-x), 1.0))
return dtensor_convert_element_type_prim(result, result_dtype)

return a * sigmoid(a)


def register_dtensor_torch_and_prims():
register_function_for_dtensor(torch.add, ltorch.add, dtensor_add, is_method=True)
register_function_for_dtensor(torch.mul, ltorch.mul, dtensor_mul, is_method=True)
Expand All @@ -543,5 +563,6 @@ def register_dtensor_torch_and_prims():
register_function_for_dtensor(torch.exp, ltorch.exp, dtensor_exp, is_method=True)
register_function_for_dtensor(torch.neg, ltorch.neg, dtensor_neg, is_method=True)
register_function_for_dtensor(torch.reciprocal, ltorch.reciprocal, dtensor_reciprocal, is_method=True)
register_function_for_dtensor(torch.nn.functional.silu, ltorch.silu, dtensor_silu, is_method=False)
if LooseVersion(torch.__version__) >= "2.8":
register_function_for_dtensor(torch._grouped_mm, ltorch._grouped_mm, dtensor_grouped_mm, is_method=False)
Loading