Skip to content

Commit 8de90bf

Browse files
authored
Merge branch 'main' into addScript
2 parents 3c2cbd2 + 137163f commit 8de90bf

File tree

8 files changed

+496
-81
lines changed

8 files changed

+496
-81
lines changed

examples/models/llama/static_attention.py

Lines changed: 178 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from abc import ABC, abstractmethod
2-
from typing import Dict, Optional, Tuple
2+
from typing import Any, Callable, Dict, List, Optional, Tuple
33

44
import torch
55
import torch.nn as nn
@@ -47,29 +47,39 @@ def calculate_cache_key(layer_id: int, head_id: int) -> str:
4747
return f"l{layer_id},h{head_id}"
4848

4949
@staticmethod
50-
def apply_update(cache, update, pos, style, transpose=False):
50+
def apply_update(
51+
cache, update, pos, style, transpose=False, update_pos=0, update_len=None
52+
):
5153
"""
5254
After inference, update the cache state for next iteration. The runtime needs to
5355
implement the same operation.
5456
"""
5557
if style == "shift_pointer":
5658
if transpose:
57-
update_len = update.size(-1)
59+
update_len = update_len or update.size(-1)
5860
updated = torch.roll(cache, -update_len, -1)
59-
updated[:, :, -update_len:] = update
61+
updated[:, :, -update_len:] = update[
62+
:, :, update_pos : update_pos + update_len
63+
]
6064
else:
61-
update_len = update.size(-2)
65+
update_len = update_len or update.size(-2)
6266
updated = torch.roll(cache, -update_len, -2)
63-
updated[:, -update_len:, :] = update
67+
updated[:, -update_len:, :] = update[
68+
:, update_pos : update_pos + update_len, :
69+
]
6470

6571
if style == "smart_mask":
6672
updated = torch.clone(cache)
6773
if transpose:
68-
update_len = update.size(-1)
69-
updated[:, :, pos : pos + update_len] = update
74+
update_len = update_len or update.size(-1)
75+
updated[:, :, pos : pos + update_len] = update[
76+
:, :, update_pos : update_pos + update_len
77+
]
7078
else:
71-
update_len = update.size(-2)
72-
updated[:, pos : pos + update_len, :] = update
79+
update_len = update_len or update.size(-2)
80+
updated[:, pos : pos + update_len, :] = update[
81+
:, update_pos : update_pos + update_len, :
82+
]
7383

7484
return updated
7585

@@ -163,6 +173,164 @@ def unmask(self, new_unmasked_len):
163173
self.unmasked_len += new_unmasked_len
164174

165175

