Skip to content

Commit 753f84d

Browse files
authored
Use free_table as a mask tensor (#1086)
* use free_table as a mask tensor Signed-off-by: Wang, Yi A <[email protected]> * fix beamsearch issue Signed-off-by: Wang, Yi A <[email protected]> --------- Signed-off-by: Wang, Yi A <[email protected]>
1 parent 1733791 commit 753f84d

File tree

1 file changed

+21
-20
lines changed

1 file changed

+21
-20
lines changed

optimum/exporters/ipex/cache_utils.py

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def __init__(
4949
self.block_tables = -1 * torch.ones([self.num_blocks], dtype=torch.int32, device=device).reshape(
5050
batch_size, -1
5151
)
52-
self.free_blocks = torch.arange(self.num_blocks, device=device)
52+
self.free_blocks = torch.ones([self.num_blocks], dtype=torch.int32, device=device)
5353
self.max_cache_len = max_cache_len
5454
self.num_kv_heads = config.num_key_value_heads
5555
self.num_hidden_layers = config.num_hidden_layers
@@ -88,12 +88,10 @@ def update_for_prefill(
8888
all_slot_offsets = []
8989
num_blocks = (input_lens + self.block_size - 1) // self.block_size
9090
for i in range(batch_size):
91-
for b_idx in range(num_blocks[i]):
92-
if self.block_tables[i][b_idx] == -1:
93-
# need a free block
94-
self.block_tables[i][b_idx] = self.free_blocks[0]
95-
self.free_blocks = self.free_blocks[1:]
96-
91+
nb = num_blocks[i]
92+
block_table = self.free_blocks.nonzero().view(-1)[0:nb]
93+
self.block_tables[i][0:nb] = block_table
94+
self.free_blocks[block_table] = 0
9795
slots_range = torch.arange(input_lens[i], device=key_states.device)
9896
block_indices = slots_range // self.block_size
9997
slot_offsets = slots_range % self.block_size
@@ -103,7 +101,6 @@ def update_for_prefill(
103101
all_block_indices = torch.cat(all_block_indices)
104102
all_slot_offsets = torch.cat(all_slot_offsets)
105103
self.slots = all_block_indices * self.block_size + all_slot_offsets
106-
107104
# Update the cache
108105
PagedAttention.reshape_and_cache(
109106
key_states,
@@ -127,16 +124,16 @@ def update_for_decode(
127124
):
128125
if layer_idx == 0:
129126
start_block_idx = self._seen_tokens // self.block_size
130-
num_blocks = (self._seen_tokens + self.block_size) // self.block_size
131127
slot_offset_in_block = (self._seen_tokens) % self.block_size
132128
self.slots = torch.zeros([batch_size], device=key_states.device, dtype=torch.int32)
133129
for i in range(batch_size):
134-
for b_idx in range(start_block_idx[i], num_blocks[i]):
130+
if slot_offset_in_block[i] == 0:
131+
# need a new block:
132+
b_idx = start_block_idx[i]
135133
if self.block_tables[i][b_idx] == -1:
136134
# need a free block
137-
self.block_tables[i][b_idx] = self.free_blocks[0]
138-
self.free_blocks = self.free_blocks[1:]
139-
135+
self.block_tables[i][b_idx] = self.free_blocks.nonzero().view(-1)[0:1]
136+
self.free_blocks[self.block_tables[i][b_idx]] = 0
140137
self.slots[i] = self.block_tables[i][start_block_idx[i]] * self.block_size + slot_offset_in_block[i]
141138
# Update the cache
142139
PagedAttention.reshape_and_cache(
@@ -196,7 +193,7 @@ def reset(self):
196193
"""Resets the cache values while preserving the objects"""
197194
self._seen_tokens = torch.zeros([self.batch_size], dtype=torch.int32, device=self.block_tables.device)
198195
self.block_tables.fill_(-1)
199-
self.free_blocks = torch.arange(self.num_blocks, device=self.block_tables.device)
196+
self.free_blocks = torch.ones([self.num_blocks], dtype=torch.int32, device=self.block_tables.device)
200197
self.max_seq_len = 0
201198

202199
def reorder_cache(self, beam_idx: torch.LongTensor):
@@ -206,16 +203,18 @@ def reorder_cache(self, beam_idx: torch.LongTensor):
206203
updated_block_tables = self.block_tables.index_select(0, beam_idx.to(device))
207204
mask = self.block_tables.masked_fill(self.block_tables != -1, 1).masked_fill(self.block_tables == -1, 0)
208205
num_blocks = mask.cumsum(-1)[:, -1]
209-
updated_table = []
206+
updated_table = torch.zeros_like(beam_idx)
210207
for i in range(beam_idx.shape[0]):
211-
self.block_tables[i, 0 : num_blocks[i] - 1] = updated_block_tables[i, 0 : num_blocks[i] - 1]
212-
updated_table.append(self.block_tables[i : i + 1, num_blocks[i] - 1 : num_blocks[i]])
213-
updated_table = torch.cat(tuple(updated_table), dim=0)
208+
nb = num_blocks[i]
209+
self.block_tables[i, 0 : nb - 1] = updated_block_tables[i, 0 : nb - 1]
210+
updated_table[i] = self.block_tables[i][nb - 1]
214211
for layer_idx in range(self.num_hidden_layers):
215212
self.key_cache[layer_idx][updated_table] = self.key_cache[layer_idx][updated_table[beam_idx]]
216213
self.value_cache[layer_idx][updated_table] = self.value_cache[layer_idx][updated_table[beam_idx]]
217214
free_table = torch.unique((origin_table[origin_table != self.block_tables]).view(-1))
218-
self.free_blocks = torch.cat((self.free_blocks, free_table))
215+
for i in free_table:
216+
if not (self.block_tables == i).any():
217+
self.free_blocks[i] = 1
219218

220219
def crop(self, maximum_length: int):
221220
"""Crop the past key values up to a new `maximum_length` in terms of tokens. `maximum_length` can also be
@@ -235,4 +234,6 @@ def crop(self, maximum_length: int):
235234
self._seen_tokens[bs] = new_tokens
236235
self.max_seq_len, _ = self._seen_tokens.max(dim=0)
237236
free_table = torch.unique((origin_table[origin_table != self.block_tables]).view(-1))
238-
self.free_blocks = torch.cat((self.free_blocks, free_table))
237+
for i in free_table:
238+
if not (self.block_tables == i).any():
239+
self.free_blocks[i] = 1

0 commit comments

Comments
 (0)