Skip to content

Commit c8ef0c4

Browse files
authored
[DTensor] Add torch symbol and prim for _grouped_mm (#2503)
1 parent ad8ed75 commit c8ef0c4

File tree

4 files changed

+91
-1
lines changed

4 files changed

+91
-1
lines changed

thunder/executors/nvfuserex_impl.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3191,6 +3191,7 @@ def _grouped_mm_transform(
31913191

31923192

31933193
register_supported(prims._grouped_mm, _grouped_mm_transform, _grouped_mm_check)
3194+
register_supported(DTensorPrimIDs._GROUPED_MM, _grouped_mm_transform, _grouped_mm_check)
31943195

31953196

31963197
def _cumsum_check(a: TensorProxy, dim: int, /, dtype: dtypes.dtype | None = None) -> bool:

thunder/tests/distributed/helper.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,13 @@ def _run(cls, rank, test_name, file_name, pipe, *, fake_pg=False):
187187

188188
local_rank = self.rank % torch.cuda.device_count()
189189
torch.cuda.set_device(local_rank)
190+
191+
# nvFuser Multi-GPU expects these environment variables to be set
190192
os.environ["LOCAL_RANK"] = str(local_rank)
193+
# We only have single node tests, so `LOCAL_WORLD_SIZE` is the same as `WORLD_SIZE`
194+
os.environ["LOCAL_WORLD_SIZE"] = str(self.world_size)
195+
os.environ["RANK"] = str(self.rank)
196+
os.environ["WORLD_SIZE"] = str(self.world_size)
191197

192198
torch.distributed.barrier()
193199
try:

thunder/tests/distributed/test_dtensor.py

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import unittest
22
from itertools import product
33
from collections.abc import Sequence
4+
from looseversion import LooseVersion
45

56
import pytest
67
import torch
@@ -12,7 +13,7 @@
1213

1314
from thunder.tests.distributed.helper import DistributedParallelTestCase
1415
from torch.distributed._tensor import DeviceMesh, distribute_tensor
15-
from torch.distributed.tensor.placement_types import Shard
16+
from torch.distributed.tensor.placement_types import Shard, Replicate
1617
from torch.testing._internal.distributed._tensor.common_dtensor import DTensorConverter
1718

1819
from torch.testing._internal import common_utils
@@ -249,6 +250,56 @@ def fn(x):
249250

250251
torch.testing.assert_close(actual, expected)
251252

253+
@common_utils.parametrize("executor", tuple(executors_map.keys()))
254+
@common_utils.parametrize(
255+
"input_shardings",
256+
[
257+
(
258+
[
259+
Shard(
260+
-1,
261+
)
262+
],
263+
[
264+
Shard(1),
265+
],
266+
[Replicate()],
267+
),
268+
],
269+
)
270+
def test_dtensor_grouped_mm(self, executor, input_shardings):
271+
if LooseVersion(torch.__version__) < "2.8":
272+
raise unittest.SkipTest("test_dtensor_grouped_mm: torch._grouped_mm is not available in torch < 2.8")
273+
274+
num_devices = self.world_size
275+
mesh = DeviceMesh("cuda", list(range(num_devices)))
276+
277+
if (torch.cuda.get_device_capability() < (9, 0)) and executor == "torch":
278+
raise unittest.SkipTest(
279+
"test_dtensor_grouped_mm: torch._grouped_mm doesn't support device capability < 9.0"
280+
)
281+
282+
M = 16
283+
N = 64
284+
K = 32
285+
G = 2
286+
287+
inp_shard, w_shard, offsets_shard = input_shardings
288+
in_dtensor = distribute_tensor(torch.randn(M, K, requires_grad=False, dtype=torch.bfloat16), mesh, inp_shard)
289+
w_dtensor = distribute_tensor(torch.randn(G, K, N, requires_grad=False, dtype=torch.bfloat16), mesh, w_shard)
290+
offsets_dtensor = distribute_tensor(torch.tensor([0, 16], dtype=torch.int32), mesh, offsets_shard)
291+
292+
tfn = thunder.jit(torch._grouped_mm, executors=executors_map[executor].executors_list())
293+
294+
tfn(in_dtensor, w_dtensor, offsets_dtensor)
295+
296+
trcs = thunder.last_traces(tfn)
297+
init_trc = trcs[0]
298+
299+
from thunder.torch.experimental.dtensor_torch_and_prims import dtensor_grouped_mm
300+
301+
assert any(bsym.sym == dtensor_grouped_mm for bsym in init_trc.bound_symbols)
302+
252303
@common_utils.parametrize(
253304
"op, executor",
254305
product(dtensor_supported_opinfos, tuple(executors_map.keys())),

thunder/torch/experimental/dtensor_torch_and_prims.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from functools import partial
22
from collections.abc import Callable
33
from enum import auto, Enum
4+
from looseversion import LooseVersion
45

56
from thunder.torch import torchsymbol, TensorLike, register_function
67
import thunder.torch as ltorch
@@ -36,6 +37,7 @@ class DTensorPrimIDs(Enum):
3637
RESHAPE = auto()
3738
CONVERT_ELEMENT_TYPE = auto()
3839
BROADCAST_IN_DIM = auto()
40+
_GROUPED_MM = auto()
3941
EXP = auto()
4042
LINEAR = auto()
4143
NEG = auto()
@@ -363,10 +365,40 @@ def dtensor_reciprocal(a: TensorLike) -> TensorLike:
363365
)
364366

365367

368+
if LooseVersion(torch.__version__) >= "2.8":
369+
370+
def dtensor_grouped_mm_meta(a, b, offsets):
371+
output = run_with_fake_tensor(torch._grouped_mm, a, b, offsets)
372+
local_tensor_proxy = TensorProxy(
373+
like=a.local_tensor, dtype=dtypes.to_dtype(output._local_tensor.dtype), shape=output._local_tensor.shape
374+
)
375+
spec = output._spec
376+
spec_proxy = AnyProxy(spec, history=a.history)
377+
return create_dtensor_proxy_from_proxies(local_tensor_proxy, spec_proxy, False)
378+
379+
dtensor_grouped_mm_prim = make_prim(
380+
DTensorPrimIDs._GROUPED_MM, "dtensor_grouped_mm_prim", meta=dtensor_grouped_mm_meta
381+
)
382+
383+
dtensor_grouped_mm_prim_impl = pytorchex.register_operator(
384+
"dtensor_grouped_mm_prim", like=dtensor_grouped_mm_prim, fn=torch._grouped_mm
385+
)
386+
387+
pytorchex.register_implementation(dtensor_grouped_mm_prim, dtensor_grouped_mm_prim_impl)
388+
389+
@dtensor_torchsymbol(torch._grouped_mm, id="dtensor.torch._grouped_mm")
390+
def dtensor_grouped_mm(a: TensorLike, b: TensorLike, offsets: TensorLike, *, bias=None, dtype=None) -> TensorLike:
391+
assert bias is None, "bias is not supported"
392+
assert dtype is None, "dtype is not supported"
393+
return dtensor_grouped_mm_prim(a, b, offsets)
394+
395+
366396
def register_dtensor_torch_and_prims():
367397
register_function_for_dtensor(torch.mul, ltorch.mul, dtensor_mul, is_method=True)
368398
register_function_for_dtensor(torch.reshape, ltorch.reshape, dtensor_reshape, is_method=True)
369399
register_function_for_dtensor(torch.nn.functional.linear, ltorch.linear, dtensor_linear, is_method=False)
370400
register_function_for_dtensor(torch.exp, ltorch.exp, dtensor_exp, is_method=True)
371401
register_function_for_dtensor(torch.neg, ltorch.neg, dtensor_neg, is_method=True)
372402
register_function_for_dtensor(torch.reciprocal, ltorch.reciprocal, dtensor_reciprocal, is_method=True)
403+
if LooseVersion(torch.__version__) >= "2.8":
404+
register_function_for_dtensor(torch._grouped_mm, ltorch._grouped_mm, dtensor_grouped_mm, is_method=False)

0 commit comments

Comments
 (0)