@@ -62,6 +62,11 @@ def parse_args():
6262 parser .add_argument ("--num_xcds" , type = int , default = None , help = "Number of XCDs (auto-detected if not set)" )
6363 parser .add_argument ("-r" , "--num_ranks" , type = int , default = 8 , help = "Number of ranks/processes" )
6464 parser .add_argument ("--use_gluon" , action = "store_true" , help = "Use Gluon implementation with traffic shaping" )
65+ parser .add_argument (
66+ "--benchmark_rccl" ,
67+ action = "store_true" ,
68+ help = "Also benchmark PyTorch RCCL (all_to_all) for comparison" ,
69+ )
6570
6671 return vars (parser .parse_args ())
6772
@@ -268,6 +273,69 @@ def run_experiment():
268273 # Wait for all to finish benchmarking
269274 shmem .barrier ()
270275
276+ # Benchmark RCCL (PyTorch all_to_all) for comparison
277+ if args .get ("benchmark_rccl" , False ):
278+ shmem .info ("Benchmarking PyTorch RCCL (all_to_all)..." )
279+
280+ # Create PyTorch tensors (not on Iris heap)
281+ # For all_to_all, we need a list of tensors to send and receive
282+ pytorch_input_list = [torch .zeros (M , N , dtype = datatype , device = f"cuda:{ rank } " ) for _ in range (world_size )]
283+ pytorch_output_list = [torch .zeros (M , N , dtype = datatype , device = f"cuda:{ rank } " ) for _ in range (world_size )]
284+
285+ # Fill input tensors with deterministic values
286+ for target_rank in range (world_size ):
287+ val = float (rank * 1000 + target_rank )
288+ pytorch_input_list [target_rank ].fill_ (val )
289+
290+ # Warmup
291+ for _ in range (10 ):
292+ dist .all_to_all (pytorch_output_list , pytorch_input_list )
293+ torch .cuda .synchronize ()
294+ dist .barrier ()
295+
296+ # Benchmark
297+ for target_rank in range (world_size ):
298+ pytorch_output_list [target_rank ].zero_ ()
299+ val = float (rank * 1000 + target_rank )
300+ pytorch_input_list [target_rank ].fill_ (val )
301+ dist .barrier ()
302+
303+ rccl_start = torch .cuda .Event (enable_timing = True )
304+ rccl_end = torch .cuda .Event (enable_timing = True )
305+
306+ num_iterations = 126 # Match Iris benchmark iterations
307+ dist .barrier ()
308+ rccl_start .record ()
309+ for _ in range (num_iterations ):
310+ dist .all_to_all (pytorch_output_list , pytorch_input_list )
311+ rccl_end .record ()
312+ torch .cuda .synchronize ()
313+ dist .barrier ()
314+
315+ rccl_ms = rccl_start .elapsed_time (rccl_end ) / num_iterations
316+ element_size = torch .tensor ([], dtype = datatype ).element_size ()
317+ total_bytes = (world_size - 1 ) * M * N * element_size
318+ total_bytes_gb = total_bytes / (1024 ** 3 )
319+ rccl_bandwidth_gbps = total_bytes_gb / (rccl_ms * 1e-3 )
320+
321+ shmem .info (
322+ f"RCCL all_to_all (M={ M } , N={ N } , world_size={ world_size } , dtype={ args ['datatype' ]} ): "
323+ f"{ rccl_ms :.3f} ms, { rccl_bandwidth_gbps :.3f} GB/s"
324+ )
325+
326+ if args ["benchmark" ]:
327+ # Calculate performance ratio
328+ iris_bandwidth = bandwidth_gbps
329+ rccl_ratio = (iris_bandwidth / rccl_bandwidth_gbps ) * 100 if rccl_bandwidth_gbps > 0 else 0
330+ shmem .info (f"Performance ratio (Iris/RCCL): { rccl_ratio :.1f} %" )
331+
332+ json_writer .add_field ("rccl_bandwidth_gbps" , rccl_bandwidth_gbps )
333+ json_writer .add_field ("rccl_ms" , rccl_ms )
334+ json_writer .add_field ("rccl_ratio_percent" , rccl_ratio )
335+
336+ # Wait for all to finish RCCL benchmarking
337+ shmem .barrier ()
338+
271339 if rank == 0 :
272340 json_writer .flush ()
273341 json_writer .display ()
@@ -279,7 +347,7 @@ def run_experiment():
279347def main ():
280348 args = parse_args ()
281349 num_ranks = args ["num_ranks" ]
282- init_url = "tcp://127.0.0.1:29503 "
350+ init_url = "tcp://127.0.0.1:29569 "
283351
284352 mp .spawn (
285353 fn = _worker ,
0 commit comments