Skip to content

Commit f5fd12b

Browse files
committed
move rope related logic together
ghstack-source-id: f688718 ghstack-comment-id: 2442844675 Pull Request resolved: #6542
1 parent 80807fd commit f5fd12b

File tree

1 file changed

+74
-63
lines changed

1 file changed

+74
-63
lines changed

examples/models/llama/llama_transformer.py

Lines changed: 74 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,69 @@ def __post_init__(self):
143143
self.hidden_dim = find_multiple(hidden_dim, multiple_of)
144144

145145

146+
class Rope(torch.nn.Module):
147+
def __init__(self, params: ModelArgs):
148+
super().__init__()
149+
self.params = params
150+
if self.params.use_hf_rope:
151+
self.precompute_freqs_cis = hf_precompute_freqs_cis
152+
else:
153+
self.precompute_freqs_cis = partial(
154+
precompute_freqs_cis, use_scaled=self.params.use_scaled_rope
155+
)
156+
freqs_cos, freqs_sin = self.precompute_freqs_cis(
157+
self.params.dim // self.params.n_heads,
158+
(
159+
self.params.max_seq_len # Normal llama2.
160+
if self.params.ffn_dim_multiplier is None
161+
else self.params.max_seq_len * 2 # Sharded checkpoint.
162+
),
163+
self.params.rope_freq_base,
164+
)
165+
self.register_buffer("freqs_cos", freqs_cos, persistent=False)
166+
self.register_buffer("freqs_sin", freqs_sin, persistent=False)
167+
if self.params.use_hf_rope:
168+
self.apply_rotary_emb = hf_apply_rotary_emb
169+
else:
170+
self.apply_rotary_emb = RotaryEmbedding()
171+
172+
def forward(
173+
self,
174+
q: torch.Tensor,
175+
k: torch.Tensor,
176+
seq_len: int,
177+
input_pos: Optional[torch.LongTensor] = None,
178+
):
179+
if self.params.use_kv_cache:
180+
assert (
181+
input_pos is not None
182+
), "input_pos must be provided when use_kv_cache is True"
183+
184+
if self.params.enable_dynamic_shape:
185+
# when KV cache is used, seqlen is most likely 1. We want to slice from the start_pos.
186+
input_pos_item = input_pos[-1].item()
187+
torch._check_is_size(input_pos_item)
188+
torch._check(input_pos_item < self.params.max_seq_len)
189+
# pyre-ignore: Incompatible parameter type [6]: torch.narrow does expect int or Tensor
190+
freqs_cos = self.freqs_cos.narrow(0, input_pos_item, seq_len)
191+
# pyre-ignore: Incompatible parameter type [6]
192+
freqs_sin = self.freqs_sin.narrow(0, input_pos_item, seq_len)
193+
else:
194+
# When not using dynamic shape, use of the .item results in
195+
# symints, due to querying the data from tensor.
196+
# this path avoids that for mps backend, although probably mps backend
197+
# can support dynamic shape?
198+
freqs_cos = self.freqs_cos[input_pos]
199+
freqs_sin = self.freqs_sin[input_pos]
200+
201+
else:
202+
assert input_pos is None, "input_pos is unused when use_kv_cache is False"
203+
freqs_cos = self.freqs_cos[:seq_len]
204+
freqs_sin = self.freqs_sin[:seq_len]
205+
q, k = self.apply_rotary_emb(q, k, freqs_cos, freqs_sin)
206+
return q, k
207+
208+
146209
class KVCache(nn.Module):
147210
def __init__(
148211
self,
@@ -262,7 +325,7 @@ def forward(
262325

263326

264327
class Attention(nn.Module):
265-
def __init__(self, args: ModelArgs, layer_id: int):
328+
def __init__(self, args: ModelArgs, layer_id: int, rope: Rope):
266329
super().__init__()
267330
self.use_kv_cache = args.use_kv_cache
268331
self.n_heads = args.n_heads
@@ -284,6 +347,8 @@ def __init__(self, args: ModelArgs, layer_id: int):
284347

285348
self.layer_id = layer_id
286349

350+
self.rope = rope
351+
287352
causal_mask = torch.tril(
288353
torch.ones(
289354
self.max_seq_len,
@@ -300,7 +365,8 @@ def __init__(self, args: ModelArgs, layer_id: int):
300365
args.max_seq_len,
301366
self.n_kv_heads,
302367
self.head_dim,
303-
not args.use_sdpa_with_kv_cache_op, # if we are using the custom op dont transpose the cache. Expect untransposed q k v
368+
not args.use_sdpa_with_kv_cache_op,
369+
# if we are using the custom op don't transpose the cache. Expect untransposed q k v
304370
args.enable_dynamic_shape,
305371
)
306372
self.SDPA = SDPA(
@@ -311,16 +377,10 @@ def __init__(self, args: ModelArgs, layer_id: int):
311377
max_seq_len=self.max_seq_len,
312378
enable_dynamic_shape=args.enable_dynamic_shape,
313379
)
314-
if args.use_hf_rope:
315-
self.apply_rotary_emb = hf_apply_rotary_emb
316-
else:
317-
self.apply_rotary_emb = RotaryEmbedding()
318380

319381
def forward(
320382
self,
321383
x: torch.Tensor,
322-
freqs_cos: torch.Tensor,
323-
freqs_sin: torch.Tensor,
324384
input_pos: Optional[torch.Tensor] = None,
325385
):
326386
bsz, seqlen, _ = x.shape
@@ -333,7 +393,7 @@ def forward(
333393
v = v.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
334394

335395
# RoPE relative positional embeddings
336-
q, k = self.apply_rotary_emb(q, k, freqs_cos, freqs_sin)
396+
q, k = self.rope.forward(q, k, seqlen, input_pos)
337397

338398
if self.use_kv_cache:
339399
assert input_pos is not None
@@ -421,13 +481,13 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
421481

422482

423483
class TransformerBlock(nn.Module):
424-
def __init__(self, layer_id: int, args: ModelArgs):
484+
def __init__(self, layer_id: int, args: ModelArgs, rope: Rope):
425485
super().__init__()
426486
self.use_kv_cache = args.use_kv_cache
427487
self.n_heads = args.n_heads
428488
self.dim = args.dim
429489
self.head_dim = args.dim // args.n_heads
430-
self.attention = Attention(args, layer_id)
490+
self.attention = Attention(args, layer_id, rope)
431491
if args.moe:
432492
self.block_sparse_moe = MOEFeedForward(args)
433493
else:
@@ -456,33 +516,17 @@ def __init__(self, params: ModelArgs):
456516
self.n_layers = params.n_layers
457517

458518
self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim)
519+
self.rope = Rope(params)
459520
self.layers = torch.nn.ModuleList()
460521
for layer_id in range(params.n_layers):
461-
self.layers.append(TransformerBlock(layer_id, params))
522+
self.layers.append(TransformerBlock(layer_id, params, self.rope))
462523
self.norm = RMSNorm(params.dim, eps=params.norm_eps)
463524
self.output = nn.Linear(params.dim, params.vocab_size, bias=False)
464525
self.use_kv_cache = params.use_kv_cache
465526
self.generate_full_logits = params.generate_full_logits
466527
self.max_seq_len = params.max_seq_len
467528
self.input_prune_map = params.input_prune_map
468529
self.output_prune_map = params.output_prune_map
469-
if params.use_hf_rope:
470-
self.precompute_freqs_cis = hf_precompute_freqs_cis
471-
else:
472-
self.precompute_freqs_cis = partial(
473-
precompute_freqs_cis, use_scaled=params.use_scaled_rope
474-
)
475-
freqs_cos, freqs_sin = self.precompute_freqs_cis(
476-
params.dim // params.n_heads,
477-
(
478-
params.max_seq_len # Normal llama2.
479-
if params.ffn_dim_multiplier is None
480-
else params.max_seq_len * 2 # Sharded checkpoint.
481-
),
482-
params.rope_freq_base,
483-
)
484-
self.register_buffer("freqs_cos", freqs_cos, persistent=False)
485-
self.register_buffer("freqs_sin", freqs_sin, persistent=False)
486530

487531
def forward(
488532
self,
@@ -498,42 +542,9 @@ def forward(
498542
)
499543
if tokens is not None and h is None:
500544
h = self.tok_embeddings(tokens)
501-
seqlen = h.shape[1]
502-
503-
if self.use_kv_cache:
504-
assert (
505-
input_pos is not None
506-
), "input_pos must be provided when use_kv_cache is True"
507-
508-
if self.params.enable_dynamic_shape:
509-
# when KV cache is used, seqlen is most likely 1. We want to slice from the start_pos.
510-
input_pos_item = input_pos[-1].item()
511-
torch._check_is_size(input_pos_item)
512-
torch._check(input_pos_item < self.params.max_seq_len)
513-
# pyre-ignore: Incompatible parameter type [6]: torch.narrow does expect int or Tensor
514-
freqs_cos = self.freqs_cos.narrow(0, input_pos_item, seqlen)
515-
# pyre-ignore: Incompatible parameter type [6]
516-
freqs_sin = self.freqs_sin.narrow(0, input_pos_item, seqlen)
517-
else:
518-
# When not using dynamic shape, use of the .item results in
519-
# symints, due to querying the data from tensor.
520-
# this path avoids that for mps backend, although probably mps backend
521-
# can support dynamic shape?
522-
freqs_cos = self.freqs_cos[input_pos]
523-
freqs_sin = self.freqs_sin[input_pos]
524-
525-
else:
526-
assert input_pos is None, "input_pos is unused when use_kv_cache is False"
527-
freqs_cos = self.freqs_cos[:seqlen]
528-
freqs_sin = self.freqs_sin[:seqlen]
529545

530546
for layer in self.layers:
531-
h = layer(
532-
h,
533-
freqs_cos,
534-
freqs_sin,
535-
input_pos,
536-
)
547+
h = layer(h, input_pos)
537548

538549
if not self.generate_full_logits:
539550
# Only the last logit is used for the new generated token

0 commit comments

Comments
 (0)