Skip to content

Commit aee4a74

Browse files
authored
[Dev] remove fp16 assert in moe_grouped_gemm & EP (#2494)
1 parent a6d86a6 commit aee4a74

File tree

3 files changed

+85
-3
lines changed

3 files changed

+85
-3
lines changed

megatron/core/transformer/moe/experts.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,7 @@ def forward(
238238
permuted_probs: torch.Tensor,
239239
):
240240
"""Forward step of the GroupedMLP."""
241+
assert self.config.bf16, "Currently GroupedGEMM for MoE only supports bf16."
241242
if self.activation_recompute:
242243
self.activation_checkpoint = tensor_parallel.CheckpointWithoutOutput()
243244

megatron/training/arguments.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -900,7 +900,6 @@ def validate_args(args, defaults={}):
900900
'residual connection in fp32 only supported when using fp16 or bf16.'
901901

902902
if args.moe_grouped_gemm:
903-
assert args.bf16, 'Currently GroupedGEMM for MoE only supports bf16 dtype.'
904903
dc = torch.cuda.get_device_capability()
905904
assert dc[0] >= 8, "Unsupported compute capability for GroupedGEMM kernels."
906905

@@ -1084,8 +1083,6 @@ def validate_args(args, defaults={}):
10841083
assert args.num_experts is not None, "num_experts must be non None to use expert model parallelism"
10851084
assert args.num_experts % args.expert_model_parallel_size == 0, \
10861085
"Number of experts should be a multiple of expert model parallel_size."
1087-
assert not args.fp16, \
1088-
"Expert parallelism is not supported with fp16 training."
10891086

10901087
# MoE router check
10911088
if isinstance(args.moe_router_load_balancing_type, list) and len(args.moe_router_load_balancing_type) == 1:

tests/unit_tests/transformer/moe/test_moe_layer.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,3 +192,87 @@ def test_interleave_transformer_block(self, moe_layer_freq):
192192

193193
def teardown_method(self, method):
194194
Utils.destroy_model_parallel()
195+
196+
197+
class TestMoELayerFP16:
198+
"""Test MoE layer with FP16 precision."""
199+
200+
def setup_method(self, method):
201+
pass
202+
203+
@pytest.mark.parametrize("moe_token_dispatcher_type", ["allgather", "alltoall"])
204+
@pytest.mark.parametrize("num_moe_experts", [2, 4])
205+
@pytest.mark.parametrize("tp_size,ep_size", [(1, 1), (2, 2), (4, 2)])
206+
def test_moe_layer_fp16_forward_backward(
207+
self, num_moe_experts, moe_token_dispatcher_type, tp_size, ep_size
208+
):
209+
"""Test MoE layer forward and backward pass with fp16 params and inputs."""
210+
Utils.initialize_model_parallel(
211+
tensor_model_parallel_size=tp_size, expert_model_parallel_size=ep_size
212+
)
213+
_set_random_seed(seed_=123, data_parallel_random_init=False)
214+
215+
hidden_size = 64
216+
sequence_length = 32
217+
micro_batch_size = 2
218+
219+
transformer_config = TransformerConfig(
220+
num_layers=1,
221+
hidden_size=hidden_size,
222+
num_attention_heads=4,
223+
num_moe_experts=num_moe_experts,
224+
use_cpu_initialization=False,
225+
moe_token_dispatcher_type=moe_token_dispatcher_type,
226+
moe_router_load_balancing_type="aux_loss",
227+
moe_router_topk=2,
228+
moe_aux_loss_coeff=0.01,
229+
moe_grouped_gemm=False, # Use SequentialMLP for fp16 test
230+
moe_ffn_hidden_size=256,
231+
add_bias_linear=False,
232+
tensor_model_parallel_size=tp_size,
233+
expert_model_parallel_size=ep_size,
234+
sequence_parallel=tp_size > 1,
235+
fp16=True,
236+
params_dtype=torch.float16,
237+
)
238+
239+
transformer_layer_spec = get_gpt_layer_local_spec(
240+
num_experts=num_moe_experts, moe_grouped_gemm=False
241+
)
242+
243+
moe_layer = MoELayer(
244+
transformer_config, transformer_layer_spec.submodules.mlp.submodules
245+
).cuda()
246+
247+
hidden_states = torch.randn(
248+
sequence_length,
249+
micro_batch_size,
250+
hidden_size,
251+
device=torch.cuda.current_device(),
252+
dtype=torch.float16,
253+
requires_grad=True,
254+
)
255+
256+
# Forward pass
257+
output, _ = moe_layer(hidden_states)
258+
259+
assert output.dtype == torch.float16, f"Expected fp16 output, got {output.dtype}"
260+
assert output.shape == hidden_states.shape, f"Output shape mismatch"
261+
262+
# Backward pass
263+
loss = output.sum()
264+
loss.backward()
265+
266+
assert hidden_states.grad is not None, "Input gradients should exist"
267+
assert (
268+
hidden_states.grad.dtype == torch.float16
269+
), f"Expected fp16 gradients, got {hidden_states.grad.dtype}"
270+
271+
for name, param in moe_layer.named_parameters():
272+
if param.requires_grad:
273+
assert param.grad is not None, f"Gradient for {name} should exist"
274+
275+
Utils.destroy_model_parallel()
276+
277+
def teardown_method(self, method):
278+
Utils.destroy_model_parallel()

0 commit comments

Comments
 (0)