Skip to content

Commit 2c54af6

Browse files
committed
move rope related logic together
ghstack-source-id: b209862 ghstack-comment-id: 2442844675 Pull Request resolved: #6542
1 parent 80807fd commit 2c54af6

File tree

1 file changed

+101
-96
lines changed

1 file changed

+101
-96
lines changed

examples/models/llama/llama_transformer.py

Lines changed: 101 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -143,16 +143,73 @@ def __post_init__(self):
143143
self.hidden_dim = find_multiple(hidden_dim, multiple_of)
144144

145145

146+
class Rope(torch.nn.Module):
147+
def __init__(self, params: ModelArgs):
148+
super().__init__()
149+
self.params = params
150+
if self.params.use_hf_rope:
151+
self.precompute_freqs_cis = hf_precompute_freqs_cis
152+
else:
153+
self.precompute_freqs_cis = partial(
154+
precompute_freqs_cis, use_scaled=self.params.use_scaled_rope
155+
)
156+
freqs_cos, freqs_sin = self.precompute_freqs_cis(
157+
self.params.dim // self.params.n_heads,
158+
(
159+
self.params.max_seq_len # Normal llama2.
160+
if self.params.ffn_dim_multiplier is None
161+
else self.params.max_seq_len * 2 # Sharded checkpoint.
162+
),
163+
self.params.rope_freq_base,
164+
)
165+
self.register_buffer("freqs_cos", freqs_cos, persistent=False)
166+
self.register_buffer("freqs_sin", freqs_sin, persistent=False)
167+
if self.params.use_hf_rope:
168+
self.apply_rotary_emb = hf_apply_rotary_emb
169+
else:
170+
self.apply_rotary_emb = RotaryEmbedding()
171+
172+
def forward(self, q: torch.Tensor, k: torch.Tensor, seq_len: int, input_pos: Optional[torch.LongTensor] = None):
173+
if self.params.use_kv_cache:
174+
assert (
175+
input_pos is not None
176+
), "input_pos must be provided when use_kv_cache is True"
177+
178+
if self.params.enable_dynamic_shape:
179+
# when KV cache is used, seqlen is most likely 1. We want to slice from the start_pos.
180+
input_pos_item = input_pos[-1].item()
181+
torch._check_is_size(input_pos_item)
182+
torch._check(input_pos_item < self.params.max_seq_len)
183+
# pyre-ignore: Incompatible parameter type [6]: torch.narrow does expect int or Tensor
184+
freqs_cos = self.freqs_cos.narrow(0, input_pos_item, seq_len)
185+
# pyre-ignore: Incompatible parameter type [6]
186+
freqs_sin = self.freqs_sin.narrow(0, input_pos_item, seq_len)
187+
else:
188+
# When not using dynamic shape, use of the .item results in
189+
# symints, due to querying the data from tensor.
190+
# this path avoids that for mps backend, although probably mps backend
191+
# can support dynamic shape?
192+
freqs_cos = self.freqs_cos[input_pos]
193+
freqs_sin = self.freqs_sin[input_pos]
194+
195+
else:
196+
assert input_pos is None, "input_pos is unused when use_kv_cache is False"
197+
freqs_cos = self.freqs_cos[:seq_len]
198+
freqs_sin = self.freqs_sin[:seq_len]
199+
q, k = self.apply_rotary_emb(q, k, freqs_cos, freqs_sin)
200+
return q, k
201+
202+
146203
class KVCache(nn.Module):
147204
def __init__(
148-
self,
149-
max_batch_size: int,
150-
max_seq_length: int,
151-
n_heads: int,
152-
head_dim: int,
153-
transpose_cache: bool,
154-
enable_dynamic_shape: bool,
155-
dtype=torch.float32,
205+
self,
206+
max_batch_size: int,
207+
max_seq_length: int,
208+
n_heads: int,
209+
head_dim: int,
210+
transpose_cache: bool,
211+
enable_dynamic_shape: bool,
212+
dtype=torch.float32,
156213
):
157214
super().__init__()
158215
self.max_seq_length = max_seq_length
@@ -175,7 +232,7 @@ def __init__(
175232
)
176233

