Skip to content

Commit 999eb7e

Browse files
committed
refactor-attention
1 parent c87a56e commit 999eb7e

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

examples/models/llama/attention.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,12 @@ def forward(
162162

163163
@register_attention("mha")
164164
class AttentionMHA(Attention):
165-
def __init__(self, args: ModelArgs, layer_id: int, rope: Rope):
165+
def __init__(
166+
self,
167+
args: ModelArgs,
168+
layer_id: int,
169+
rope: Rope,
170+
):
166171
super().__init__()
167172
self.use_kv_cache = args.use_kv_cache
168173
self.n_heads = args.n_heads

0 commit comments

Comments
 (0)