Skip to content

Commit 4f72db4

Browse files
Improve CCL performance (#298)
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
1 parent 731e34f commit 4f72db4

File tree

12 files changed

+950
-210
lines changed

12 files changed

+950
-210
lines changed

benchmark/ccl/all_gather/benchmark.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -336,7 +336,7 @@ def run_experiment():
336336
def main():
337337
args = parse_args()
338338
num_ranks = args["num_ranks"]
339-
init_url = "tcp://127.0.0.1:29503"
339+
init_url = "tcp://127.0.0.1:29234"
340340

341341
mp.spawn(
342342
fn=_worker,

benchmark/ccl/all_reduce/benchmark.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,9 @@ def parse_args():
8484
default=None,
8585
help="Column slice size for ring variant (power of two, must divide block_size_n)",
8686
)
87+
parser.add_argument(
88+
"--init_url", type=str, default="tcp://127.0.0.1:29527", help="Initialization URL for distributed setup"
89+
)
8790

8891
return vars(parser.parse_args())
8992

@@ -100,10 +103,8 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict):
100103
)
101104

102105
shmem = iris.iris(args["heap_size"])
103-
104106
rank = shmem.get_rank()
105107
world_size = shmem.get_num_ranks()
106-
107108
# Datatype mapping
108109
datatype = torch.float32
109110
if args["datatype"] == "fp16":
@@ -374,7 +375,7 @@ def run_experiment():
374375
def main():
375376
args = parse_args()
376377
num_ranks = args["num_ranks"]
377-
init_url = "tcp://127.0.0.1:29503"
378+
init_url = args["init_url"]
378379

379380
mp.spawn(
380381
fn=_worker,

benchmark/ccl/all_to_all/benchmark.py

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,11 @@ def parse_args():
6262
parser.add_argument("--num_xcds", type=int, default=None, help="Number of XCDs (auto-detected if not set)")
6363
parser.add_argument("-r", "--num_ranks", type=int, default=8, help="Number of ranks/processes")
6464
parser.add_argument("--use_gluon", action="store_true", help="Use Gluon implementation with traffic shaping")
65+
parser.add_argument(
66+
"--benchmark_rccl",
67+
action="store_true",
68+
help="Also benchmark PyTorch RCCL (all_to_all) for comparison",
69+
)
6570

6671
return vars(parser.parse_args())
6772

@@ -268,6 +273,69 @@ def run_experiment():
268273
# Wait for all to finish benchmarking
269274
shmem.barrier()
270275

276+
# Benchmark RCCL (PyTorch all_to_all) for comparison
277+
if args.get("benchmark_rccl", False):
278+
shmem.info("Benchmarking PyTorch RCCL (all_to_all)...")
279+
280+
# Create PyTorch tensors (not on Iris heap)
281+
# For all_to_all, we need a list of tensors to send and receive
282+
pytorch_input_list = [torch.zeros(M, N, dtype=datatype, device=f"cuda:{rank}") for _ in range(world_size)]
283+
pytorch_output_list = [torch.zeros(M, N, dtype=datatype, device=f"cuda:{rank}") for _ in range(world_size)]
284+
285+
# Fill input tensors with deterministic values
286+
for target_rank in range(world_size):
287+
val = float(rank * 1000 + target_rank)
288+
pytorch_input_list[target_rank].fill_(val)
289+
290+
# Warmup
291+
for _ in range(10):
292+
dist.all_to_all(pytorch_output_list, pytorch_input_list)
293+
torch.cuda.synchronize()
294+
dist.barrier()
295+
296+
# Benchmark
297+
for target_rank in range(world_size):
298+
pytorch_output_list[target_rank].zero_()
299+
val = float(rank * 1000 + target_rank)
300+
pytorch_input_list[target_rank].fill_(val)
301+
dist.barrier()
302+
303+
rccl_start = torch.cuda.Event(enable_timing=True)
304+
rccl_end = torch.cuda.Event(enable_timing=True)
305+
306+
num_iterations = 126 # Match Iris benchmark iterations
307+
dist.barrier()
308+
rccl_start.record()
309+
for _ in range(num_iterations):
310+
dist.all_to_all(pytorch_output_list, pytorch_input_list)
311+
rccl_end.record()
312+
torch.cuda.synchronize()
313+
dist.barrier()
314+
315+
rccl_ms = rccl_start.elapsed_time(rccl_end) / num_iterations
316+
element_size = torch.tensor([], dtype=datatype).element_size()
317+
total_bytes = (world_size - 1) * M * N * element_size
318+
total_bytes_gb = total_bytes / (1024**3)
319+
rccl_bandwidth_gbps = total_bytes_gb / (rccl_ms * 1e-3)
320+
321+
shmem.info(
322+
f"RCCL all_to_all (M={M}, N={N}, world_size={world_size}, dtype={args['datatype']}): "
323+
f"{rccl_ms:.3f} ms, {rccl_bandwidth_gbps:.3f} GB/s"
324+
)
325+
326+
if args["benchmark"]:
327+
# Calculate performance ratio
328+
iris_bandwidth = bandwidth_gbps
329+
rccl_ratio = (iris_bandwidth / rccl_bandwidth_gbps) * 100 if rccl_bandwidth_gbps > 0 else 0
330+
shmem.info(f"Performance ratio (Iris/RCCL): {rccl_ratio:.1f}%")
331+
332+
json_writer.add_field("rccl_bandwidth_gbps", rccl_bandwidth_gbps)
333+
json_writer.add_field("rccl_ms", rccl_ms)
334+
json_writer.add_field("rccl_ratio_percent", rccl_ratio)
335+
336+
# Wait for all to finish RCCL benchmarking
337+
shmem.barrier()
338+
271339
if rank == 0:
272340
json_writer.flush()
273341
json_writer.display()
@@ -279,7 +347,7 @@ def run_experiment():
279347
def main():
280348
args = parse_args()
281349
num_ranks = args["num_ranks"]
282-
init_url = "tcp://127.0.0.1:29503"
350+
init_url = "tcp://127.0.0.1:29569"
283351

284352
mp.spawn(
285353
fn=_worker,

0 commit comments

Comments
 (0)