Skip to content

Commit d471238

Browse files
committed
ruff formatting
1 parent d258021 commit d471238

File tree

4 files changed

+52
-20
lines changed

4 files changed

+52
-20
lines changed

src/pplx_kernels/nvshmem.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,13 @@
77

88

99
###### NVSHMEM ######
10-
def nvshmem_init(global_rank: int, local_rank: int, world_size: int, device: Any, uid: Optional[Any] = None) -> None:
10+
def nvshmem_init(
11+
global_rank: int,
12+
local_rank: int,
13+
world_size: int,
14+
device: Any,
15+
uid: Optional[Any] = None,
16+
) -> None:
1117
uniqueid = nvshmem.get_unique_id(empty=True)
1218
if local_rank == 0:
1319
uniqueid = nvshmem.get_unique_id()
@@ -18,7 +24,13 @@ def nvshmem_init(global_rank: int, local_rank: int, world_size: int, device: Any
1824
dist.broadcast_object_list(broadcast_objects, src=0)
1925
dist.barrier()
2026

21-
nvshmem.init(device=device, uid=broadcast_objects[0], rank=global_rank, nranks=world_size, initializer_method="uid")
27+
nvshmem.init(
28+
device=device,
29+
uid=broadcast_objects[0],
30+
rank=global_rank,
31+
nranks=world_size,
32+
initializer_method="uid",
33+
)
2234

2335

2436
# This stream wrapper returns the format required by CUDA Python. This workaround will be removed when nvshmem4py supports Torch stream interoperability.
@@ -31,5 +43,3 @@ def __init__(self, pt_stream: Any) -> None:
3143
def __cuda_stream__(self) -> tuple[int, int]:
3244
stream_id = self.pt_stream.cuda_stream
3345
return (0, stream_id)
34-
35-

tests/bench_all_to_all.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -119,8 +119,8 @@ def bench_all_to_all(
119119
)
120120
a2a_out_tensor = torch.empty_like(a2a_tensor)
121121

122-
nvshmem_in = nvshmem.tensor( a2a_shape, dtype=torch.uint8 )
123-
nvshmem_out = nvshmem.tensor( a2a_shape, dtype=torch.uint8 )
122+
nvshmem_in = nvshmem.tensor(a2a_shape, dtype=torch.uint8)
123+
nvshmem_out = nvshmem.tensor(a2a_shape, dtype=torch.uint8)
124124

125125
# Compute stats
126126
dispatch_bytes = (
@@ -176,7 +176,9 @@ def run() -> tuple[float, ...]:
176176

177177
e3.record(torch_stream_)
178178

179-
nvshmem.collective.alltoall(team, nvshmem_out, nvshmem_in, stream=torch_stream_wrapped)
179+
nvshmem.collective.alltoall(
180+
team, nvshmem_out, nvshmem_in, stream=torch_stream_wrapped
181+
)
180182

181183
e4.record(torch_stream_)
182184

@@ -233,6 +235,7 @@ def run() -> tuple[float, ...]:
233235
result,
234236
)
235237

238+
236239
def _worker_bench_all_to_all(
237240
pgi: ProcessGroupInfo,
238241
dp_size: int,
@@ -246,7 +249,9 @@ def _worker_bench_all_to_all(
246249
dev = Device(local_rank)
247250
dev.set_current()
248251

249-
nvshmem_init(global_rank=global_rank, local_rank=local_rank, world_size=num_ranks, device=dev)
252+
nvshmem_init(
253+
global_rank=global_rank, local_rank=local_rank, world_size=num_ranks, device=dev
254+
)
250255

251256
in_dtype = getattr(torch, in_dtype_str)
252257
out_dtype = getattr(torch, out_dtype_str)

tests/test_all_to_all.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,7 @@ def _do_test_all_to_all(
283283
ref_y[i_token] += rank_data.x[i_token].to(device).to(y.dtype) * val * weight
284284
torch.testing.assert_close(y[: rank_data.num_tokens], ref_y)
285285

286+
286287
def _worker_test_all_to_all(
287288
pgi: ProcessGroupInfo,
288289
dp_size: int,
@@ -300,9 +301,9 @@ def _worker_test_all_to_all(
300301
dev = Device(local_rank)
301302
dev.set_current()
302303

303-
304-
305-
nvshmem_init(global_rank=global_rank, local_rank=local_rank, world_size=num_ranks, device=dev)
304+
nvshmem_init(
305+
global_rank=global_rank, local_rank=local_rank, world_size=num_ranks, device=dev
306+
)
306307

307308
moe_config = dataclasses.replace(
308309
moe_config,
@@ -314,13 +315,16 @@ def _worker_test_all_to_all(
314315
if test_script_init_status < 2 and local_rank == 0:
315316
logger.warning(
316317
"NVSHMEM hostlib initialization incomplete - status: %d (rank: %d, local_rank: %d)",
317-
test_script_init_status, global_rank, local_rank
318+
test_script_init_status,
319+
global_rank,
320+
local_rank,
318321
)
319322

320323
_do_test_all_to_all(pgi, dp_size, moe_config, internode, use_compile)
321324

322325
nvshmem.finalize()
323326

327+
324328
@pytest.mark.skipif(torch.cuda.device_count() < 4, reason="Requires at least 4 GPUs")
325329
@pytest.mark.parametrize("in_dtype", ["bfloat16", "float8_e4m3fn", "float16"])
326330
@pytest.mark.parametrize("out_dtype", ["float16", "bfloat16"])

tests/test_nvshmem.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@
1818

1919
logger = logging.getLogger(__name__)
2020

21-
def test_nvshmem_1_gpu() -> None:
2221

22+
def test_nvshmem_1_gpu() -> None:
2323
local_rank = 0
2424
rank_id = 0 # Define rank_id for single GPU test
2525

@@ -35,7 +35,9 @@ def test_nvshmem_1_gpu() -> None:
3535
if test_script_init_status < 2 and local_rank == 0:
3636
logger.warning(
3737
"NVSHMEM hostlib initialization incomplete - status: %d (rank: %d, local_rank: %d)",
38-
test_script_init_status, rank_id, local_rank
38+
test_script_init_status,
39+
rank_id,
40+
local_rank,
3941
)
4042

4143
assert nvshmem.my_pe() == 0
@@ -50,14 +52,21 @@ def _worker_test_nvshmem_4_gpu(pgi: ProcessGroupInfo) -> None:
5052
dev = Device(local_rank)
5153
dev.set_current()
5254

53-
nvshmem_init(global_rank=pgi.rank, local_rank=local_rank, world_size=pgi.world_size, device=dev)
55+
nvshmem_init(
56+
global_rank=pgi.rank,
57+
local_rank=local_rank,
58+
world_size=pgi.world_size,
59+
device=dev,
60+
)
5461

5562
# Check host initialization status
5663
test_script_init_status = nvshmem.direct.init_status()
5764
if test_script_init_status < 2 and local_rank == 0:
5865
logger.warning(
5966
"NVSHMEM hostlib initialization incomplete - status: %d (rank: %d, local_rank: %d)",
60-
test_script_init_status, pgi.rank, local_rank
67+
test_script_init_status,
68+
pgi.rank,
69+
local_rank,
6170
)
6271

6372
assert nvshmem.my_pe() == pgi.rank
@@ -80,21 +89,25 @@ def _worker_test_all_to_all(pgi: ProcessGroupInfo) -> None:
8089
num_ranks = dist.get_world_size()
8190
rank_id = dist.get_rank()
8291

83-
nvshmem_init(global_rank=rank_id, local_rank=local_rank, world_size=num_ranks, device=dev)
92+
nvshmem_init(
93+
global_rank=rank_id, local_rank=local_rank, world_size=num_ranks, device=dev
94+
)
8495

8596
# Check NVSHMEM host initialization status
8697
test_script_init_status = nvshmem.direct.init_status()
8798
if test_script_init_status < 2 and local_rank == 0:
8899
logger.warning(
89100
"NVSHMEM hostlib initialization incomplete - status: %d (rank: %d, local_rank: %d)",
90-
test_script_init_status, rank_id, local_rank
101+
test_script_init_status,
102+
rank_id,
103+
local_rank,
91104
)
92105

93106
# all-to-all test
94107
try:
95108
# Allocate a PyTorch tensor backed by NVSHMEM symmetric memory
96-
t_in = nvshmem.tensor( (pgi.world_size,), dtype=torch.int32 ).fill_(pgi.rank)
97-
t_out = nvshmem.tensor( (pgi.world_size,), dtype=torch.int32 )
109+
t_in = nvshmem.tensor((pgi.world_size,), dtype=torch.int32).fill_(pgi.rank)
110+
t_out = nvshmem.tensor((pgi.world_size,), dtype=torch.int32)
98111

99112
team = Teams.TEAM_WORLD
100113
nvshmem.collective.alltoall(team, t_out, t_in)

0 commit comments

Comments
 (0)