176+
class StaticAttentionIOManager:
177+
def __init__(
178+
self,
179+
config: ModelArgs,
180+
input_len: int,
181+
cache_len: int,
182+
style: str = "shift_pointer",
183+
mask_val: float = float("-inf"),
184+
):
185+
self.mask = StaticAttentionMask(
186+
input_len, cache_len, style=style, mask_val=mask_val
187+
)
188+
189+
rope = Rope(config)
190+
freqs = rope.get_freqs(None, config.max_seq_len)
191+
self.freqs_cos = freqs[0]
192+
self.freqs_sin = freqs[1]
193+
194+
self.k_caches = {
195+
StaticKVCache.calculate_cache_key(layer_id, head_id): torch.zeros(
196+
1, cache_len, config.head_dim
197+
)
198+
for layer_id in range(config.n_layers)
199+
for head_id in range(config.n_kv_heads)
200+
}
201+
self.v_caches = {
202+
StaticKVCache.calculate_cache_key(layer_id, head_id): torch.zeros(
203+
1, cache_len, config.head_dim
204+
)
205+
for layer_id in range(config.n_layers)
206+
for head_id in range(config.n_kv_heads)
207+
}
208+
209+
self.config = config
210+
self.input_len = input_len
211+
self.cache_len = cache_len
212+
self.style = style
213+
self.mask_val = mask_val
214+
self.pos = 0
215+
self.cache_full = False
216+
217+
def reset(self):
218+
self.pos = 0
219+
self.cache_full = False
220+
self.mask.reset()
221+
222+
def prefill(
223+
self,
224+
model: Callable[..., Any],
225+
tokens: List[int],
226+
) -> torch.Tensor:
227+
if self.cache_full:
228+
raise RuntimeError("KV cache is full.")
229+
230+
self.mask.tensor[:, :, self.cache_len :] = torch.triu(
231+
torch.full((1, self.input_len, self.input_len), self.mask_val),
232+
diagonal=1,
233+
)
234+
235+
logits = None
236+
all_logits = None
237+
for i in range(0, len(tokens), self.input_len):
238+
logits = self._run_once(model, tokens[i : i + self.input_len])[0]
239+
if self.config.generate_full_logits:
240+
if all_logits is None:
241+
all_logits = logits
242+
else:
243+
all_logits = torch.cat([all_logits, logits], dim=1)
244+
245+
if self.config.generate_full_logits:
246+
return all_logits[:, : len(tokens), :]
247+
248+
return logits
249+
250+
def decode(
251+
self,
252+
model: Callable[..., Any],
253+
init_token: int,
254+
n: int,
255+
stop_tokens: Optional[List[int]] = None,
256+
):
257+
if self.cache_full:
258+
raise RuntimeError("KV cache is full.")
259+
260+
self.mask.tensor[:, :, self.cache_len :] = torch.triu(
261+
torch.full((1, self.input_len, self.input_len), self.mask_val),
262+
diagonal=1,
263+
)
264+
265+
stop_tokens = stop_tokens or []
266+
new_tokens = [init_token]
267+
for _ in range(n):
268+
y = self._run_once(model, new_tokens[-1:])[0]
269+
new_tokens.append(y[:, :1, :].argmax().item())
270+
if new_tokens[-1] in stop_tokens:
271+
break
272+
273+
return new_tokens
274+
275+
def _run_once(
276+
self,
277+
model: Callable[..., Any],
278+
tokens: List[int],
279+
non_padded_len: Optional[int] = None,
280+
freqs_cos_override: Optional[torch.Tensor] = None,
281+
freqs_sin_override: Optional[torch.Tensor] = None,
282+
):
283+
n_tokens = len(tokens)
284+
if n_tokens < self.input_len:
285+
tokens += [0] * (self.input_len - n_tokens)
286+
tokens = torch.tensor([tokens], dtype=torch.int32)
287+
if freqs_cos_override is None:
288+
freqs_cos_override = self.freqs_cos[self.pos : self.pos + self.input_len]
289+
if freqs_sin_override is None:
290+
freqs_sin_override = self.freqs_sin[self.pos : self.pos + self.input_len]
291+
y, attn_updates = model(
292+
tokens,
293+
{
294+
"mask": self.mask.tensor,
295+
"freqs_cos_override": freqs_cos_override,
296+
"freqs_sin_override": freqs_sin_override,
297+
"in_cache_state": (self.k_caches, self.v_caches),
298+
},
299+
)
300+
non_padded_len = non_padded_len or n_tokens
301+
if self.pos + non_padded_len <= self.cache_len:
302+
self._update_states(attn_updates, 0, non_padded_len)
303+
else:
304+
self.cache_full = True
305+
306+
return y, attn_updates
307+
308+
def _update_states(self, attn_updates, update_pos, update_len):
309+
assert self.pos + update_len <= self.cache_len
310+
311+
self.mask.unmask(update_len)
312+
k_cache_updates, v_cache_updates = attn_updates["out_cache_state"]
313+
for cache_id, update in k_cache_updates.items():
314+
self.k_caches[cache_id] = StaticKVCache.apply_update(
315+
self.k_caches[cache_id],
316+
update,
317+
self.pos,
318+
style=self.style,
319+
update_pos=update_pos,
320+
update_len=update_len,
321+
)
322+
for cache_id, update in v_cache_updates.items():
323+
self.v_caches[cache_id] = StaticKVCache.apply_update(
324+
self.v_caches[cache_id],
325+
update,
326+
self.pos,
327+
style=self.style,
328+
update_pos=update_pos,
329+
update_len=update_len,
330+
)
331+
self.pos += update_len
332+
333+
166334
class _Rope(nn.Module):
167335
def __init__(self, use_hf_rope):
168336
super().__init__()

