Skip to content

Commit 8792a4d

Browse files
committed
Update
[ghstack-poisoned]
1 parent 411fb82 commit 8792a4d

File tree

1 file changed

+41
-35
lines changed

1 file changed

+41
-35
lines changed

examples/models/llama/llama_transformer.py

Lines changed: 41 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -169,10 +169,16 @@ def __init__(self, params: ModelArgs):
169169
else:
170170
self.apply_rotary_emb = RotaryEmbedding()
171171

172-
def forward(self, q: torch.Tensor, k: torch.Tensor, seq_len: int, input_pos: Optional[torch.LongTensor] = None):
172+
def forward(
173+
self,
174+
q: torch.Tensor,
175+
k: torch.Tensor,
176+
seq_len: int,
177+
input_pos: Optional[torch.LongTensor] = None,
178+
):
173179
if self.params.use_kv_cache:
174180
assert (
175-
input_pos is not None
181+
input_pos is not None
176182
), "input_pos must be provided when use_kv_cache is True"
177183

178184
if self.params.enable_dynamic_shape:
@@ -202,14 +208,14 @@ def forward(self, q: torch.Tensor, k: torch.Tensor, seq_len: int, input_pos: Opt
202208

203209
class KVCache(nn.Module):
204210
def __init__(
205-
self,
206-
max_batch_size: int,
207-
max_seq_length: int,
208-
n_heads: int,
209-
head_dim: int,
210-
transpose_cache: bool,
211-
enable_dynamic_shape: bool,
212-
dtype=torch.float32,
211+
self,
212+
max_batch_size: int,
213+
max_seq_length: int,
214+
n_heads: int,
215+
head_dim: int,
216+
transpose_cache: bool,
217+
enable_dynamic_shape: bool,
218+
dtype=torch.float32,
213219
):
214220
super().__init__()
215221
self.max_seq_length = max_seq_length
@@ -232,7 +238,7 @@ def __init__(
232238
)
233239

234240
def update(
235-
self, input_pos: torch.Tensor, k_val: torch.Tensor, v_val: torch.Tensor
241+
self, input_pos: torch.Tensor, k_val: torch.Tensor, v_val: torch.Tensor
236242
) -> Tuple[torch.Tensor, torch.Tensor]:
237243
# input_pos: [S], k_val: [B, H, S, D] or [B, S, H, D] depending on transpose_cache
238244
if self.enable_dynamic_shape:
@@ -270,13 +276,13 @@ def update(
270276

271277
class SDPA(nn.Module):
272278
def __init__(
273-
self,
274-
kv_cache: KVCache,
275-
dim: int,
276-
head_dim: int,
277-
n_rep: int,
278-
max_seq_len: int,
279-
enable_dynamic_shape: bool,
279+
self,
280+
kv_cache: KVCache,
281+
dim: int,
282+
head_dim: int,
283+
n_rep: int,
284+
max_seq_len: int,
285+
enable_dynamic_shape: bool,
280286
):
281287
super().__init__()
282288
self.kv_cache = kv_cache
@@ -287,14 +293,14 @@ def __init__(
287293
self.enable_dynamic_shape = enable_dynamic_shape
288294

289295
def forward(
290-
self,
291-
input_pos: torch.Tensor,
292-
q: torch.Tensor, # Already have rotary embeddings. (bs, seqlen, n_local_heads, head_dim)
293-
k: torch.Tensor, # Already have rotary embeddings. (bs, seqlen, n_local_kv_heads, head_dim)
294-
v: torch.Tensor, # (bs, seqlen, n_local_kv_heads, head_dim)
295-
bsz,
296-
seqlen,
297-
mask: torch.Tensor,
296+
self,
297+
input_pos: torch.Tensor,
298+
q: torch.Tensor, # Already have rotary embeddings. (bs, seqlen, n_local_heads, head_dim)
299+
k: torch.Tensor, # Already have rotary embeddings. (bs, seqlen, n_local_kv_heads, head_dim)
300+
v: torch.Tensor, # (bs, seqlen, n_local_kv_heads, head_dim)
301+
bsz,
302+
seqlen,
303+
mask: torch.Tensor,
298304
) -> torch.Tensor:
299305
q = q.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
300306
k = k.transpose(1, 2)
@@ -373,9 +379,9 @@ def __init__(self, args: ModelArgs, layer_id: int, rope: Rope):
373379
)
374380

375381
def forward(
376-
self,
377-
x: torch.Tensor,
378-
input_pos: Optional[torch.Tensor] = None,
382+
self,
383+
x: torch.Tensor,
384+
input_pos: Optional[torch.Tensor] = None,
379385
):
380386
bsz, seqlen, _ = x.shape
381387

@@ -523,12 +529,12 @@ def __init__(self, params: ModelArgs):
523529
self.output_prune_map = params.output_prune_map
524530

525531
def forward(
526-
self,
527-
tokens: Optional[torch.LongTensor] = None, # tokens
528-
input_pos: Optional[
529-
torch.LongTensor
530-
] = None, # Scalar tensor indicating size of window of the caches
531-
h: Optional[torch.FloatTensor] = None, # embeddings
532+
self,
533+
tokens: Optional[torch.LongTensor] = None, # tokens
534+
input_pos: Optional[
535+
torch.LongTensor
536+
] = None, # Scalar tensor indicating size of window of the caches
537+
h: Optional[torch.FloatTensor] = None, # embeddings
532538
) -> torch.Tensor:
533539
if (tokens is None) ^ (h is not None):
534540
raise ValueError(

0 commit comments

Comments
 (0)