This folder provides an example of a distributed All-Gather + GEMM kernel. It explores two distinct patterns for fusing communication and computation: a Pull model and a Push model.
The core kernel implementations are located in examples/14_all_gather_gemm/.
Comparisons are performed against a baseline using the RCCL All-Gather collective and torch.matmul.
The two main patterns explored are:
In the Pull model, the consumer (GEMM kernel) takes full control. It actively "pulls" data from remote GPUs as it is needed using an iris.load instruction. The communication is fused directly into a single, persistent compute kernel.
The Push model decouples communication and computation. A dedicated producer kernel "pushes" data to a remote inbox using iris.store, and the consumer (GEMM kernel) waits for a synchronization signal before performing a fast local load from that inbox.
To run a minimal, standalone example that demonstrates the kernel's functionality and validates its output for a single configuration, use the example_run scripts.
Pull Model:
python examples/14_all_gather_gemm/example_run_pull.py --num_ranks 8
Push Model:
python examples/14_all_gather_gemm/example_run_push.py --num_ranks 8
For more comprehensive testing, dedicated scripts in the benchmark/examples/ directory handle both correctness validation and performance benchmarking across a range of configurations. The behavior of these scripts is controlled by flags.
The scripts run a sweep of configurations defined in the JSON file at dataset/ag_gemm.json.
To verify the numerical correctness of an implementation against a PyTorch reference, run its benchmark script with the -v or --validate flag.
Pull Model:
python benchmark/examples/benchmark_all_gather_gemm_pull.py --num_ranks 8 -v
Push Model:
python benchmark/examples/benchmark_all_gather_gemm_push.py --num_ranks 8 -v
To run the full performance benchmark sweep and save the results as .json files into the results/ directory, use the -b or --benchmark flag.
Pull Model:
python benchmark/examples/benchmark_all_gather_gemm_pull.py --num_ranks 8 -b
Push Model:
python benchmark/examples/benchmark_all_gather_gemm_push.py --num_ranks 8 -b
To validate and benchmark the RCCL + torch.matmul implementation, follow the same steps as the pull/push versions.
python examples/benchmark/reference/all_gather_gemm/benchmark_rccl_torch.py --num_ranks 8 -b