examples/models/llama/tests/test_static_attention.py

Lines changed: 6 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
import unittest
22

33
import torch
4-
from executorch.examples.models.llama.attention import AttentionMHA, ForwardOptions
4+
from executorch.examples.models.llama.attention import AttentionMHA
55
from executorch.examples.models.llama.llama_transformer import construct_transformer
66
from executorch.examples.models.llama.model_args import ModelArgs
77
from executorch.examples.models.llama.rope import Rope
88
from executorch.examples.models.llama.static_attention import (
99
StaticAttention,
10+
StaticAttentionIOManager,
1011
StaticAttentionMask,
1112
StaticKVCache,
1213
)
@@ -171,62 +172,21 @@ def test_within_transformer(self):
171172
static_layer.attention.load_weights_from_attention_mha(mha_layer.attention)
172173

173174
x = torch.randint(config.vocab_size, (1, config.max_seq_len))
174-
rope = Rope(config)
175-
freqs_cos, freqs_sin = rope.get_freqs(None, config.max_seq_len)
176175
expected = mha_transformer(x)
177176

178177
n_chunks = 3
179178
chunk_len = config.max_seq_len // n_chunks
180179
cache_len = config.max_seq_len - chunk_len
181180

182181
def test_with_style(style):
183-
mask = StaticAttentionMask(chunk_len, cache_len, style=style)
184-
mask.tensor[:, :, cache_len:] = torch.triu(
185-
torch.full((1, chunk_len, chunk_len), float("-inf")),
186-
diagonal=1,
187-
)
188-
k_caches = {
189-
StaticKVCache.calculate_cache_key(layer_id, i): torch.zeros(
190-
1, cache_len, config.head_dim
191-
)
192-
for layer_id in range(config.n_layers)
193-
for i in range(config.n_kv_heads)
194-
}
195-
v_caches = {
196-
StaticKVCache.calculate_cache_key(layer_id, i): torch.zeros(
197-
1, cache_len, config.head_dim
198-
)
199-
for layer_id in range(config.n_layers)
200-
for i in range(config.n_kv_heads)
201-
}
182+
mgr = StaticAttentionIOManager(config, chunk_len, cache_len, style=style)
202183
ys = []
203184
for i in range(n_chunks):
204-
y_i, attn_update = static_transformer(
205-
x[:, i * chunk_len : (i + 1) * chunk_len],
206-
attn_options=ForwardOptions(
207-
mask=mask.tensor,
208-
freqs_cos_override=freqs_cos[
209-
i * chunk_len : (i + 1) * chunk_len
210-
],
211-
freqs_sin_override=freqs_sin[
212-
i * chunk_len : (i + 1) * chunk_len
213-
],
214-
in_cache_state=(k_caches, v_caches),
215-
out_cache_state=({}, {}),
216-
),
185+
y_i = mgr.prefill(
186+
static_transformer,
187+
x[0][i * chunk_len : (i + 1) * chunk_len].tolist(),
217188
)
218189
ys.append(y_i)
219-
mask.unmask(chunk_len)
220-
k_cache_updates, v_cache_updates = attn_update["out_cache_state"]
221-
if i < n_chunks - 1:
222-
for cache_id, update in k_cache_updates.items():
223-
k_caches[cache_id] = StaticKVCache.apply_update(
224-
k_caches[cache_id], update, pos=chunk_len * i, style=style
225-
)
226-
for cache_id, update in v_cache_updates.items():
227-
v_caches[cache_id] = StaticKVCache.apply_update(
228-
v_caches[cache_id], update, pos=chunk_len * i, style=style
229-
)
230190

231191
self.assertTrue(torch.isclose(ys[-1], expected, rtol=1e-3).all())
232192

examples/selective_build/CMakeLists.txt

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ if(NOT CMAKE_CXX_STANDARD)
3333
# Can't set to 11 due to executor_runner.cpp make_unique
3434
endif()
3535

36-
set(_common_compile_options -Wno-deprecated-declarations -fPIC)
36+
set(_common_compile_options -Wno-deprecated-declarations -fPIC -ffunction-sections -fdata-sections)
3737

3838
# Let files say "include <executorch/path/to/header.h>".
3939
set(_common_include_directories ${EXECUTORCH_ROOT}/..)
@@ -123,13 +123,25 @@ gen_selected_ops(
123123
)
124124

125125
generate_bindings_for_kernels(
126-
LIB_NAME "select_build_lib" FUNCTIONS_YAML
127-
${EXECUTORCH_ROOT}/kernels/portable/functions.yaml CUSTOM_OPS_YAML
126+
LIB_NAME
127+
"select_build_lib"
128+
FUNCTIONS_YAML
129+
${EXECUTORCH_ROOT}/kernels/portable/functions.yaml
130+
CUSTOM_OPS_YAML
128131
"${_custom_ops_yaml}"
132+
DTYPE_SELECTIVE_BUILD
133+
"${EXECUTORCH_DTYPE_SELECTIVE_BUILD}"
129134
)
130135

131136
gen_operators_lib(
132-
LIB_NAME "select_build_lib" KERNEL_LIBS ${_kernel_lib} DEPS executorch_core
137+
LIB_NAME
138+
"select_build_lib"
139+
KERNEL_LIBS
140+
${_kernel_lib}
141+
DEPS
142+
executorch_core
143+
DTYPE_SELECTIVE_BUILD
144+
"${EXECUTORCH_DTYPE_SELECTIVE_BUILD}"
133145
)
134146

135147
list(TRANSFORM _executor_runner__srcs PREPEND "${EXECUTORCH_ROOT}/")

examples/selective_build/test_selective_build.sh

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -162,13 +162,17 @@ test_cmake_select_ops_in_yaml() {
162162
}
163163

164164
test_cmake_select_ops_in_model() {
165-
echo "Exporting MobilenetV2"
166-
${PYTHON_EXECUTABLE} -m examples.portable.scripts.export --model_name="mv2"
165+
local model_name="add_mul"
166+
local model_export_name="${model_name}.pte"
167+
echo "Exporting ${model_name}"
168+
${PYTHON_EXECUTABLE} -m examples.portable.scripts.export --model_name="${model_name}"
167169
local example_dir=examples/selective_build
168170
local build_dir=cmake-out/${example_dir}
169171
rm -rf ${build_dir}
170-
retry cmake -DCMAKE_BUILD_TYPE=$CMAKE_BUILD_TYPE \
171-
-DEXECUTORCH_SELECT_OPS_FROM_MODEL="./mv2.pte" \
172+
retry cmake -DCMAKE_BUILD_TYPE="$CMAKE_BUILD_TYPE" \
173+
-DEXECUTORCH_SELECT_OPS_FROM_MODEL="./${model_export_name}" \
174+
-DEXECUTORCH_DTYPE_SELECTIVE_BUILD=ON \
175+
-DEXECUTORCH_OPTIMIZE_SIZE=ON \
172176
-DCMAKE_INSTALL_PREFIX=cmake-out \
173177
-DPYTHON_EXECUTABLE="$PYTHON_EXECUTABLE" \
174178
-B${build_dir} \
@@ -178,10 +182,10 @@ test_cmake_select_ops_in_model() {
178182
cmake --build ${build_dir} -j9 --config $CMAKE_BUILD_TYPE
179183

180184
echo 'Running selective build test'
181-
${build_dir}/selective_build_test --model_path="./mv2.pte"
185+
${build_dir}/selective_build_test --model_path="./${model_export_name}"
182186

183-
echo "Removing mv2.pte"
184-
rm "./mv2.pte"
187+
echo "Removing ${model_export_name}"
188+
rm "./${model_export_name}"
185189
}
186190

187191
if [[ -z $BUCK ]];

extension/apple/ExecuTorch/Exported/ExecuTorch+Tensor.swift

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1095,3 +1095,10 @@ public extension Tensor {
10951095
))
10961096
}
10971097
}
1098+
1099+
@available(*, deprecated, message: "This API is experimental.")
1100+
extension Tensor: CustomStringConvertible {
1101+
public var description: String {
1102+
self.anyTensor.description
1103+
}
1104+
}

0 commit comments

Comments
 (0)