Skip to content

Commit e91a736

Browse files
committed
fix unit test: Add isinstance check for to_local() to support mocked unit tests
Signed-off-by: Yuhe Zhang <[email protected]>
1 parent b7f9358 commit e91a736

File tree

1 file changed

+19
-14
lines changed

1 file changed

+19
-14
lines changed

nemo_automodel/components/_peft/lora_moe.py

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -470,54 +470,59 @@ def forward(
470470
)
471471
permuted_probs = permuted_probs.unsqueeze(-1)
472472

473+
# Helper to get local weights. In production, all weights are DTensor after parallelization.
474+
# The isinstance check is for unit tests where parallelizer is mocked and weights remain Parameters.
475+
def to_local(proj):
476+
return proj.to_local() if isinstance(proj, DTensor) else proj
477+
473478
if torch.count_nonzero(tokens_per_expert) > 0:
474479
# 1. Gate + Up Projection
475480
output1 = ops.gmm(
476481
permuted_local_hidden_states,
477-
self.gate_and_up_projs.to_local(),
482+
to_local(self.gate_and_up_projs),
478483
tokens_per_expert,
479484
trans_b=False,
480485
)
481486

482487
# Add LoRA
483488
lora_out1_A = ops.gmm(
484489
permuted_local_hidden_states,
485-
self.lora_gate_and_up_A.to_local(),
490+
to_local(self.lora_gate_and_up_A),
486491
tokens_per_expert,
487492
trans_b=False,
488493
)
489494
# [T, R] @ [E_local, R, H] -> [T, H]
490495
lora_out1 = ops.gmm(
491-
lora_out1_A, self.lora_gate_and_up_B.to_local(), tokens_per_expert, trans_b=False
496+
lora_out1_A, to_local(self.lora_gate_and_up_B), tokens_per_expert, trans_b=False
492497
)
493498

494499
output1 = output1 + lora_out1 * self.scale
495500

496501
if self.expert_bias:
497-
gate_and_up_bias = self.gate_up_proj_bias.to_local()
502+
gate_and_up_bias = to_local(self.gate_up_proj_bias)
498503
output1 = self._apply_bias(output1, gate_and_up_bias, tokens_per_expert)
499504

500505
output1 = self.expert_activation(output1, permuted_probs)
501506

502507
# 2. Down Projection
503-
output2 = ops.gmm(output1, self.down_projs.to_local(), tokens_per_expert, trans_b=False)
508+
output2 = ops.gmm(output1, to_local(self.down_projs), tokens_per_expert, trans_b=False)
504509

505510
# Add LoRA
506-
lora_out2_A = ops.gmm(output1, self.lora_down_A.to_local(), tokens_per_expert, trans_b=False)
507-
lora_out2 = ops.gmm(lora_out2_A, self.lora_down_B.to_local(), tokens_per_expert, trans_b=False)
511+
lora_out2_A = ops.gmm(output1, to_local(self.lora_down_A), tokens_per_expert, trans_b=False)
512+
lora_out2 = ops.gmm(lora_out2_A, to_local(self.lora_down_B), tokens_per_expert, trans_b=False)
508513
output2 = output2 + lora_out2 * self.scale
509514

510515
if self.expert_bias:
511-
down_bias = self.down_proj_bias.to_local()
516+
down_bias = to_local(self.down_proj_bias)
512517
output2 = self._apply_bias(output2, down_bias, tokens_per_expert, permuted_probs)
513518
else:
514519
# Handle empty case for DeepEP - use [0] indexing to match base shapes exactly
515-
W1 = self.gate_and_up_projs.to_local()[0] # [dim, 2*inter]
516-
W2 = self.down_projs.to_local()[0] # [inter, dim]
517-
A1 = self.lora_gate_and_up_A.to_local()[0] # [dim, r]
518-
B1 = self.lora_gate_and_up_B.to_local()[0] # [r, 2*inter]
519-
A2 = self.lora_down_A.to_local()[0] # [inter, r]
520-
B2 = self.lora_down_B.to_local()[0] # [r, dim]
520+
W1 = to_local(self.gate_and_up_projs)[0] # [dim, 2*inter]
521+
W2 = to_local(self.down_projs)[0] # [inter, dim]
522+
A1 = to_local(self.lora_gate_and_up_A)[0] # [dim, r]
523+
B1 = to_local(self.lora_gate_and_up_B)[0] # [r, 2*inter]
524+
A2 = to_local(self.lora_down_A)[0] # [inter, r]
525+
B2 = to_local(self.lora_down_B)[0] # [r, dim]
521526

522527
dummy_x = x[0] * 0 # [dim]
523528

0 commit comments

Comments
 (0)