You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: litgpt/model.py
+6-3Lines changed: 6 additions & 3 deletions
Original file line number
Diff line number
Diff line change
@@ -309,7 +309,7 @@ def forward(
309
309
attn=block.attn
310
310
ifattn.kv_cache.batch_size<eff_batch_size:
311
311
raiseValueError(f"Batch size {eff_batch_size} is too large for KV cache layer {l_ix} (batch size {attn.kv_cache.batch_size}). Use 'assign_kv_caches' or `set_kv_cache'")
312
-
x=block(x, cos, sin, input_pos, self.mask_cache)
312
+
x=block(x, cos, sin, idx, input_pos, self.mask_cache)
313
313
314
314
x=self.transformer.ln_f(x)
315
315
clamp_head=partial(
@@ -428,6 +428,7 @@ def forward(
428
428
x: torch.Tensor,
429
429
cos: torch.Tensor,
430
430
sin: torch.Tensor,
431
+
token_idx: torch.Tensor,
431
432
input_pos: Optional[int] =None,
432
433
mask_cache: Optional[torch.Tensor] =None,
433
434
) ->torch.Tensor:
@@ -457,6 +458,7 @@ def forward(
457
458
x_normed,
458
459
cos=cos,
459
460
sin=sin,
461
+
token_idx=token_idx,
460
462
input_pos=input_pos,
461
463
mask_cache=mask_cache,
462
464
)
@@ -511,6 +513,7 @@ def forward(
511
513
x: torch.Tensor,
512
514
cos: torch.Tensor,
513
515
sin: torch.Tensor,
516
+
token_idx: torch.Tensor,
514
517
input_pos: Optional[int] =None,
515
518
mask_cache: Optional[torch.Tensor] =None,
516
519
) ->torch.Tensor:
@@ -596,12 +599,12 @@ def forward(
596
599
# Instead of asking for the key and value tensors as such,
597
600
# `k_and_v` allows access to them. Since they are never needed at
0 commit comments