Skip to content

Commit 42a5e48

Browse files
committed
[Pending multi device CI] Add symmetric memory sync test.
stack-info: PR: #375, branch: joydddd/stack/18
1 parent cee26aa commit 42a5e48

File tree

1 file changed

+86
-0
lines changed

1 file changed

+86
-0
lines changed

test/test_distributed.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
from __future__ import annotations
2+
3+
import torch
4+
import torch.distributed as dist
5+
import torch.distributed._symmetric_memory as symm_mem
6+
from torch.testing._internal.common_distributed import MultiProcessTestCase
7+
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
8+
from torch.testing._internal.common_utils import instantiate_parametrized_tests
9+
from torch.testing._internal.common_utils import run_tests
10+
11+
import helion
12+
from helion._testing import code_and_output
13+
import helion.language as hl
14+
15+
16+
@helion.jit
17+
def symm_mem_sync_kernel(
18+
remote_signal_pad_ptrs: torch.Tensor,
19+
local_signal_pad: torch.Tensor,
20+
rank: hl.constexpr,
21+
) -> None:
22+
N, world_size = local_signal_pad.size()
23+
world_size = hl.specialize(world_size)
24+
for n in hl.grid(N):
25+
ptr_tile = remote_signal_pad_ptrs[:]
26+
multicast_signalpad = hl.multicast_like(local_signal_pad, ptr_tile)
27+
hl.signal(multicast_signalpad, [n, rank], signal=1, wait_for=0, scope="sys")
28+
for world in hl.tile(world_size, block_size=world_size):
29+
hl.wait(local_signal_pad, [n, world], signal=1, update=0, scope="sys")
30+
31+
32+
@instantiate_parametrized_tests
33+
class SymmMemBarrier(MultiProcessTestCase):
34+
def setUp(self) -> None:
35+
super().setUp()
36+
self._spawn_processes()
37+
38+
@property
39+
def world_size(self) -> int:
40+
# world_size > 2 is needed to verify accumulation order
41+
return 4
42+
43+
@property
44+
def device(self) -> torch.device:
45+
return torch.device(f"cuda:{self.rank}")
46+
47+
def _init_process(self):
48+
torch.cuda.set_device(self.device)
49+
store = dist.FileStore(self.file_name, self.world_size)
50+
dist.init_process_group(
51+
backend="nccl",
52+
world_size=self.world_size,
53+
rank=self.rank,
54+
store=store,
55+
)
56+
torch.manual_seed(42 + self.rank)
57+
58+
@skip_if_lt_x_gpu(4)
59+
def test_symm_mem_barrier(self):
60+
self._init_process()
61+
t = symm_mem.empty(4096, device=self.device)
62+
symm_mem_hdl = symm_mem.rendezvous(t, group=dist.group.WORLD)
63+
local_signal_pad_t = symm_mem_hdl.get_signal_pad(
64+
symm_mem_hdl.rank, (32, symm_mem_hdl.world_size), dtype=torch.int32
65+
)
66+
signal_pad_pointers_t = torch.as_tensor(
67+
symm_mem_hdl.signal_pad_ptrs, dtype=torch.uint64
68+
).to(self.device)
69+
70+
code, result = code_and_output(
71+
symm_mem_sync_kernel,
72+
(
73+
signal_pad_pointers_t,
74+
local_signal_pad_t,
75+
symm_mem_hdl.rank,
76+
),
77+
)
78+
79+
signal_pad = symm_mem_hdl.get_signal_pad(symm_mem_hdl.rank)
80+
assert signal_pad.eq(0).all().item()
81+
82+
dist.destroy_process_group()
83+
84+
85+
if __name__ == "__main__":
86+
run_tests()

0 commit comments

Comments
 (0)