diff --git a/src/mistral_inference/transformer_layers.py b/src/mistral_inference/transformer_layers.py index 4ee23f56..1069ec39 100644 --- a/src/mistral_inference/transformer_layers.py +++ b/src/mistral_inference/transformer_layers.py @@ -162,7 +162,7 @@ def forward( cache: Optional[CacheView] = None, mask: Optional[BlockDiagonalMask] = None, ) -> torch.Tensor: - r = self.attention.forward(self.attention_norm(x), freqs_cis, cache) + r = self.attention.forward(x=self.attention_norm(x), freqs_cis=freqs_cis, cache=cache, mask=mask) h = x + r r = self.feed_forward.forward(self.ffn_norm(h)) out = h + r