@@ -141,33 +141,33 @@ def init_target():
141
141
mpi_group .comm .barrier ()
142
142
143
143
144
- def create_connection (group : mscclpp_comm .CommGroup , transport : str ):
145
- if transport == "NVLS" :
144
+ def create_connection (group : mscclpp_comm .CommGroup , connection_type : str ):
145
+ if connection_type == "NVLS" :
146
146
all_ranks = list (range (group .nranks ))
147
- tran = Transport .Nvls
148
- connection = group .make_connection (all_ranks , tran )
147
+ tran = Transport .CudaIpc
148
+ connection = group .make_connection (all_ranks , tran , use_switch = True )
149
149
return connection
150
150
151
151
remote_nghrs = list (range (group .nranks ))
152
152
remote_nghrs .remove (group .my_rank )
153
- if transport == "NVLink" :
153
+ if connection_type == "NVLink" :
154
154
tran = Transport .CudaIpc
155
- elif transport == "IB" :
155
+ elif connection_type == "IB" :
156
156
tran = group .my_ib_device (group .my_rank % 8 )
157
157
else :
158
158
assert False
159
159
connections = group .make_connection (remote_nghrs , tran )
160
160
return connections
161
161
162
162
163
- def create_group_and_connection (mpi_group : MpiGroup , transport : str ):
164
- if (transport == "NVLink" or transport == "NVLS" ) and all_ranks_on_the_same_node (mpi_group ) is False :
163
+ def create_group_and_connection (mpi_group : MpiGroup , connection_type : str ):
164
+ if (connection_type == "NVLink" or connection_type == "NVLS" ) and all_ranks_on_the_same_node (mpi_group ) is False :
165
165
pytest .skip ("cannot use nvlink/nvls for cross node" )
166
166
group = mscclpp_comm .CommGroup (mpi_group .comm )
167
167
try :
168
- connection = create_connection (group , transport )
168
+ connection = create_connection (group , connection_type )
169
169
except Error as e :
170
- if transport == "IB" and e .args [0 ] == ErrorCode .InvalidUsage :
170
+ if connection_type == "IB" and e .args [0 ] == ErrorCode .InvalidUsage :
171
171
pytest .skip ("IB not supported on this node" )
172
172
raise
173
173
return group , connection
@@ -194,10 +194,10 @@ def test_gpu_buffer(mpi_group: MpiGroup, nelem: int, dtype: cp.dtype):
194
194
195
195
196
196
@parametrize_mpi_groups (2 , 4 , 8 , 16 )
197
- @pytest .mark .parametrize ("transport " , ["IB" , "NVLink" ])
197
+ @pytest .mark .parametrize ("connection_type " , ["IB" , "NVLink" ])
198
198
@pytest .mark .parametrize ("nelem" , [2 ** i for i in [10 , 15 , 20 ]])
199
- def test_connection_write (mpi_group : MpiGroup , transport : Transport , nelem : int ):
200
- group , connections = create_group_and_connection (mpi_group , transport )
199
+ def test_connection_write (mpi_group : MpiGroup , connection_type : str , nelem : int ):
200
+ group , connections = create_group_and_connection (mpi_group , connection_type )
201
201
memory = GpuBuffer (nelem , dtype = cp .int32 )
202
202
nelemPerRank = nelem // group .nranks
203
203
sizePerRank = nelemPerRank * memory .itemsize
@@ -229,16 +229,16 @@ def test_connection_write(mpi_group: MpiGroup, transport: Transport, nelem: int)
229
229
230
230
231
231
@parametrize_mpi_groups (2 , 4 , 8 , 16 )
232
- @pytest .mark .parametrize ("transport " , ["IB" , "NVLink" ])
232
+ @pytest .mark .parametrize ("connection_type " , ["IB" , "NVLink" ])
233
233
@pytest .mark .parametrize ("nelem" , [2 ** i for i in [10 , 15 , 20 , 27 ]])
234
234
@pytest .mark .parametrize ("device" , ["cuda" , "cpu" ])
235
- def test_connection_write_and_signal (mpi_group : MpiGroup , transport : Transport , nelem : int , device : str ):
235
+ def test_connection_write_and_signal (mpi_group : MpiGroup , connection_type : str , nelem : int , device : str ):
236
236
# this test starts with a random tensor on rank 0 and rotates it all the way through all ranks
237
237
# and finally, comes back to rank 0 to make sure it matches all the original values
238
238
239
- if device == "cpu" and transport == "NVLink" :
239
+ if device == "cpu" and connection_type == "NVLink" :
240
240
pytest .skip ("nvlink doesn't work with host allocated memory" )
241
- group , connections = create_group_and_connection (mpi_group , transport )
241
+ group , connections = create_group_and_connection (mpi_group , connection_type )
242
242
xp = cp if device == "cuda" else np
243
243
if group .my_rank == 0 :
244
244
memory = xp .random .randn (nelem )
@@ -339,7 +339,7 @@ def test_nvls_connection(mpi_group: MpiGroup):
339
339
pytest .skip ("cannot use nvls for cross node" )
340
340
group = mscclpp_comm .CommGroup (mpi_group .comm )
341
341
all_ranks = list (range (group .nranks ))
342
- nvls_connection = group .make_connection (all_ranks , Transport .Nvls )
342
+ nvls_connection = group .make_connection (all_ranks , Transport .CudaIpc , use_switch = True )
343
343
memory1 = GpuBuffer (2 ** 29 , cp .int8 )
344
344
memory2 = GpuBuffer (2 ** 29 , cp .int8 )
345
345
memory3 = GpuBuffer (2 ** 29 , cp .int8 )
@@ -449,13 +449,13 @@ def __call__(self):
449
449
450
450
451
451
@parametrize_mpi_groups (2 , 4 , 8 , 16 )
452
- @pytest .mark .parametrize ("transport " , ["NVLink" , "IB" ])
453
- def test_h2d_semaphores (mpi_group : MpiGroup , transport : str ):
452
+ @pytest .mark .parametrize ("connection_type " , ["NVLink" , "IB" ])
453
+ def test_h2d_semaphores (mpi_group : MpiGroup , connection_type : str ):
454
454
def signal (semaphores ):
455
455
for rank in semaphores :
456
456
semaphores [rank ].signal ()
457
457
458
- group , connections = create_group_and_connection (mpi_group , transport )
458
+ group , connections = create_group_and_connection (mpi_group , connection_type )
459
459
460
460
semaphores = group .make_semaphore (connections , Host2DeviceSemaphore )
461
461
kernel = MscclppKernel ("h2d_semaphore" , group .my_rank , group .nranks , semaphores )
@@ -530,9 +530,9 @@ def test_fifo(
530
530
531
531
@parametrize_mpi_groups (2 , 4 , 8 , 16 )
532
532
@pytest .mark .parametrize ("nelem" , [2 ** i for i in [10 , 15 , 20 ]])
533
- @pytest .mark .parametrize ("transport " , ["IB" , "NVLink" ])
534
- def test_proxy (mpi_group : MpiGroup , nelem : int , transport : str ):
535
- group , connections = create_group_and_connection (mpi_group , transport )
533
+ @pytest .mark .parametrize ("connection_type " , ["IB" , "NVLink" ])
534
+ def test_proxy (mpi_group : MpiGroup , nelem : int , connection_type : str ):
535
+ group , connections = create_group_and_connection (mpi_group , connection_type )
536
536
537
537
memory = GpuBuffer (nelem , dtype = cp .int32 )
538
538
nelemPerRank = nelem // group .nranks
@@ -579,10 +579,10 @@ def test_proxy(mpi_group: MpiGroup, nelem: int, transport: str):
579
579
580
580
@parametrize_mpi_groups (2 , 4 , 8 , 16 )
581
581
@pytest .mark .parametrize ("nelem" , [2 ** i for i in [10 , 15 , 20 ]])
582
- @pytest .mark .parametrize ("transport " , ["NVLink" , "IB" ])
582
+ @pytest .mark .parametrize ("connection_type " , ["NVLink" , "IB" ])
583
583
@pytest .mark .parametrize ("use_packet" , [False , True ])
584
- def test_port_channel (mpi_group : MpiGroup , nelem : int , transport : str , use_packet : bool ):
585
- group , connections = create_group_and_connection (mpi_group , transport )
584
+ def test_port_channel (mpi_group : MpiGroup , nelem : int , connection_type : str , use_packet : bool ):
585
+ group , connections = create_group_and_connection (mpi_group , connection_type )
586
586
587
587
memory = GpuBuffer (nelem , dtype = cp .int32 )
588
588
if use_packet :
0 commit comments