Skip to content

Commit 658411c

Browse files
authored
update pytest and python API to fix ut failure (#598)
update pytest and python API to fix ut failure
1 parent 334b232 commit 658411c

File tree

3 files changed

+30
-29
lines changed

3 files changed

+30
-29
lines changed

python/mscclpp/comm.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ def make_connection(
8787
self,
8888
all_ranks: list[int],
8989
endpoints: EndpointConfig | Transport | dict[int, EndpointConfig] | dict[int, Transport],
90+
use_switch: bool = False,
9091
) -> dict[int, Connection]:
9192
if type(endpoints) is Transport:
9293
endpoints = EndpointConfig(endpoints)
@@ -98,7 +99,7 @@ def make_connection(
9899
endpoint = endpoints[rank]
99100
else:
100101
endpoint = endpoints
101-
if endpoint.transport == Transport.Nvls:
102+
if endpoint.transport == Transport.CudaIpc and use_switch:
102103
return connect_nvls_collective(self.communicator, all_ranks, 2**30)
103104
else:
104105
connections[rank] = self.communicator.connect(endpoint, rank)

python/mscclpp_benchmark/mscclpp_op.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -440,7 +440,7 @@ def __init__(
440440
self.group.barrier()
441441
# create a connection for each remote neighbor
442442
self.nvlink_connections = self.group.make_connection(remote_nghrs, Transport.CudaIpc)
443-
self.nvls_connection = group.make_connection(all_ranks, Transport.Nvls)
443+
self.nvls_connection = group.make_connection(all_ranks, Transport.CudaIpc, use_switch=True)
444444
self.memory = GpuBuffer(nelem, memory_dtype)
445445
self.nvls_mem_handle = self.nvls_connection.bind_allocated_memory(
446446
self.memory.data.ptr, self.memory.data.mem.size

python/test/test_mscclpp.py

Lines changed: 27 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -141,33 +141,33 @@ def init_target():
141141
mpi_group.comm.barrier()
142142

143143

144-
def create_connection(group: mscclpp_comm.CommGroup, transport: str):
145-
if transport == "NVLS":
144+
def create_connection(group: mscclpp_comm.CommGroup, connection_type: str):
145+
if connection_type == "NVLS":
146146
all_ranks = list(range(group.nranks))
147-
tran = Transport.Nvls
148-
connection = group.make_connection(all_ranks, tran)
147+
tran = Transport.CudaIpc
148+
connection = group.make_connection(all_ranks, tran, use_switch=True)
149149
return connection
150150

151151
remote_nghrs = list(range(group.nranks))
152152
remote_nghrs.remove(group.my_rank)
153-
if transport == "NVLink":
153+
if connection_type == "NVLink":
154154
tran = Transport.CudaIpc
155-
elif transport == "IB":
155+
elif connection_type == "IB":
156156
tran = group.my_ib_device(group.my_rank % 8)
157157
else:
158158
assert False
159159
connections = group.make_connection(remote_nghrs, tran)
160160
return connections
161161

162162

163-
def create_group_and_connection(mpi_group: MpiGroup, transport: str):
164-
if (transport == "NVLink" or transport == "NVLS") and all_ranks_on_the_same_node(mpi_group) is False:
163+
def create_group_and_connection(mpi_group: MpiGroup, connection_type: str):
164+
if (connection_type == "NVLink" or connection_type == "NVLS") and all_ranks_on_the_same_node(mpi_group) is False:
165165
pytest.skip("cannot use nvlink/nvls for cross node")
166166
group = mscclpp_comm.CommGroup(mpi_group.comm)
167167
try:
168-
connection = create_connection(group, transport)
168+
connection = create_connection(group, connection_type)
169169
except Error as e:
170-
if transport == "IB" and e.args[0] == ErrorCode.InvalidUsage:
170+
if connection_type == "IB" and e.args[0] == ErrorCode.InvalidUsage:
171171
pytest.skip("IB not supported on this node")
172172
raise
173173
return group, connection
@@ -194,10 +194,10 @@ def test_gpu_buffer(mpi_group: MpiGroup, nelem: int, dtype: cp.dtype):
194194

195195

196196
@parametrize_mpi_groups(2, 4, 8, 16)
197-
@pytest.mark.parametrize("transport", ["IB", "NVLink"])
197+
@pytest.mark.parametrize("connection_type", ["IB", "NVLink"])
198198
@pytest.mark.parametrize("nelem", [2**i for i in [10, 15, 20]])
199-
def test_connection_write(mpi_group: MpiGroup, transport: Transport, nelem: int):
200-
group, connections = create_group_and_connection(mpi_group, transport)
199+
def test_connection_write(mpi_group: MpiGroup, connection_type: str, nelem: int):
200+
group, connections = create_group_and_connection(mpi_group, connection_type)
201201
memory = GpuBuffer(nelem, dtype=cp.int32)
202202
nelemPerRank = nelem // group.nranks
203203
sizePerRank = nelemPerRank * memory.itemsize
@@ -229,16 +229,16 @@ def test_connection_write(mpi_group: MpiGroup, transport: Transport, nelem: int)
229229

230230

231231
@parametrize_mpi_groups(2, 4, 8, 16)
232-
@pytest.mark.parametrize("transport", ["IB", "NVLink"])
232+
@pytest.mark.parametrize("connection_type", ["IB", "NVLink"])
233233
@pytest.mark.parametrize("nelem", [2**i for i in [10, 15, 20, 27]])
234234
@pytest.mark.parametrize("device", ["cuda", "cpu"])
235-
def test_connection_write_and_signal(mpi_group: MpiGroup, transport: Transport, nelem: int, device: str):
235+
def test_connection_write_and_signal(mpi_group: MpiGroup, connection_type: str, nelem: int, device: str):
236236
# this test starts with a random tensor on rank 0 and rotates it all the way through all ranks
237237
# and finally, comes back to rank 0 to make sure it matches all the original values
238238

239-
if device == "cpu" and transport == "NVLink":
239+
if device == "cpu" and connection_type == "NVLink":
240240
pytest.skip("nvlink doesn't work with host allocated memory")
241-
group, connections = create_group_and_connection(mpi_group, transport)
241+
group, connections = create_group_and_connection(mpi_group, connection_type)
242242
xp = cp if device == "cuda" else np
243243
if group.my_rank == 0:
244244
memory = xp.random.randn(nelem)
@@ -339,7 +339,7 @@ def test_nvls_connection(mpi_group: MpiGroup):
339339
pytest.skip("cannot use nvls for cross node")
340340
group = mscclpp_comm.CommGroup(mpi_group.comm)
341341
all_ranks = list(range(group.nranks))
342-
nvls_connection = group.make_connection(all_ranks, Transport.Nvls)
342+
nvls_connection = group.make_connection(all_ranks, Transport.CudaIpc, use_switch=True)
343343
memory1 = GpuBuffer(2**29, cp.int8)
344344
memory2 = GpuBuffer(2**29, cp.int8)
345345
memory3 = GpuBuffer(2**29, cp.int8)
@@ -449,13 +449,13 @@ def __call__(self):
449449

450450

451451
@parametrize_mpi_groups(2, 4, 8, 16)
452-
@pytest.mark.parametrize("transport", ["NVLink", "IB"])
453-
def test_h2d_semaphores(mpi_group: MpiGroup, transport: str):
452+
@pytest.mark.parametrize("connection_type", ["NVLink", "IB"])
453+
def test_h2d_semaphores(mpi_group: MpiGroup, connection_type: str):
454454
def signal(semaphores):
455455
for rank in semaphores:
456456
semaphores[rank].signal()
457457

458-
group, connections = create_group_and_connection(mpi_group, transport)
458+
group, connections = create_group_and_connection(mpi_group, connection_type)
459459

460460
semaphores = group.make_semaphore(connections, Host2DeviceSemaphore)
461461
kernel = MscclppKernel("h2d_semaphore", group.my_rank, group.nranks, semaphores)
@@ -530,9 +530,9 @@ def test_fifo(
530530

531531
@parametrize_mpi_groups(2, 4, 8, 16)
532532
@pytest.mark.parametrize("nelem", [2**i for i in [10, 15, 20]])
533-
@pytest.mark.parametrize("transport", ["IB", "NVLink"])
534-
def test_proxy(mpi_group: MpiGroup, nelem: int, transport: str):
535-
group, connections = create_group_and_connection(mpi_group, transport)
533+
@pytest.mark.parametrize("connection_type", ["IB", "NVLink"])
534+
def test_proxy(mpi_group: MpiGroup, nelem: int, connection_type: str):
535+
group, connections = create_group_and_connection(mpi_group, connection_type)
536536

537537
memory = GpuBuffer(nelem, dtype=cp.int32)
538538
nelemPerRank = nelem // group.nranks
@@ -579,10 +579,10 @@ def test_proxy(mpi_group: MpiGroup, nelem: int, transport: str):
579579

580580
@parametrize_mpi_groups(2, 4, 8, 16)
581581
@pytest.mark.parametrize("nelem", [2**i for i in [10, 15, 20]])
582-
@pytest.mark.parametrize("transport", ["NVLink", "IB"])
582+
@pytest.mark.parametrize("connection_type", ["NVLink", "IB"])
583583
@pytest.mark.parametrize("use_packet", [False, True])
584-
def test_port_channel(mpi_group: MpiGroup, nelem: int, transport: str, use_packet: bool):
585-
group, connections = create_group_and_connection(mpi_group, transport)
584+
def test_port_channel(mpi_group: MpiGroup, nelem: int, connection_type: str, use_packet: bool):
585+
group, connections = create_group_and_connection(mpi_group, connection_type)
586586

587587
memory = GpuBuffer(nelem, dtype=cp.int32)
588588
if use_packet:

0 commit comments

Comments
 (0)