Skip to content

Commit 3ea4c3b

Browse files
authored
Enable compile for ipex patched model with paged attention (#1253)
* fix max input_length Signed-off-by: jiqing-feng <[email protected]> * make static address for key value cache Signed-off-by: jiqing-feng <[email protected]> * fix max cache len Signed-off-by: jiqing-feng <[email protected]> * fix typo Signed-off-by: jiqing-feng <[email protected]> * fix typo Signed-off-by: jiqing-feng <[email protected]> * fix max_seqw_len Signed-off-by: jiqing-feng <[email protected]> * use where to replace nonzero Signed-off-by: jiqing-feng <[email protected]> * avoid dynamic shape tensor of blocks Signed-off-by: jiqing-feng <[email protected]> * fix block tables Signed-off-by: jiqing-feng <[email protected]> * rm useless dim Signed-off-by: jiqing-feng <[email protected]> * fix .max Signed-off-by: jiqing-feng <[email protected]> * fix small error Signed-off-by: jiqing-feng <[email protected]> * fix pretrain Signed-off-by: jiqing-feng <[email protected]> * fix input len stride size mismatch Signed-off-by: jiqing-feng <[email protected]> * rm warmup Signed-off-by: jiqing-feng <[email protected]> * put the alloc cache outside forward Signed-off-by: jiqing-feng <[email protected]> * fix update cache Signed-off-by: jiqing-feng <[email protected]> * rm condition in model forward Signed-off-by: jiqing-feng <[email protected]> * fix index Signed-off-by: jiqing-feng <[email protected]> * fix attn interface condition Signed-off-by: jiqing-feng <[email protected]> * fix input kwargs Signed-off-by: jiqing-feng <[email protected]> * use index select Signed-off-by: jiqing-feng <[email protected]> * rm max_seq_lens Signed-off-by: jiqing-feng <[email protected]> * add get_seq_length function Signed-off-by: jiqing-feng <[email protected]> * fix inputs Signed-off-by: jiqing-feng <[email protected]> * make slots static address Signed-off-by: jiqing-feng <[email protected]> * use inplace op to avoid recompile Signed-off-by: jiqing-feng <[email protected]> * register flashh attention to pytorch op Signed-off-by: jiqing-feng <[email protected]> fix block tables input Signed-off-by: jiqing-feng <[email protected]> * Revert "register flashh attention to pytorch op" This reverts commit 489cd0d. * pass int value max_input_lens through config Signed-off-by: jiqing-feng <[email protected]> * fix slots address Signed-off-by: jiqing-feng <[email protected]> * fix slots Signed-off-by: jiqing-feng <[email protected]> * Assign device type earlier to void recompile in ipex Signed-off-by: jiqing-feng <[email protected]> * fix other models Signed-off-by: jiqing-feng <[email protected]> * fix style Signed-off-by: jiqing-feng <[email protected]> * disable cpp wrapper for paged attention model to get better performance Signed-off-by: jiqing-feng <[email protected]> * add patch gpt2 lm head model forward to pass kwargs Signed-off-by: jiqing-feng <[email protected]> * fix convert func Signed-off-by: jiqing-feng <[email protected]> * rm static cache check Signed-off-by: jiqing-feng <[email protected]> * fix static cache check Signed-off-by: jiqing-feng <[email protected]> * fix qwen Signed-off-by: jiqing-feng <[email protected]> * fix typo Signed-off-by: jiqing-feng <[email protected]> * patch falcon causal lm forward Signed-off-by: jiqing-feng <[email protected]> * general optimization Signed-off-by: jiqing-feng <[email protected]> * fix attn output Signed-off-by: jiqing-feng <[email protected]> * fix max_input_lens Signed-off-by: jiqing-feng <[email protected]> * fix max_input_lens Signed-off-by: jiqing-feng <[email protected]> * fix attn Signed-off-by: jiqing-feng <[email protected]> * fix slots to multi batch and reorder kv cache for beam search Signed-off-by: jiqing-feng <[email protected]> * fix comment Signed-off-by: jiqing-feng <[email protected]> * fix reorder kv cache Signed-off-by: jiqing-feng <[email protected]> * fix typo Signed-off-by: jiqing-feng <[email protected]> * rm export Signed-off-by: jiqing-feng <[email protected]> * fix pipeline Signed-off-by: jiqing-feng <[email protected]> * check pkv before update Signed-off-by: jiqing-feng <[email protected]> * rm export=True Signed-off-by: jiqing-feng <[email protected]> * em export in ST model Signed-off-by: jiqing-feng <[email protected]> * fix ruff Signed-off-by: jiqing-feng <[email protected]> * fix from_pretrained config Signed-off-by: jiqing-feng <[email protected]> * fix beam search Signed-off-by: jiqing-feng <[email protected]> * rm static cache Signed-off-by: jiqing-feng <[email protected]> * add torch 2.7 tests Signed-off-by: jiqing-feng <[email protected]> * use phi3 to check with or without pkv Signed-off-by: jiqing-feng <[email protected]> * fix format Signed-off-by: jiqing-feng <[email protected]> * fix format Signed-off-by: jiqing-feng <[email protected]> * ob more details for pytest Signed-off-by: jiqing-feng <[email protected]> * fix text-generation tests Signed-off-by: jiqing-feng <[email protected]> * fix test name Signed-off-by: jiqing-feng <[email protected]> * fix name Signed-off-by: jiqing-feng <[email protected]> * check patched models results separately Signed-off-by: jiqing-feng <[email protected]> * fix Signed-off-by: jiqing-feng <[email protected]> * fix patched models tests Signed-off-by: jiqing-feng <[email protected]> * fix falcon test Signed-off-by: jiqing-feng <[email protected]> * add patched mistral test Signed-off-by: jiqing-feng <[email protected]> * disable torch 2.6 tests Signed-off-by: jiqing-feng <[email protected]> * fix input format Signed-off-by: jiqing-feng <[email protected]> * fix typo Signed-off-by: jiqing-feng <[email protected]> * fix format Signed-off-by: jiqing-feng <[email protected]> --------- Signed-off-by: jiqing-feng <[email protected]>
1 parent 19d6073 commit 3ea4c3b

File tree

10 files changed

+512
-335
lines changed

10 files changed

+512
-335
lines changed

.github/workflows/test_ipex.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,4 +65,4 @@ jobs:
6565
6666
- name: Test with Pytest
6767
run: |
68-
pytest tests/ipex/${{ matrix.test-file }}
68+
pytest tests/ipex/${{ matrix.test-file }} -rsx -v

optimum/exporters/ipex/cache_utils.py

Lines changed: 63 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,10 @@ def __init__(
5454
# Used in `generate` to keep tally of how many tokens the cache has seen
5555

5656
self._seen_tokens = torch.zeros([max_batch_size], dtype=torch.int32, device=device)
57-
default_block_size = 16
57+
self.slots = torch.zeros([max_cache_len * max_batch_size], dtype=torch.int32, device=device)
58+
torch._dynamo.mark_static_address(self._seen_tokens)
59+
torch._dynamo.mark_static_address(self.slots)
60+
default_block_size = 16 if max_cache_len <= 64 else 64
5861
self.block_size = int(os.environ.get("OI_PAGED_ATTN_BLOCK_SIZE", str(default_block_size)))
5962
self.num_blocks = (max_cache_len // self.block_size + (max_cache_len % self.block_size != 0)) * max_batch_size
6063
self.block_tables = -1 * torch.ones([self.num_blocks], dtype=torch.int32, device=device).reshape(
@@ -69,7 +72,6 @@ def __init__(
6972
else:
7073
head_size = config.hidden_size // config.num_attention_heads
7174
self.head_size = head_size
72-
self.max_seq_len = 0
7375

7476
self.key_cache: List[torch.Tensor] = []
7577
self.value_cache: List[torch.Tensor] = []
@@ -87,6 +89,8 @@ def __init__(
8789
for i in range(config.num_hidden_layers):
8890
new_layer_key_cache = torch.zeros(key_cache_shape, dtype=dtype, device=device)
8991
new_layer_value_cache = torch.zeros(value_cache_shape, dtype=dtype, device=device)
92+
torch._dynamo.mark_static_address(new_layer_key_cache)
93+
torch._dynamo.mark_static_address(new_layer_value_cache)
9094
self.key_cache.append(new_layer_key_cache)
9195
self.value_cache.append(new_layer_value_cache)
9296

@@ -116,79 +120,50 @@ def reshape_and_cache(
116120
slots,
117121
)
118122

119-
def update_for_prefill(
120-
self,
121-
key_states: torch.Tensor,
122-
value_states: torch.Tensor,
123-
layer_idx: int,
124-
batch_size: int,
125-
input_lens: torch.Tensor,
126-
):
127-
if layer_idx == 0:
128-
all_block_indices = []
129-
all_slot_offsets = []
130-
num_blocks = (input_lens + self.block_size - 1) // self.block_size
131-
for i in range(batch_size):
132-
nb = num_blocks[i]
133-
block_table = self.free_blocks.nonzero().view(-1)[0:nb]
134-
self.block_tables[i][0:nb] = block_table
135-
self.free_blocks[block_table] = 0
136-
slots_range = torch.arange(input_lens[i], device=self.device)
137-
block_indices = slots_range // self.block_size
138-
slot_offsets = slots_range % self.block_size
139-
all_block_indices.append(self.block_tables[i][block_indices])
140-
all_slot_offsets.append(slot_offsets)
141-
142-
all_block_indices = torch.cat(all_block_indices)
143-
all_slot_offsets = torch.cat(all_slot_offsets)
144-
self.slots = all_block_indices * self.block_size + all_slot_offsets
145-
# Update the cache
146-
self.reshape_and_cache(
147-
key_states, value_states, self.key_cache[layer_idx], self.value_cache[layer_idx], self.slots
148-
)
149-
150-
# Update the number of seen tokens
151-
if layer_idx == self.num_hidden_layers - 1:
152-
self._seen_tokens = self._seen_tokens + input_lens
153-
self.max_seq_len, _ = self._seen_tokens.max(dim=0)
154-
155-
def update_for_decode(
156-
self,
157-
key_states: torch.Tensor,
158-
value_states: torch.Tensor,
159-
layer_idx: int,
160-
batch_size: int,
161-
):
162-
if layer_idx == 0:
163-
start_block_idx = self._seen_tokens // self.block_size
164-
slot_offset_in_block = (self._seen_tokens) % self.block_size
165-
self.slots = torch.zeros([batch_size], device=self.device, dtype=torch.int32)
166-
for i in range(batch_size):
167-
if slot_offset_in_block[i] == 0:
168-
# need a new block:
169-
b_idx = start_block_idx[i]
170-
if self.block_tables[i][b_idx] == -1:
171-
# need a free block
172-
self.block_tables[i][b_idx] = self.free_blocks.nonzero().view(-1)[0:1]
173-
self.free_blocks[self.block_tables[i][b_idx]] = 0
174-
self.slots[i] = self.block_tables[i][start_block_idx[i]] * self.block_size + slot_offset_in_block[i]
175-
# Update the cache
176-
self.reshape_and_cache(
177-
key_states, value_states, self.key_cache[layer_idx], self.value_cache[layer_idx], self.slots
178-
)
179-
180-
# Update the number of seen tokens
181-
if layer_idx == self.num_hidden_layers - 1:
182-
self._seen_tokens = self._seen_tokens + 1
183-
self.max_seq_len = self.max_seq_len + 1
123+
# outside the model forward
124+
def alloc_slot_for_prefill(self, input_lens: torch.Tensor, batch_size: int):
125+
all_block_indices = []
126+
all_slot_offsets = []
127+
num_blocks = (input_lens + self.block_size - 1) // self.block_size
128+
for i in range(batch_size):
129+
nb = num_blocks[i]
130+
scores = self.free_blocks * torch.arange(self.free_blocks.shape[0], 0, -1)
131+
block_table = torch.topk(scores, nb).indices
132+
self.block_tables[i][0:nb] = block_table
133+
self.free_blocks[block_table] = 0
134+
slots_range = torch.arange(input_lens[i], device=self.device)
135+
block_indices = slots_range // self.block_size
136+
slot_offsets = slots_range % self.block_size
137+
all_block_indices.append(self.block_tables[i][block_indices])
138+
all_slot_offsets.append(slot_offsets)
139+
140+
all_block_indices = torch.cat(all_block_indices)
141+
all_slot_offsets = torch.cat(all_slot_offsets).int()
142+
# Use inplace op to keep the same memory address, avoid recompile
143+
self.slots[: all_block_indices.shape[0]].copy_(all_block_indices * self.block_size + all_slot_offsets)
144+
145+
# outside the model forward
146+
def alloc_slot_for_decode(self, batch_size: int):
147+
start_block_idx = self._seen_tokens // self.block_size
148+
slot_offset_in_block = (self._seen_tokens) % self.block_size
149+
# Use inplace op to keep the same memory address, avoid recompile
150+
self.slots.zero_()
151+
for i in range(batch_size):
152+
if slot_offset_in_block[i] == 0:
153+
# need a new block:
154+
b_idx = start_block_idx[i]
155+
if self.block_tables[i][b_idx] == -1:
156+
# Need a free block. Get indices of free blocks, select the first free block
157+
scores = self.free_blocks * torch.arange(self.free_blocks.shape[0], 0, -1)
158+
self.block_tables[i][b_idx] = scores.argmax()
159+
self.free_blocks[self.block_tables[i][b_idx]] = 0
160+
self.slots[i] = self.block_tables[i][start_block_idx[i]] * self.block_size + slot_offset_in_block[i]
184161

185162
def update(
186163
self,
187164
key_states: torch.Tensor,
188165
value_states: torch.Tensor,
189166
layer_idx: int,
190-
attention_mask: torch.Tensor,
191-
input_lens: torch.Tensor,
192167
) -> Tuple[torch.Tensor, torch.Tensor]:
193168
"""
194169
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
@@ -204,45 +179,46 @@ def update(
204179
A tuple containing the updated key and value states.
205180
"""
206181

207-
batch_size = input_lens.shape[-1]
208-
if self.get_seq_length() == 0:
209-
# prefill
210-
self.update_for_prefill(key_states, value_states, layer_idx, batch_size, input_lens)
211-
else:
212-
# decode
213-
self.update_for_decode(key_states, value_states, layer_idx, batch_size)
182+
self.reshape_and_cache(
183+
key_states, value_states, self.key_cache[layer_idx], self.value_cache[layer_idx], self.slots
184+
)
214185

215186
return self.key_cache[layer_idx], self.value_cache[layer_idx]
216187

217-
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
188+
def get_seq_length(self) -> int:
218189
"""Returns the sequence length of the cached states that were seen by the model."""
219-
return self.max_seq_len
190+
return self._seen_tokens.max()
220191

221192
def get_max_length(self) -> Optional[int]:
222193
"""Returns the maximum sequence length of the cached states."""
223194
return self.max_cache_len
224195

225196
def reset(self):
226197
"""Resets the cache values while preserving the objects"""
227-
self._seen_tokens = torch.zeros([self.max_batch_size], dtype=torch.int32, device=self.device)
198+
self._seen_tokens.zero_()
228199
self.block_tables.fill_(-1)
229-
self.free_blocks = torch.ones([self.num_blocks], dtype=torch.int32, device=self.device)
230-
self.max_seq_len = 0
200+
self.free_blocks.fill_(1)
231201

232202
def reorder_cache(self, beam_idx: torch.LongTensor):
233203
"""Reorders the cache for beam search, given the selected beam indices."""
234204
origin_table = self.block_tables.clone()
235205
updated_block_tables = self.block_tables.index_select(0, beam_idx.to(self.device))
236-
mask = self.block_tables.masked_fill(self.block_tables != -1, 1).masked_fill(self.block_tables == -1, 0)
237-
num_blocks = mask.cumsum(-1)[:, -1]
206+
mask = torch.where(self.block_tables == -1, 0, 1)
207+
num_blocks = mask.sum(-1)
238208
updated_table = torch.zeros_like(beam_idx)
239209
for i in range(beam_idx.shape[0]):
240210
nb = num_blocks[i]
241211
self.block_tables[i, 0 : nb - 1] = updated_block_tables[i, 0 : nb - 1]
242212
updated_table[i] = self.block_tables[i][nb - 1]
243213
for layer_idx in range(self.num_hidden_layers):
244-
self.key_cache[layer_idx][updated_table] = self.key_cache[layer_idx][updated_table[beam_idx]]
245-
self.value_cache[layer_idx][updated_table] = self.value_cache[layer_idx][updated_table[beam_idx]]
214+
# The updated_table cannot contain the whole block table, otherwise will cause core-dump.
215+
self.key_cache[layer_idx][updated_table] = self.key_cache[layer_idx].index_select(
216+
0, updated_table[beam_idx]
217+
)
218+
self.value_cache[layer_idx][updated_table] = self.value_cache[layer_idx].index_select(
219+
0, updated_table[beam_idx]
220+
)
221+
246222
free_table = torch.unique((origin_table[origin_table != self.block_tables]).view(-1))
247223
for i in free_table:
248224
if not (self.block_tables == i).any():
@@ -252,7 +228,7 @@ def crop(self, maximum_length: int):
252228
"""Crop the past key values up to a new `maximum_length` in terms of tokens. `maximum_length` can also be
253229
negative to remove `maximum_length` tokens. This is used in assisted decoding and contrastive search."""
254230

255-
max_seq_len = self.get_seq_length()
231+
max_seq_len = self._seen_tokens.max()
256232
if maximum_length < 0:
257233
maximum_length = max_seq_len - abs(maximum_length)
258234

@@ -264,7 +240,7 @@ def crop(self, maximum_length: int):
264240
num_blocks = (new_tokens + self.block_size - 1) // self.block_size
265241
self.block_tables[bs, num_blocks:] = -1
266242
self._seen_tokens[bs] = new_tokens
267-
self.max_seq_len, _ = self._seen_tokens.max(dim=0)
243+
268244
free_table = torch.unique((origin_table[origin_table != self.block_tables]).view(-1))
269245
for i in free_table:
270246
if not (self.block_tables == i).any():

optimum/exporters/ipex/model_patcher.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,9 @@
3333

3434
from .modeling_utils import (
3535
_IPEX_MINIMUM_VERSION_FOR_PATCHING,
36+
_falcon_for_causal_lm_forward,
3637
_falcon_model_forward,
38+
_gpt2_lm_head_model_forward,
3739
_gpt2_model_forward,
3840
_ipex_rms_layer_norm_forward,
3941
_IPEXFalconDecoderLayer,
@@ -104,6 +106,7 @@ def _patch_falcon_model(model):
104106
model.config.num_kv_heads if (model.config.new_decoder_architecture or not model.config.multi_query) else 1
105107
)
106108
setattr(model.config, "num_key_value_heads", num_key_value_heads)
109+
convert_func(model, "forward", _falcon_for_causal_lm_forward)
107110
convert_functions(model, FalconModel, "forward", _falcon_model_forward)
108111
replace_customized_linear_with_linear(model)
109112
convert_class(model, FalconDecoderLayer, _IPEXFalconDecoderLayer, model.device, model.config)
@@ -118,6 +121,7 @@ def _patch_gpt2_model(model):
118121
"""
119122
num_key_value_heads = model.config.num_attention_heads
120123
setattr(model.config, "num_key_value_heads", num_key_value_heads)
124+
convert_func(model, "forward", _gpt2_lm_head_model_forward)
121125
convert_functions(model, GPT2Model, "forward", _gpt2_model_forward)
122126
convert_class(model, GPT2Block, _IPEXGPT2Block, model.device, model.config)
123127
return model
@@ -129,6 +133,8 @@ def _patch_qwen2_model(model):
129133
1. Use IPEX rope and paged cache
130134
2. Linear fusion with (2 Linears + Silu + Mul) and (Linear + Add)
131135
"""
136+
# To avoid call _ignore_causal_mask_sdpa which will cause recompile
137+
model.config._attn_implementation = "ipex_paged"
132138
convert_functions(model, Qwen2Model, "forward", _qwen2_model_forward)
133139
convert_functions(model, Qwen2RMSNorm, "forward", _ipex_rms_layer_norm_forward)
134140
convert_class(model, Qwen2DecoderLayer, _IPEXQwen2DecoderLayer, model.device, model.config)

0 commit comments

Comments
 (0)