Skip to content

Commit 57fb2e2

Browse files
committed
Update
[ghstack-poisoned]
1 parent 8792a4d commit 57fb2e2

File tree

1 file changed

+3
-6
lines changed

1 file changed

+3
-6
lines changed

examples/models/llama/llama_transformer.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -365,8 +365,7 @@ def __init__(self, args: ModelArgs, layer_id: int, rope: Rope):
365365
args.max_seq_len,
366366
self.n_kv_heads,
367367
self.head_dim,
368-
not args.use_sdpa_with_kv_cache_op,
369-
# if we are using the custom op don't transpose the cache. Expect untransposed q k v
368+
not args.use_sdpa_with_kv_cache_op, # if we are using the custom op don't transpose the cache. Expect untransposed q k v
370369
args.enable_dynamic_shape,
371370
)
372371
self.SDPA = SDPA(
@@ -495,10 +494,8 @@ def __init__(self, layer_id: int, args: ModelArgs, rope: Rope):
495494
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
496495
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
497496

498-
def forward(self, x, freqs_cos, freqs_sin, input_pos=None): # x: 1xN
499-
h = self.attention.forward(
500-
self.attention_norm(x), freqs_cos, freqs_sin, input_pos
501-
)
497+
def forward(self, x, input_pos=None): # x: 1xN
498+
h = self.attention.forward(self.attention_norm(x), input_pos)
502499

503500
h = x + h
504501
if hasattr(self, "block_sparse_moe"):

0 commit comments

Comments
 (0)