diff --git a/examples/models/llama/llama_transformer.py b/examples/models/llama/llama_transformer.py index 4d39d131d1d..f6410a47ff2 100644 --- a/examples/models/llama/llama_transformer.py +++ b/examples/models/llama/llama_transformer.py @@ -7,6 +7,7 @@ # Please refer to README.md in the same folder for more information. +import math from dataclasses import dataclass from functools import partial from typing import Dict, Optional, Tuple @@ -251,7 +252,10 @@ def forward( k = k.repeat_interleave(self.n_rep, dim=1) v = v.repeat_interleave(self.n_rep, dim=1) - y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=0.0) + scores = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(self.head_dim) + scores = scores + attn_mask + scores = F.softmax(scores.float(), dim=-1).type_as(q) + y = torch.matmul(scores, v) return y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)