Skip to content

Commit 0464d9e

Browse files
authored
[Cache] lfm2 cache: allocate empty kv layers during init (#41396)
* [Cache] lfm2 cache: allocate empty kv layers during init Signed-off-by: Paul Pak <[email protected]> * [Cache] lfm2_cache: update modular file Signed-off-by: Paul Pak <[email protected]> --------- Signed-off-by: Paul Pak <[email protected]>
1 parent da7b8ce commit 0464d9e

File tree

2 files changed

+34
-46
lines changed

2 files changed

+34
-46
lines changed

src/transformers/models/lfm2/modeling_lfm2.py

Lines changed: 17 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,8 @@ def __init__(
165165
)
166166
torch._dynamo.mark_static_address(conv_state)
167167
self.conv_cache.append(conv_state)
168+
self.key_cache.append(torch.tensor([]))
169+
self.value_cache.append(torch.tensor([]))
168170

169171
def update(
170172
self,
@@ -190,35 +192,27 @@ def update(
190192
A tuple containing the updated key and value states.
191193
"""
192194
# Update the cache
193-
if key_states is not None:
194-
if len(self.key_cache) <= layer_idx:
195-
# There may be skipped layers, fill them with empty lists
196-
for _ in range(len(self.key_cache), layer_idx):
197-
self.key_cache.append(torch.tensor([]))
198-
self.value_cache.append(torch.tensor([]))
199-
self.key_cache.append(key_states)
200-
self.value_cache.append(value_states)
201-
elif (
202-
not self.key_cache[layer_idx].numel() # prefers not t.numel() to len(t) == 0 to export the model
203-
): # fills previously skipped layers; checking for tensor causes errors
204-
self.key_cache[layer_idx] = key_states
205-
self.value_cache[layer_idx] = value_states
206-
else:
207-
self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
208-
self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
195+
if self.key_cache[layer_idx].numel() == 0:
196+
self.key_cache[layer_idx] = key_states
197+
self.value_cache[layer_idx] = value_states
198+
else:
199+
self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
200+
self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
209201

210202
return self.key_cache[layer_idx], self.value_cache[layer_idx]
211203

212204
def reorder_cache(self, beam_idx: torch.LongTensor):
213205
"""Reorders the cache for beam search, given the selected beam indices."""
214206
for layer_idx in range(len(self.key_cache)):
215-
device = self.key_cache[layer_idx].device
216-
self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
217-
device = self.value_cache[layer_idx].device
218-
self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))
219-
220-
device = self.conv_cache[layer_idx].device
221-
self.conv_cache[layer_idx] = self.conv_cache[layer_idx].index_select(0, beam_idx.to(device))
207+
if self.key_cache[layer_idx].numel():
208+
device = self.key_cache[layer_idx].device
209+
self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
210+
device = self.value_cache[layer_idx].device
211+
self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))
212+
213+
if self.conv_cache[layer_idx].numel():
214+
device = self.conv_cache[layer_idx].device
215+
self.conv_cache[layer_idx] = self.conv_cache[layer_idx].index_select(0, beam_idx.to(device))
222216

223217
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
224218
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""

src/transformers/models/lfm2/modular_lfm2.py

Lines changed: 17 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,8 @@ def __init__(
123123
)
124124
torch._dynamo.mark_static_address(conv_state)
125125
self.conv_cache.append(conv_state)
126+
self.key_cache.append(torch.tensor([]))
127+
self.value_cache.append(torch.tensor([]))
126128

127129
def update(
128130
self,
@@ -148,35 +150,27 @@ def update(
148150
A tuple containing the updated key and value states.
149151
"""
150152
# Update the cache
151-
if key_states is not None:
152-
if len(self.key_cache) <= layer_idx:
153-
# There may be skipped layers, fill them with empty lists
154-
for _ in range(len(self.key_cache), layer_idx):
155-
self.key_cache.append(torch.tensor([]))
156-
self.value_cache.append(torch.tensor([]))
157-
self.key_cache.append(key_states)
158-
self.value_cache.append(value_states)
159-
elif (
160-
not self.key_cache[layer_idx].numel() # prefers not t.numel() to len(t) == 0 to export the model
161-
): # fills previously skipped layers; checking for tensor causes errors
162-
self.key_cache[layer_idx] = key_states
163-
self.value_cache[layer_idx] = value_states
164-
else:
165-
self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
166-
self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
153+
if self.key_cache[layer_idx].numel() == 0:
154+
self.key_cache[layer_idx] = key_states
155+
self.value_cache[layer_idx] = value_states
156+
else:
157+
self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
158+
self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
167159

168160
return self.key_cache[layer_idx], self.value_cache[layer_idx]
169161

170162
def reorder_cache(self, beam_idx: torch.LongTensor):
171163
"""Reorders the cache for beam search, given the selected beam indices."""
172164
for layer_idx in range(len(self.key_cache)):
173-
device = self.key_cache[layer_idx].device
174-
self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
175-
device = self.value_cache[layer_idx].device
176-
self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))
177-
178-
device = self.conv_cache[layer_idx].device
179-
self.conv_cache[layer_idx] = self.conv_cache[layer_idx].index_select(0, beam_idx.to(device))
165+
if self.key_cache[layer_idx].numel():
166+
device = self.key_cache[layer_idx].device
167+
self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
168+
device = self.value_cache[layer_idx].device
169+
self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))
170+
171+
if self.conv_cache[layer_idx].numel():
172+
device = self.conv_cache[layer_idx].device
173+
self.conv_cache[layer_idx] = self.conv_cache[layer_idx].index_select(0, beam_idx.to(device))
180174

181175
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
182176
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""

0 commit comments

Comments
 (0)