Skip to content

Commit 02cae54

Browse files
DTensor: skip nvfuser test if NVFUSER_DISABLE=multidevice is set (#2724)
Co-authored-by: beverlylytle <[email protected]>
1 parent ad15cc2 commit 02cae54

File tree

2 files changed

+9
-0
lines changed

2 files changed

+9
-0
lines changed

thunder/tests/distributed/test_dtensor.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from itertools import product
33
from collections.abc import Sequence
44
from looseversion import LooseVersion
5+
import os
56

67
import pytest
78
import torch
@@ -341,6 +342,9 @@ def test_dtensor_columnwise_parallel(self, jit_fn):
341342
],
342343
)
343344
def test_dtensor_grouped_mm(self, executor, input_shardings):
345+
if executor == "nvfuser" and "multidevice" in os.environ.get("NVFUSER_DISABLE", ""):
346+
raise unittest.SkipTest("test_dtensor_grouped_mm: nvfuser multidevice is disabled")
347+
344348
if LooseVersion(torch.__version__) < "2.8":
345349
raise unittest.SkipTest("test_dtensor_grouped_mm: torch._grouped_mm is not available in torch < 2.8")
346350

thunder/tests/distributed/test_moe.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from functools import partial
22
import copy
3+
import os
4+
import unittest
35

46
import torch
57
from torch.distributed.tensor.placement_types import Placement, Shard, Replicate
@@ -143,6 +145,9 @@ def parallelize_moe_model(model: llama4_moe.Llama4MoE, device_mesh: torch.distri
143145

144146
class TestLlama4MoEDistributed(DistributedParallelTestCase):
145147
def test_llama4_moe_distributed(self):
148+
if "multidevice" in os.environ.get("NVFUSER_DISABLE", ""):
149+
raise unittest.SkipTest("test_llama4_moe_distributed: nvfuser multidevice is disabled")
150+
146151
# Get world size
147152
world_size = self.world_size
148153
device = f"cuda:{self.rank}"

0 commit comments

Comments
 (0)