We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent b0d1213 commit f17c075Copy full SHA for f17c075
vllm/model_executor/models/glm4_1v.py
@@ -419,15 +419,16 @@ def forward(
419
max_seqlen: Optional[int] = None, # Only used for Flash Attention
420
seqlens: Optional[list[int]] = None, # Only used for xFormers
421
) -> torch.Tensor:
422
- x = x + self.attn(
+ x_attn = self.attn(
423
self.norm1(x),
424
cu_seqlens=cu_seqlens,
425
rotary_pos_emb=rotary_pos_emb,
426
max_seqlen=max_seqlen,
427
seqlens=seqlens,
428
)
429
+ x_fused_norm, residual = self.norm2(x, residual=x_attn)
430
+ x = residual + self.mlp(x_fused_norm)
431
- x = x + self.mlp(self.norm2(x))
432
return x
433
434
0 commit comments