diff --git a/mistral/model.py b/mistral/model.py index 553bdb60..11df6b2e 100644 --- a/mistral/model.py +++ b/mistral/model.py @@ -14,6 +14,7 @@ from mistral.moe import MoeArgs, MoeLayer from xformers.ops.fmha import memory_efficient_attention +from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask @dataclass @@ -80,6 +81,7 @@ def forward( x: torch.Tensor, freqs_cis: torch.Tensor, cache: Optional[CacheView], + mask: Optional[torch.Tensor]=None, ) -> torch.Tensor: seqlen_sum, _ = x.shape @@ -109,9 +111,9 @@ def forward( # xformers requires (B=1, S, H, D) xq, key, val = xq[None, ...], key[None, ...], val[None, ...] - output = memory_efficient_attention( - xq, key, val, None if cache is None else cache.mask - ) + if mask is None and cache is not None: + mask = cache.mask + output = memory_efficient_attention(xq, key, val, mask) return self.wo(output.view(seqlen_sum, self.n_heads * self.head_dim)) @@ -163,9 +165,10 @@ def __init__(self, args: ModelArgs): self.feed_forward = FeedForward(args=args) def forward( - self, x: torch.Tensor, freqs_cis: torch.Tensor, cache: Optional[CacheView] + self, x: torch.Tensor, freqs_cis: torch.Tensor, cache: Optional[CacheView], + mask: Optional[torch.Tensor]=None, ) -> torch.Tensor: - r = self.attention.forward(self.attention_norm(x), freqs_cis, cache) + r = self.attention.forward(self.attention_norm(x), freqs_cis, cache, mask) h = x + r r = self.feed_forward.forward(self.ffn_norm(h)) out = h + r @@ -251,8 +254,10 @@ def forward_partial( assert sum(seqlens) == num_toks, (sum(seqlens), num_toks) if cache is not None: input_metadata = cache.get_input_metadata(seqlens) + mask = None else: input_metadata = SimpleInputMetadata.from_seqlens(seqlens, self.device) + mask = BlockDiagonalCausalMask.from_seqlens(seqlens).make_local_attention(max(seqlens)) if self.pipeline_rank == 0: assert self.tok_embeddings is not None @@ -271,7 +276,7 @@ def forward_partial( cache_view = cache.get_view(local_layer_id, input_metadata) else: cache_view = None - h = layer(h, freqs_cis, cache_view) + h = layer(h, freqs_cis, cache_view, mask) if cache is not None: cache.update_seqlens(seqlens)