diff --git a/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py b/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py index 273a6ca42c..94b1422e2a 100644 --- a/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py +++ b/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py @@ -212,7 +212,7 @@ def get_total_slots(): elif is_unpaged_prefill: # prepare some params of unpaged_prefill attention stage. q_start_loc_cpu, kv_seqlens_cpu = None, None - q_seqlens_cpu = step_context.q_seqlens.cpu() + q_seqlens_cpu = step_context.q_seqlens.cpu().to(torch.int32) if SocVersion.is_Ascend910(): single_attention_mask = torch.logical_not( torch.tril( @@ -251,7 +251,7 @@ def get_total_slots(): step_context.block_offsets = step_context.block_offsets\ .repeat_interleave(step_context.q_seqlens, 0) dynamo.mark_dynamic(step_context.block_offsets, [0, 1]) - kv_seqlens = step_context.kv_seqlens.to(torch.int32) + kv_seqlens = step_context.kv_seqlens.cpu().to(torch.int32) if not step_context.is_decoding: if is_unpaged_prefill: if SocVersion.is_Ascend910(): @@ -270,10 +270,10 @@ def get_total_slots(): raise ValueError(f"dlinfer doesn't support {SocVersion.device_name()} device currently.") kv_seqlens = kv_seqlens.repeat_interleave(step_context.q_seqlens, 0) if not is_unpaged_prefill and AscendOpsBackend.enable_aclgraph(): - kv_seqlens = kv_seqlens.cpu().tolist() + kv_seqlens = kv_seqlens.cpu().to(torch.int32) else: if step_context.is_decoding: - kv_seqlens_cpu = step_context.kv_seqlens.cpu() + kv_seqlens_cpu = step_context.kv_seqlens.cpu().to(torch.int32) elif is_unpaged_prefill: pass else: diff --git a/lmdeploy/pytorch/backends/dlinfer/rotary_embedding.py b/lmdeploy/pytorch/backends/dlinfer/rotary_embedding.py index 16dfc99dea..87d4626a15 100644 --- a/lmdeploy/pytorch/backends/dlinfer/rotary_embedding.py +++ b/lmdeploy/pytorch/backends/dlinfer/rotary_embedding.py @@ -40,14 +40,13 @@ def _rotary_embedding_fwd(position_ids: torch.Tensor, class DlinferRotaryEmbeddingImpl(RotaryEmbeddingImpl, nn.Module): """Base rotary embedding.""" - def __init__(self, dim: int, base: int = 10000, scaling_factor: float = 1.0): + def __init__(self, dim: int, base: float = 10000.0, scaling_factor: float = 1.0): super().__init__() self.scaling_factor = scaling_factor self.dim = dim self.base = base # yapf: disable - inv_freq = 1.0 / (self.base - ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim)).float().cuda() + inv_freq = 1.0 / (self.base**(torch.arange(0, self.dim, 2, dtype=torch.float, device='cuda') / self.dim)) # yapf: enable self.register_buffer('inv_freq', inv_freq, persistent=False)