Skip to content

Commit 41dc47b

Browse files
committed
[TEST ONLY]Drop scaled_dot_product_attention
[ghstack-poisoned]
1 parent 7493aae commit 41dc47b

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

examples/models/llama/llama_transformer.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
# Please refer to README.md in the same folder for more information.
99

10+
import math
1011
from dataclasses import dataclass
1112
from functools import partial
1213
from typing import Dict, Optional, Tuple
@@ -251,7 +252,10 @@ def forward(
251252

252253
k = k.repeat_interleave(self.n_rep, dim=1)
253254
v = v.repeat_interleave(self.n_rep, dim=1)
254-
y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=0.0)
255+
scores = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(self.head_dim)
256+
scores = scores + attn_mask
257+
scores = F.softmax(scores.float(), dim=-1).type_as(q)
258+
y = torch.matmul(scores, v)
255259

256260
return y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
257261

0 commit comments

Comments
 (0)