@@ -104,12 +104,15 @@ def reshape_and_cache(
104104 ):
105105 # TODO: unify API definition between CPU and XPU in IPEX version > 2.6
106106 if self .device .type == "xpu" and self ._supports_flash_decoding :
107+ # make a WA here as slots here is padded but XPU does not support slots with length not equal to key length, will fix it in IPEX 2.8
108+ valid_len = key .shape [0 ]
109+ truncated_slots = slots [:valid_len ]
107110 PagedAttention .reshape_and_cache_flash (
108111 key ,
109112 value ,
110113 key_cache ,
111114 value_cache ,
112- slots ,
115+ truncated_slots ,
113116 )
114117 else :
115118 PagedAttention .reshape_and_cache (
@@ -127,7 +130,7 @@ def alloc_slot_for_prefill(self, input_lens: torch.Tensor, batch_size: int):
127130 num_blocks = (input_lens + self .block_size - 1 ) // self .block_size
128131 for i in range (batch_size ):
129132 nb = num_blocks [i ]
130- scores = self .free_blocks * torch .arange (self .free_blocks .shape [0 ], 0 , - 1 )
133+ scores = self .free_blocks * torch .arange (self .free_blocks .shape [0 ], 0 , - 1 , device = self . device )
131134 block_table = torch .topk (scores , nb ).indices
132135 self .block_tables [i ][0 :nb ] = block_table
133136 self .free_blocks [block_table ] = 0
@@ -154,7 +157,7 @@ def alloc_slot_for_decode(self, batch_size: int):
154157 b_idx = start_block_idx [i ]
155158 if self .block_tables [i ][b_idx ] == - 1 :
156159 # Need a free block. Get indices of free blocks, select the first free block
157- scores = self .free_blocks * torch .arange (self .free_blocks .shape [0 ], 0 , - 1 )
160+ scores = self .free_blocks * torch .arange (self .free_blocks .shape [0 ], 0 , - 1 , device = self . device )
158161 self .block_tables [i ][b_idx ] = scores .argmax ()
159162 self .free_blocks [self .block_tables [i ][b_idx ]] = 0
160163 self .slots [i ] = self .block_tables [i ][start_block_idx [i ]] * self .block_size + slot_offset_in_block [i ]
0 commit comments