Skip to content

Commit 15461ff

Browse files
Cast tensors in KVCache only when needed (#2017)
1 parent 3d66f32 commit 15461ff

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

litgpt/model.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -791,8 +791,10 @@ def forward(self, input_pos: torch.Tensor, k: torch.Tensor, v: torch.Tensor) ->
791791
792792
"""
793793
# move the buffer to the activation dtype for when AMP is used
794-
self.k = self.k.to(k.dtype)
795-
self.v = self.v.to(v.dtype)
794+
if self.k.dtype != k.dtype:
795+
self.k = self.k.to(k.dtype)
796+
if self.v.dtype != v.dtype:
797+
self.v = self.v.to(v.dtype)
796798
# update the cache
797799
bs = k.size(0)
798800
k = batched_index_copy_(self.k[:bs, ...], -2, input_pos, k)

0 commit comments

Comments
 (0)