diff --git a/litgpt/model.py b/litgpt/model.py index 01ea83ad4a..0321f3fad0 100644 --- a/litgpt/model.py +++ b/litgpt/model.py @@ -806,7 +806,7 @@ def build_rope_cache( theta = theta / factor # Create position indices `[0, 1, ..., seq_len - 1]` - seq_idx = torch.arange(seq_len, device=device) / condense_ratio + seq_idx = torch.arange(seq_len, device=device).float() / condense_ratio # Calculate the product of position index and $\theta_i$ idx_theta = torch.outer(seq_idx, theta).repeat(1, 2)