Skip to content

Commit 6c4a9b9

Browse files
committed
fix ruff mypy and clang formatting errors
1 parent e06a0f3 commit 6c4a9b9

File tree

8 files changed

+39
-40
lines changed

8 files changed

+39
-40
lines changed

csrc/bindings/all_to_all_ops.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ fptr_t create_internode(
7474
hiddenDimScaleBytes
7575
);
7676

77-
// Needed to use host-side initialization information in device APIs.
77+
// Needed to use host-side initialization information in device APIs.
7878
nvshmem_init();
7979

8080
return (fptr_t)ptr;

csrc/bindings/bindings.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55

66
using namespace pplx;
77

8-
TORCH_LIBRARY(pplx_kernels, m) {
8+
TORCH_LIBRARY(pplx_kernels, m)
9+
{
910
register_all_to_all_ops(m);
1011
}
1112

src/pplx_kernels/__init__.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
from . import ops as ops
2-
from .all_to_all import (
3-
AllToAll as AllToAll,
4-
)
2+
from .all_to_all import AllToAll as AllToAll
53
from .nvshmem import (
6-
nvshmem_init as nvshmem_init,
74
PyTorchStreamWrapper as PyTorchStreamWrapper,
5+
nvshmem_init as nvshmem_init,
86
)

src/pplx_kernels/nvshmem.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,13 @@
11
# pyright: reportCallIssue=false
22

3-
from collections.abc import Sequence
3+
from typing import Any, Optional
44

5-
import torch
5+
import nvshmem.core as nvshmem # type: ignore[import]
66
import torch.distributed as dist
77

8-
import nvshmem.core as nvshmem
98

109
###### NVSHMEM ######
11-
def nvshmem_init(global_rank: int, local_rank: int, world_size: int, device, uid=None) -> None:
10+
def nvshmem_init(global_rank: int, local_rank: int, world_size: int, device: Any, uid: Optional[Any] = None) -> None:
1211
uniqueid = nvshmem.get_unique_id(empty=True)
1312
if local_rank == 0:
1413
uniqueid = nvshmem.get_unique_id()
@@ -20,16 +19,16 @@ def nvshmem_init(global_rank: int, local_rank: int, world_size: int, device, uid
2019
dist.barrier()
2120

2221
nvshmem.init(device=device, uid=broadcast_objects[0], rank=global_rank, nranks=world_size, initializer_method="uid")
23-
22+
2423

2524
# This stream wrapper returns the format required by CUDA Python. This workaround will be removed when nvshmem4py supports Torch stream interoperability.
2625
# For more information see: https://nvidia.github.io/cuda-python/cuda-core/latest/interoperability.html#cuda-stream-protocol
2726
class PyTorchStreamWrapper:
28-
def __init__(self, pt_stream):
27+
def __init__(self, pt_stream: Any) -> None:
2928
self.pt_stream = pt_stream
3029
self.handle = pt_stream.cuda_stream
3130

32-
def __cuda_stream__(self):
31+
def __cuda_stream__(self) -> tuple[int, int]:
3332
stream_id = self.pt_stream.cuda_stream
3433
return (0, stream_id)
3534

src/pplx_kernels/ops.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55

66
import torch
77

8+
logger = logging.getLogger(__name__)
9+
810
try:
911
_lib_path = os.path.join(os.path.dirname(__file__), "libpplx_kernels.so")
1012
torch.ops.load_library(_lib_path)
@@ -13,4 +15,4 @@
1315
from types import SimpleNamespace
1416

1517
_ops = SimpleNamespace()
16-
logging.exception("Error loading pplx-kernels")
18+
logger.exception("Error loading pplx-kernels")

tests/bench_all_to_all.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,13 @@
66
from datetime import datetime
77
from pathlib import Path
88

9+
import nvshmem.core as nvshmem # type: ignore[import]
910
import torch
10-
import torch.distributed as dist
11-
from cuda.core.experimental import Device
12-
import nvshmem.core as nvshmem
13-
from nvshmem.core import Teams
11+
from cuda.core.experimental import Device # type: ignore[import]
12+
from nvshmem.core import Teams # type: ignore[import]
1413

14+
from pplx_kernels import PyTorchStreamWrapper, nvshmem_init
1515
from pplx_kernels.all_to_all import AllToAll
16-
from pplx_kernels import nvshmem_init, PyTorchStreamWrapper
1716

1817
from .all_to_all_utils import MoEConfig, RankTestData
1918
from .distributed_utils import (
@@ -225,7 +224,7 @@ def run() -> tuple[float, ...]:
225224

226225
# Cleanup
227226
ata.destroy()
228-
227+
229228
nvshmem.free_tensor(nvshmem_in)
230229
nvshmem.free_tensor(nvshmem_out)
231230

tests/test_all_to_all.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
import dataclasses
22
import logging
33

4+
import nvshmem.core as nvshmem # type: ignore[import]
45
import pytest
56
import torch
67
import torch.distributed as dist
7-
from cuda.core.experimental import Device
8-
import nvshmem.core as nvshmem
8+
from cuda.core.experimental import Device # type: ignore[import]
9+
10+
from pplx_kernels import nvshmem_init
911
from pplx_kernels.all_to_all import AllToAll
10-
from pplx_kernels import nvshmem_init, PyTorchStreamWrapper
1112

1213
from .all_to_all_utils import MoEConfig, RankTestData
1314
from .distributed_utils import (
@@ -299,7 +300,7 @@ def _worker_test_all_to_all(
299300
dev = Device(local_rank)
300301
dev.set_current()
301302

302-
stream = PyTorchStreamWrapper(torch.cuda.current_stream())
303+
303304

304305
nvshmem_init(global_rank=global_rank, local_rank=local_rank, world_size=num_ranks, device=dev)
305306

@@ -316,7 +317,7 @@ def _worker_test_all_to_all(
316317
test_script_init_status, global_rank, local_rank
317318
)
318319

319-
_do_test_all_to_all(pgi, dp_size, moe_config, internode, stream)
320+
_do_test_all_to_all(pgi, dp_size, moe_config, internode, use_compile)
320321

321322
nvshmem.finalize()
322323

tests/test_nvshmem.py

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,13 @@
1+
import logging
2+
3+
import nvshmem.core as nvshmem # type: ignore[import]
14
import pytest
25
import torch
6+
import torch.distributed as dist
7+
from cuda.core.experimental import Device # type: ignore[import]
8+
from nvshmem.core import Teams # type: ignore[import]
9+
10+
from pplx_kernels import nvshmem_init
311

412
from .distributed_utils import (
513
ProcessGroupInfo,
@@ -8,19 +16,14 @@
816
require_multi_node,
917
)
1018

11-
from cuda.core.experimental import Device
12-
import nvshmem.core as nvshmem
13-
import torch.distributed as dist
14-
from nvshmem.core import Teams
15-
from pplx_kernels import nvshmem_init, PyTorchStreamWrapper
19+
logger = logging.getLogger(__name__)
1620

1721
def test_nvshmem_1_gpu() -> None:
1822

1923
local_rank = 0
20-
world_size = 1
24+
rank_id = 0 # Define rank_id for single GPU test
2125

2226
torch.cuda.set_device(local_rank)
23-
device = torch.device("cuda", local_rank)
2427
dev = Device(local_rank)
2528
dev.set_current()
2629

@@ -39,17 +42,15 @@ def test_nvshmem_1_gpu() -> None:
3942
assert nvshmem.n_pes() == 1
4043

4144
nvshmem.finalize()
42-
4345

4446

4547
def _worker_test_nvshmem_4_gpu(pgi: ProcessGroupInfo) -> None:
4648
local_rank = dist.get_rank()
47-
world_size = dist.get_world_size()
4849

4950
dev = Device(local_rank)
5051
dev.set_current()
5152

52-
nvshmem_init(global_rank=pgi.rank, local_rank=local_rank, world_size=world_size, device=dev)
53+
nvshmem_init(global_rank=pgi.rank, local_rank=local_rank, world_size=pgi.world_size, device=dev)
5354

5455
# Check host initialization status
5556
test_script_init_status = nvshmem.direct.init_status()
@@ -72,12 +73,10 @@ def test_nvshmem_4_gpu() -> None:
7273

7374
def _worker_test_all_to_all(pgi: ProcessGroupInfo) -> None:
7475
local_rank = dist.get_rank()
75-
world_size = dist.get_world_size()
7676

7777
dev = Device(local_rank)
7878
dev.set_current()
79-
stream = PyTorchStreamWrapper(torch.cuda.current_stream())
80-
79+
8180
num_ranks = dist.get_world_size()
8281
rank_id = dist.get_rank()
8382

@@ -98,9 +97,9 @@ def _worker_test_all_to_all(pgi: ProcessGroupInfo) -> None:
9897
t_out = nvshmem.tensor( (pgi.world_size,), dtype=torch.int32 )
9998

10099
team = Teams.TEAM_WORLD
101-
nvshmem.collective.alltoall(team, t_out, t_in, stream=stream)
100+
nvshmem.collective.alltoall(team, t_out, t_in)
102101

103-
nvshmem.collective.barrier(team, stream=stream)
102+
nvshmem.collective.barrier(team)
104103
torch.cuda.synchronize()
105104

106105
assert t_out.tolist() == list(range(pgi.world_size))

0 commit comments

Comments
 (0)