Skip to content

Commit f17c075

Browse files
authored
[Model] Switch to Fused RMSNorm in GLM-4.1V model (vllm-project#24733)
Signed-off-by: SamitHuang <[email protected]>
1 parent b0d1213 commit f17c075

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

vllm/model_executor/models/glm4_1v.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -419,15 +419,16 @@ def forward(
419419
max_seqlen: Optional[int] = None, # Only used for Flash Attention
420420
seqlens: Optional[list[int]] = None, # Only used for xFormers
421421
) -> torch.Tensor:
422-
x = x + self.attn(
422+
x_attn = self.attn(
423423
self.norm1(x),
424424
cu_seqlens=cu_seqlens,
425425
rotary_pos_emb=rotary_pos_emb,
426426
max_seqlen=max_seqlen,
427427
seqlens=seqlens,
428428
)
429+
x_fused_norm, residual = self.norm2(x, residual=x_attn)
430+
x = residual + self.mlp(x_fused_norm)
429431

430-
x = x + self.mlp(self.norm2(x))
431432
return x
432433

433434

0 commit comments

Comments
 (0)