@@ -192,3 +192,87 @@ def test_interleave_transformer_block(self, moe_layer_freq):
192192
193193 def teardown_method (self , method ):
194194 Utils .destroy_model_parallel ()
195+
196+
197+ class TestMoELayerFP16 :
198+ """Test MoE layer with FP16 precision."""
199+
200+ def setup_method (self , method ):
201+ pass
202+
203+ @pytest .mark .parametrize ("moe_token_dispatcher_type" , ["allgather" , "alltoall" ])
204+ @pytest .mark .parametrize ("num_moe_experts" , [2 , 4 ])
205+ @pytest .mark .parametrize ("tp_size,ep_size" , [(1 , 1 ), (2 , 2 ), (4 , 2 )])
206+ def test_moe_layer_fp16_forward_backward (
207+ self , num_moe_experts , moe_token_dispatcher_type , tp_size , ep_size
208+ ):
209+ """Test MoE layer forward and backward pass with fp16 params and inputs."""
210+ Utils .initialize_model_parallel (
211+ tensor_model_parallel_size = tp_size , expert_model_parallel_size = ep_size
212+ )
213+ _set_random_seed (seed_ = 123 , data_parallel_random_init = False )
214+
215+ hidden_size = 64
216+ sequence_length = 32
217+ micro_batch_size = 2
218+
219+ transformer_config = TransformerConfig (
220+ num_layers = 1 ,
221+ hidden_size = hidden_size ,
222+ num_attention_heads = 4 ,
223+ num_moe_experts = num_moe_experts ,
224+ use_cpu_initialization = False ,
225+ moe_token_dispatcher_type = moe_token_dispatcher_type ,
226+ moe_router_load_balancing_type = "aux_loss" ,
227+ moe_router_topk = 2 ,
228+ moe_aux_loss_coeff = 0.01 ,
229+ moe_grouped_gemm = False , # Use SequentialMLP for fp16 test
230+ moe_ffn_hidden_size = 256 ,
231+ add_bias_linear = False ,
232+ tensor_model_parallel_size = tp_size ,
233+ expert_model_parallel_size = ep_size ,
234+ sequence_parallel = tp_size > 1 ,
235+ fp16 = True ,
236+ params_dtype = torch .float16 ,
237+ )
238+
239+ transformer_layer_spec = get_gpt_layer_local_spec (
240+ num_experts = num_moe_experts , moe_grouped_gemm = False
241+ )
242+
243+ moe_layer = MoELayer (
244+ transformer_config , transformer_layer_spec .submodules .mlp .submodules
245+ ).cuda ()
246+
247+ hidden_states = torch .randn (
248+ sequence_length ,
249+ micro_batch_size ,
250+ hidden_size ,
251+ device = torch .cuda .current_device (),
252+ dtype = torch .float16 ,
253+ requires_grad = True ,
254+ )
255+
256+ # Forward pass
257+ output , _ = moe_layer (hidden_states )
258+
259+ assert output .dtype == torch .float16 , f"Expected fp16 output, got { output .dtype } "
260+ assert output .shape == hidden_states .shape , f"Output shape mismatch"
261+
262+ # Backward pass
263+ loss = output .sum ()
264+ loss .backward ()
265+
266+ assert hidden_states .grad is not None , "Input gradients should exist"
267+ assert (
268+ hidden_states .grad .dtype == torch .float16
269+ ), f"Expected fp16 gradients, got { hidden_states .grad .dtype } "
270+
271+ for name , param in moe_layer .named_parameters ():
272+ if param .requires_grad :
273+ assert param .grad is not None , f"Gradient for { name } should exist"
274+
275+ Utils .destroy_model_parallel ()
276+
277+ def teardown_method (self , method ):
278+ Utils .destroy_model_parallel ()
0 commit comments