177234
def update(
178-
self, input_pos: torch.Tensor, k_val: torch.Tensor, v_val: torch.Tensor
235+
self, input_pos: torch.Tensor, k_val: torch.Tensor, v_val: torch.Tensor
179236
) -> Tuple[torch.Tensor, torch.Tensor]:
180237
# input_pos: [S], k_val: [B, H, S, D] or [B, S, H, D] depending on transpose_cache
181238
if self.enable_dynamic_shape:
@@ -213,13 +270,13 @@ def update(
213270

214271
class SDPA(nn.Module):
215272
def __init__(
216-
self,
217-
kv_cache: KVCache,
218-
dim: int,
219-
head_dim: int,
220-
n_rep: int,
221-
max_seq_len: int,
222-
enable_dynamic_shape: bool,
273+
self,
274+
kv_cache: KVCache,
275+
dim: int,
276+
head_dim: int,
277+
n_rep: int,
278+
max_seq_len: int,
279+
enable_dynamic_shape: bool,
223280
):
224281
super().__init__()
225282
self.kv_cache = kv_cache
@@ -230,14 +287,14 @@ def __init__(
230287
self.enable_dynamic_shape = enable_dynamic_shape
231288

232289
def forward(
233-
self,
234-
input_pos: torch.Tensor,
235-
q: torch.Tensor, # Already have rotary embeddings. (bs, seqlen, n_local_heads, head_dim)
236-
k: torch.Tensor, # Already have rotary embeddings. (bs, seqlen, n_local_kv_heads, head_dim)
237-
v: torch.Tensor, # (bs, seqlen, n_local_kv_heads, head_dim)
238-
bsz,
239-
seqlen,
240-
mask: torch.Tensor,
290+
self,
291+
input_pos: torch.Tensor,
292+
q: torch.Tensor, # Already have rotary embeddings. (bs, seqlen, n_local_heads, head_dim)
293+
k: torch.Tensor, # Already have rotary embeddings. (bs, seqlen, n_local_kv_heads, head_dim)
294+
v: torch.Tensor, # (bs, seqlen, n_local_kv_heads, head_dim)
295+
bsz,
296+
seqlen,
297+
mask: torch.Tensor,
241298
) -> torch.Tensor:
242299
q = q.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
243300
k = k.transpose(1, 2)
@@ -262,7 +319,7 @@ def forward(
262319

263320

264321
class Attention(nn.Module):
265-
def __init__(self, args: ModelArgs, layer_id: int):
322+
def __init__(self, args: ModelArgs, layer_id: int, rope: Rope):
266323
super().__init__()
267324
self.use_kv_cache = args.use_kv_cache
268325
self.n_heads = args.n_heads
@@ -284,6 +341,8 @@ def __init__(self, args: ModelArgs, layer_id: int):
284341

285342
self.layer_id = layer_id
286343

344+
self.rope = rope
345+
287346
causal_mask = torch.tril(
288347
torch.ones(
289348
self.max_seq_len,
@@ -300,7 +359,8 @@ def __init__(self, args: ModelArgs, layer_id: int):
300359
args.max_seq_len,
301360
self.n_kv_heads,
302361
self.head_dim,
303-
not args.use_sdpa_with_kv_cache_op, # if we are using the custom op dont transpose the cache. Expect untransposed q k v
362+
not args.use_sdpa_with_kv_cache_op,
363+
# if we are using the custom op don't transpose the cache. Expect untransposed q k v
304364
args.enable_dynamic_shape,
305365
)
306366
self.SDPA = SDPA(
@@ -311,17 +371,11 @@ def __init__(self, args: ModelArgs, layer_id: int):
311371
max_seq_len=self.max_seq_len,
312372
enable_dynamic_shape=args.enable_dynamic_shape,
313373
)
314-
if args.use_hf_rope:
315-
self.apply_rotary_emb = hf_apply_rotary_emb
316-
else:
317-
self.apply_rotary_emb = RotaryEmbedding()
318374

319375
def forward(
320-
self,
321-
x: torch.Tensor,
322-
freqs_cos: torch.Tensor,
323-
freqs_sin: torch.Tensor,
324-
input_pos: Optional[torch.Tensor] = None,
376+
self,
377+
x: torch.Tensor,
378+
input_pos: Optional[torch.Tensor] = None,
325379
):
326380
bsz, seqlen, _ = x.shape
327381

@@ -333,7 +387,7 @@ def forward(
333387
v = v.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
334388

335389
# RoPE relative positional embeddings
336-
q, k = self.apply_rotary_emb(q, k, freqs_cos, freqs_sin)
390+
q, k = self.rope.forward(q, k, seqlen, input_pos)
337391

338392
if self.use_kv_cache:
339393
assert input_pos is not None
@@ -421,13 +475,13 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
421475

422476

423477
class TransformerBlock(nn.Module):
424-
def __init__(self, layer_id: int, args: ModelArgs):
478+
def __init__(self, layer_id: int, args: ModelArgs, rope: Rope):
425479
super().__init__()
426480
self.use_kv_cache = args.use_kv_cache
427481
self.n_heads = args.n_heads
428482
self.dim = args.dim
429483
self.head_dim = args.dim // args.n_heads
430-
self.attention = Attention(args, layer_id)
484+
self.attention = Attention(args, layer_id, rope)
431485
if args.moe:
432486
self.block_sparse_moe = MOEFeedForward(args)
433487
else:
@@ -456,84 +510,35 @@ def __init__(self, params: ModelArgs):
456510
self.n_layers = params.n_layers
457511

458512
self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim)
513+
self.rope = Rope(params)
459514
self.layers = torch.nn.ModuleList()
460515
for layer_id in range(params.n_layers):
461-
self.layers.append(TransformerBlock(layer_id, params))
516+
self.layers.append(TransformerBlock(layer_id, params, self.rope))
462517
self.norm = RMSNorm(params.dim, eps=params.norm_eps)
463518
self.output = nn.Linear(params.dim, params.vocab_size, bias=False)
464519
self.use_kv_cache = params.use_kv_cache
465520
self.generate_full_logits = params.generate_full_logits
466521
self.max_seq_len = params.max_seq_len
467522
self.input_prune_map = params.input_prune_map
468523
self.output_prune_map = params.output_prune_map
469-
if params.use_hf_rope:
470-
self.precompute_freqs_cis = hf_precompute_freqs_cis
471-
else:
472-
self.precompute_freqs_cis = partial(
473-
precompute_freqs_cis, use_scaled=params.use_scaled_rope
474-
)
475-
freqs_cos, freqs_sin = self.precompute_freqs_cis(
476-
params.dim // params.n_heads,
477-
(
478-
params.max_seq_len # Normal llama2.
479-
if params.ffn_dim_multiplier is None
480-
else params.max_seq_len * 2 # Sharded checkpoint.
481-
),
482-
params.rope_freq_base,
483-
)
484-
self.register_buffer("freqs_cos", freqs_cos, persistent=False)
485-
self.register_buffer("freqs_sin", freqs_sin, persistent=False)
486524

487525
def forward(
488-
self,
489-
tokens: Optional[torch.LongTensor] = None, # tokens
490-
input_pos: Optional[
491-
torch.LongTensor
492-
] = None, # Scalar tensor indicating size of window of the caches
493-
h: Optional[torch.FloatTensor] = None, # embeddings
526+
self,
527+
tokens: Optional[torch.LongTensor] = None, # tokens
528+
input_pos: Optional[
529+
torch.LongTensor
530+
] = None, # Scalar tensor indicating size of window of the caches
531+
h: Optional[torch.FloatTensor] = None, # embeddings
494532
) -> torch.Tensor:
495533
if (tokens is None) ^ (h is not None):
496534
raise ValueError(
497535
"You cannot specify both tokens and h at the same time, and must specify either one"
498536
)
499537
if tokens is not None and h is None:
500538
h = self.tok_embeddings(tokens)
501-
seqlen = h.shape[1]
502-
503-
if self.use_kv_cache:
504-
assert (
505-
input_pos is not None
506-
), "input_pos must be provided when use_kv_cache is True"
507-
508-
if self.params.enable_dynamic_shape:
509-
# when KV cache is used, seqlen is most likely 1. We want to slice from the start_pos.
510-
input_pos_item = input_pos[-1].item()
511-
torch._check_is_size(input_pos_item)
512-
torch._check(input_pos_item < self.params.max_seq_len)
513-
# pyre-ignore: Incompatible parameter type [6]: torch.narrow does expect int or Tensor
514-
freqs_cos = self.freqs_cos.narrow(0, input_pos_item, seqlen)
515-
# pyre-ignore: Incompatible parameter type [6]
516-
freqs_sin = self.freqs_sin.narrow(0, input_pos_item, seqlen)
517-
else:
518-
# When not using dynamic shape, use of the .item results in
519-
# symints, due to querying the data from tensor.
520-
# this path avoids that for mps backend, although probably mps backend
521-
# can support dynamic shape?
522-
freqs_cos = self.freqs_cos[input_pos]
523-
freqs_sin = self.freqs_sin[input_pos]
524-
525-
else:
526-
assert input_pos is None, "input_pos is unused when use_kv_cache is False"
527-
freqs_cos = self.freqs_cos[:seqlen]
528-
freqs_sin = self.freqs_sin[:seqlen]
529539

530540
for layer in self.layers:
531-
h = layer(
532-
h,
533-
freqs_cos,
534-
freqs_sin,
535-
input_pos,
536-
)
541+
h = layer(h, input_pos)
537542

538543
if not self.generate_full_logits:
539544
# Only the last logit is used for the new generated token

0 commit comments

Comments
 (0)