@@ -470,54 +470,59 @@ def forward(
470470 )
471471 permuted_probs = permuted_probs .unsqueeze (- 1 )
472472
473+ # Helper to get local weights. In production, all weights are DTensor after parallelization.
474+ # The isinstance check is for unit tests where parallelizer is mocked and weights remain Parameters.
475+ def to_local (proj ):
476+ return proj .to_local () if isinstance (proj , DTensor ) else proj
477+
473478 if torch .count_nonzero (tokens_per_expert ) > 0 :
474479 # 1. Gate + Up Projection
475480 output1 = ops .gmm (
476481 permuted_local_hidden_states ,
477- self .gate_and_up_projs . to_local ( ),
482+ to_local ( self .gate_and_up_projs ),
478483 tokens_per_expert ,
479484 trans_b = False ,
480485 )
481486
482487 # Add LoRA
483488 lora_out1_A = ops .gmm (
484489 permuted_local_hidden_states ,
485- self .lora_gate_and_up_A . to_local ( ),
490+ to_local ( self .lora_gate_and_up_A ),
486491 tokens_per_expert ,
487492 trans_b = False ,
488493 )
489494 # [T, R] @ [E_local, R, H] -> [T, H]
490495 lora_out1 = ops .gmm (
491- lora_out1_A , self .lora_gate_and_up_B . to_local ( ), tokens_per_expert , trans_b = False
496+ lora_out1_A , to_local ( self .lora_gate_and_up_B ), tokens_per_expert , trans_b = False
492497 )
493498
494499 output1 = output1 + lora_out1 * self .scale
495500
496501 if self .expert_bias :
497- gate_and_up_bias = self .gate_up_proj_bias . to_local ( )
502+ gate_and_up_bias = to_local ( self .gate_up_proj_bias )
498503 output1 = self ._apply_bias (output1 , gate_and_up_bias , tokens_per_expert )
499504
500505 output1 = self .expert_activation (output1 , permuted_probs )
501506
502507 # 2. Down Projection
503- output2 = ops .gmm (output1 , self .down_projs . to_local ( ), tokens_per_expert , trans_b = False )
508+ output2 = ops .gmm (output1 , to_local ( self .down_projs ), tokens_per_expert , trans_b = False )
504509
505510 # Add LoRA
506- lora_out2_A = ops .gmm (output1 , self .lora_down_A . to_local ( ), tokens_per_expert , trans_b = False )
507- lora_out2 = ops .gmm (lora_out2_A , self .lora_down_B . to_local ( ), tokens_per_expert , trans_b = False )
511+ lora_out2_A = ops .gmm (output1 , to_local ( self .lora_down_A ), tokens_per_expert , trans_b = False )
512+ lora_out2 = ops .gmm (lora_out2_A , to_local ( self .lora_down_B ), tokens_per_expert , trans_b = False )
508513 output2 = output2 + lora_out2 * self .scale
509514
510515 if self .expert_bias :
511- down_bias = self .down_proj_bias . to_local ( )
516+ down_bias = to_local ( self .down_proj_bias )
512517 output2 = self ._apply_bias (output2 , down_bias , tokens_per_expert , permuted_probs )
513518 else :
514519 # Handle empty case for DeepEP - use [0] indexing to match base shapes exactly
515- W1 = self .gate_and_up_projs . to_local ( )[0 ] # [dim, 2*inter]
516- W2 = self .down_projs . to_local ( )[0 ] # [inter, dim]
517- A1 = self .lora_gate_and_up_A . to_local ( )[0 ] # [dim, r]
518- B1 = self .lora_gate_and_up_B . to_local ( )[0 ] # [r, 2*inter]
519- A2 = self .lora_down_A . to_local ( )[0 ] # [inter, r]
520- B2 = self .lora_down_B . to_local ( )[0 ] # [r, dim]
520+ W1 = to_local ( self .gate_and_up_projs )[0 ] # [dim, 2*inter]
521+ W2 = to_local ( self .down_projs )[0 ] # [inter, dim]
522+ A1 = to_local ( self .lora_gate_and_up_A )[0 ] # [dim, r]
523+ B1 = to_local ( self .lora_gate_and_up_B )[0 ] # [r, 2*inter]
524+ A2 = to_local ( self .lora_down_A )[0 ] # [inter, r]
525+ B2 = to_local ( self .lora_down_B )[0 ] # [r, dim]
521526
522527 dummy_x = x [0 ] * 0 # [dim]
523528
0 commit comments