diff --git a/benchmarks/prototype/moe_training/mxfp8/bench_all_to_all_v.py b/benchmarks/prototype/moe_training/mxfp8/bench_all_to_all_v.py index 28a5bb87a2..f0d5490363 100644 --- a/benchmarks/prototype/moe_training/mxfp8/bench_all_to_all_v.py +++ b/benchmarks/prototype/moe_training/mxfp8/bench_all_to_all_v.py @@ -42,8 +42,10 @@ class ExperimentConfig: @dataclass(frozen=True) class ExperimentResult: - bf16_ms: float - mxfp8_ms: float + fwd_bf16_ms: float + fwd_mxfp8_ms: float + bwd_bf16_ms: float + bwd_mxfp8_ms: float @dataclass(frozen=True) @@ -55,6 +57,10 @@ class Experiment: def get_configs() -> List[ExperimentConfig]: # (batch_size, seq_len, dim) input_shapes = [ + (1, 8192, 5120), + (2, 8192, 5120), + (4, 8192, 5120), + (8, 8192, 5120), (16, 8192, 5120), ] configs = [] @@ -67,9 +73,8 @@ def get_configs() -> List[ExperimentConfig]: return configs -def default_a2a_fwd_bwd( +def default_a2a_fwd( routed_input: torch.Tensor, - labels: torch.Tensor, output_splits_list: list[int], input_splits_list: list[int], device_mesh: DeviceMesh, @@ -81,17 +86,12 @@ def default_a2a_fwd_bwd( device_mesh.get_group(), ) routed_input = torch.ops._c10d_functional.wait_tensor(routed_input) - - loss = F.mse_loss(routed_input, labels) - loss.backward() - torch.cuda.synchronize() return routed_input -def mxfp8_a2a_fwd_bwd( +def mxfp8_a2a_fwd( routed_input: torch.Tensor, - labels: torch.Tensor, output_splits_list: list[int], input_splits_list: list[int], device_mesh: DeviceMesh, @@ -102,7 +102,14 @@ def mxfp8_a2a_fwd_bwd( input_splits_list, device_mesh.get_group(), ) + torch.cuda.synchronize() + return routed_input + +def mse_loss_and_bwd( + routed_input: torch.Tensor, + labels: torch.Tensor, +): loss = F.mse_loss(routed_input, labels) loss.backward() torch.cuda.synchronize() @@ -110,8 +117,7 @@ def mxfp8_a2a_fwd_bwd( # Compile target funcs -default_a2a_sync_compiled = torch.compile(default_a2a_fwd_bwd) -mxfp8_a2a_sync_compiled = torch.compile(mxfp8_a2a_fwd_bwd) +mse_loss_and_bwd_compiled = torch.compile(mse_loss_and_bwd) def run_experiment( @@ -129,73 +135,94 @@ def run_experiment( # Set up device mesh mesh = init_device_mesh("cuda", (dist.get_world_size(),)) - # Max output tokens per rank is worst case where one rank receives all tokens - input_tokens_per_rank = batch_size * seq_len - def warmup(func_no_args): for _ in range(2): func_no_args() + input_tokens_per_rank = batch_size * seq_len num_experts_per_rank = 2 - num_splits = dist.get_world_size() * num_experts_per_rank - input_splits = generate_split_sizes( - num_splits, input_tokens_per_rank, device=device - ) + num_experts = dist.get_world_size() * num_experts_per_rank + input_tokens_per_expert = input_tokens_per_rank // num_experts + input_splits = torch.tensor( + input_tokens_per_expert, dtype=torch.int32, device=device + ).repeat(num_experts) input_splits_list, output_splits_list = get_split_lists(input_splits, mesh) # Generate labels labels_shape = (sum(output_splits_list), dim) labels = x.new_ones(*labels_shape) - # Bench default a2a (exclude d2h sync from preparing input splits_list and output_splits_list) - warmup( - lambda: default_a2a_sync_compiled( - ref_x, labels, output_splits_list, input_splits_list, mesh - ) - ) + # Bench default a2a fwd (exclude d2h sync from preparing input splits_list and output_splits_list) + warmup(lambda: default_a2a_fwd(ref_x, output_splits_list, input_splits_list, mesh)) start_sec = time.perf_counter() - default_a2a_sync_compiled( - ref_x, labels, output_splits_list, input_splits_list, mesh + bf16_routed_input = default_a2a_fwd( + ref_x, output_splits_list, input_splits_list, mesh ) end_sec = time.perf_counter() - bf16_ms = (end_sec - start_sec) * 1e3 + fwd_bf16_ms = (end_sec - start_sec) * 1e3 if args.profile: profile_fn( - default_a2a_sync_compiled, + default_a2a_fwd, ref_x, - labels, output_splits_list, input_splits_list, mesh, distributed=True, - profile_name="all_to_all_single_autograd", + profile_name="default_a2a_fwd", ) - # Bench mxfp8 sync a2a (exclude d2h sync from preparing input splits_list and output_splits_list) - warmup( - lambda: mxfp8_a2a_sync_compiled( - x, labels, output_splits_list, input_splits_list, mesh + # Bench default a2a backward + warmup(lambda: mse_loss_and_bwd_compiled(bf16_routed_input, labels)) + start_sec = time.perf_counter() + mse_loss_and_bwd_compiled(bf16_routed_input, labels) + end_sec = time.perf_counter() + bwd_bf16_ms = (end_sec - start_sec) * 1e3 + if args.profile: + profile_fn( + mse_loss_and_bwd_compiled, + bf16_routed_input, + labels, + distributed=True, + profile_name="bf16_a2a_bwd", ) - ) + + # Bench mxfp8 sync a2a fwd (exclude d2h sync from preparing input splits_list and output_splits_list) + warmup(lambda: mxfp8_a2a_fwd(x, output_splits_list, input_splits_list, mesh)) start_sec = time.perf_counter() - mxfp8_a2a_sync_compiled(x, labels, output_splits_list, input_splits_list, mesh) + mxfp8_routed_input = mxfp8_a2a_fwd(x, output_splits_list, input_splits_list, mesh) end_sec = time.perf_counter() - mxfp8_ms = (end_sec - start_sec) * 1e3 + fwd_mxfp8_ms = (end_sec - start_sec) * 1e3 if args.profile: profile_fn( - mxfp8_a2a_sync_compiled, + mxfp8_a2a_fwd, x, - labels, output_splits_list, input_splits_list, mesh, distributed=True, - profile_name="to_mxfp8_a2a_dequant", + profile_name="mxfp8_a2a_fwd", + ) + + # Bench mxfp8 sync a2a backward + warmup(lambda: mse_loss_and_bwd_compiled(mxfp8_routed_input, labels)) + start_sec = time.perf_counter() + mse_loss_and_bwd_compiled(mxfp8_routed_input, labels) + end_sec = time.perf_counter() + bwd_mxfp8_ms = (end_sec - start_sec) * 1e3 + if args.profile: + profile_fn( + mse_loss_and_bwd_compiled, + mxfp8_routed_input, + labels, + distributed=True, + profile_name="mxfp8_a2a_bwd", ) return ExperimentResult( - bf16_ms=bf16_ms, - mxfp8_ms=mxfp8_ms, + fwd_bf16_ms=fwd_bf16_ms, + fwd_mxfp8_ms=fwd_mxfp8_ms, + bwd_bf16_ms=bwd_bf16_ms, + bwd_mxfp8_ms=bwd_mxfp8_ms, ) @@ -203,8 +230,10 @@ def print_results(experiments: List[Experiment]): headers = [ "input_shape", "num_splits", - "bf16_ms", - "mxfp8_ms", + "fwd_bf16_ms", + "fwd_mxfp8_ms", + "bwd_bf16_ms", + "bwd_mxfp8_ms", ] rows = [] num_splits = dist.get_world_size() @@ -213,8 +242,10 @@ def print_results(experiments: List[Experiment]): [ str(experiment.config.input_shape), num_splits, - experiment.result.bf16_ms, - experiment.result.mxfp8_ms, + experiment.result.fwd_bf16_ms, + experiment.result.fwd_mxfp8_ms, + experiment.result.bwd_bf16_ms, + experiment.result.bwd_mxfp8_ms, ] ) print(tabulate(rows, headers=headers))