Skip to content

Commit 5ffa0c0

Browse files
authored
Make a WA to avoid XPU crash for API PagedAttention.reshape_and_cache_flash (#1288)
1 parent 6634bbc commit 5ffa0c0

File tree

1 file changed

+6
-3
lines changed

1 file changed

+6
-3
lines changed

optimum/exporters/ipex/cache_utils.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -104,12 +104,15 @@ def reshape_and_cache(
104104
):
105105
# TODO: unify API definition between CPU and XPU in IPEX version > 2.6
106106
if self.device.type == "xpu" and self._supports_flash_decoding:
107+
# make a WA here as slots here is padded but XPU does not support slots with length not equal to key length, will fix it in IPEX 2.8
108+
valid_len = key.shape[0]
109+
truncated_slots = slots[:valid_len]
107110
PagedAttention.reshape_and_cache_flash(
108111
key,
109112
value,
110113
key_cache,
111114
value_cache,
112-
slots,
115+
truncated_slots,
113116
)
114117
else:
115118
PagedAttention.reshape_and_cache(
@@ -127,7 +130,7 @@ def alloc_slot_for_prefill(self, input_lens: torch.Tensor, batch_size: int):
127130
num_blocks = (input_lens + self.block_size - 1) // self.block_size
128131
for i in range(batch_size):
129132
nb = num_blocks[i]
130-
scores = self.free_blocks * torch.arange(self.free_blocks.shape[0], 0, -1)
133+
scores = self.free_blocks * torch.arange(self.free_blocks.shape[0], 0, -1, device=self.device)
131134
block_table = torch.topk(scores, nb).indices
132135
self.block_tables[i][0:nb] = block_table
133136
self.free_blocks[block_table] = 0
@@ -154,7 +157,7 @@ def alloc_slot_for_decode(self, batch_size: int):
154157
b_idx = start_block_idx[i]
155158
if self.block_tables[i][b_idx] == -1:
156159
# Need a free block. Get indices of free blocks, select the first free block
157-
scores = self.free_blocks * torch.arange(self.free_blocks.shape[0], 0, -1)
160+
scores = self.free_blocks * torch.arange(self.free_blocks.shape[0], 0, -1, device=self.device)
158161
self.block_tables[i][b_idx] = scores.argmax()
159162
self.free_blocks[self.block_tables[i][b_idx]] = 0
160163
self.slots[i] = self.block_tables[i][start_block_idx[i]] * self.block_size + slot_offset_in_block[i]

0 commit comments

Comments
 (0)