Skip to content

Commit b3f3111

Browse files
authored
Qualcomm AI Engine Direct - sliding attention lookahead support (#14412)
### Summary - add lookahead decode to speed up prompt calibration - sliding attention lookahead support ### Test plan python examples/qualcomm/oss_scripts/llama/llama.py --decoder_model gemma3-1b -b build-android/ -m SM8750 -s 5f396958 --prompt "Could you tell me about Facebook?" --max_seq_len 1024 --kv_updater smart_mask --prefill_ar_len 16 --model_mode lookahead --compile_only --tasks wikitext --limit 1 --model_mode lookahead --window 4 --ngram 3 --gcap 4
1 parent 16ced4e commit b3f3111

File tree

8 files changed

+430
-107
lines changed

8 files changed

+430
-107
lines changed

examples/qualcomm/oss_scripts/llama/decoder_utils.py

Lines changed: 325 additions & 58 deletions
Large diffs are not rendered by default.

examples/qualcomm/oss_scripts/llama/llama.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,7 @@ def quantize(
223223
custom_annotations=(),
224224
scales_state_dict=None,
225225
chat_template=None,
226+
lookahead_config=None,
226227
):
227228
self.quant_dtype = quant_dtype
228229
quantizer = make_custom_quantizer(
@@ -290,6 +291,7 @@ def quantize(
290291
prompt=prompt,
291292
use_i64_token=args.embedding_quantize is not None,
292293
event_name="prepare_pt2e_prompt",
294+
lookahead_config=lookahead_config,
293295
)
294296
if scales_state_dict:
295297
set_scales(
@@ -336,6 +338,7 @@ def quantize(
336338
prompt=prompt,
337339
use_i64_token=args.embedding_quantize is not None,
338340
event_name="convert_pt2e_prompt",
341+
lookahead_config=lookahead_config,
339342
)
340343

341344
def save_logits_quant_attrs(self):
@@ -497,13 +500,6 @@ def compile(
497500
)
498501
)
499502
elif args.model_mode == "lookahead":
500-
# TODO: Lookahead decoding is not yet supported for gemma3-1b.
501-
# This will be implemented once the model architecture and KV update logic are adapted.
502-
if args.decoder_model == "gemma3-1b":
503-
raise NotImplementedError(
504-
"gemma3-1b does not currently support lookahead decoding."
505-
)
506-
507503
llama_instance_list.append(
508504
LLM_VARIANT_ARCHS.get(args.decoder_model, LlamaModel)(
509505
kv_config,
@@ -697,13 +693,19 @@ def permute(w, heads):
697693
custom_annotations = decoder_model_config.custom_annotation
698694
kv_quant_attrs = {}
699695
for i, llama_instance in enumerate(llama_instance_list):
696+
lookahead_config = (
697+
(args.window, args.ngram, args.gcap)
698+
if i == 0 and args.model_mode == "lookahead"
699+
else None
700+
)
700701
llama_instance.quantize(
701702
quant_dtype=quant_dtype,
702703
args=args,
703704
tokenizer=tokenizer,
704705
custom_annotations=custom_annotations,
705706
scales_state_dict=scales_state_dict,
706707
chat_template=chat_template,
708+
lookahead_config=lookahead_config,
707709
)
708710
# If hybrid and lookahead mode, we store kv output quant_attrs and apply to prefill output quant_attrs later
709711
if i == 0 and args.model_mode in ["hybrid", "lookahead"]:

examples/qualcomm/oss_scripts/llama/masking_utils.py

Lines changed: 27 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -93,24 +93,26 @@ def mask(self) -> torch.Tensor:
9393
pass
9494

9595
@abstractmethod
96-
def smart_mask_update(self, pos, n_updates):
96+
def smart_mask_update(self, pos, n_updates, lade_pos_offset):
9797
"""
9898
Update the attention mask by smart mask update method after model forward.
9999
100100
Args:
101101
pos (int): Current position in the sequence.
102102
n_updates (int): Number of new tokens to update.
103+
lade_pos_offset (List[int]): Position offset of lookahead attention mask.
103104
"""
104105
pass
105106

106107
@abstractmethod
107-
def shift_pointer_update(self, pos, n_updates):
108+
def shift_pointer_update(self, pos, n_updates, lade_pos_offset):
108109
"""
109110
Update the attention mask by shift pointer update method after model forward.
110111
111112
Args:
112113
pos (int): Current position in the sequence.
113114
n_updates (int): Number of tokens to shift.
115+
lade_pos_offset (List[int]): Position offset of lookahead attention mask.
114116
"""
115117
pass
116118

@@ -124,7 +126,7 @@ def __init__(self, max_batch_size: int, ar_len: int, max_seq_len: int):
124126
def mask(self):
125127
return self._mask
126128

127-
def smart_mask_update(self, pos, n_updates):
129+
def smart_mask_update(self, pos, n_updates, _):
128130
"""
129131
Smart Mask mechanism for attention mask updating
130132
@@ -159,7 +161,7 @@ def smart_mask_update(self, pos, n_updates):
159161
end_pos = pos + n_updates
160162
self.mask[:, :, start_pos:end_pos] = 0
161163

162-
def shift_pointer_update(self, pos, n_updates):
164+
def shift_pointer_update(self, pos, n_updates, _):
163165
"""
164166
Shift Pointer mechanism for attention mask updating
165167
@@ -173,7 +175,7 @@ def shift_pointer_update(self, pos, n_updates):
173175
3 ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ● ● ● ● ○
174176
4 ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ● ● ● ● ●
175177
176-
After 1st update (e.g., pos=0, n_updates=5, sliding_window=3):
178+
After 1st update (e.g., pos=0, n_updates=5):
177179
Newly added tokens are unmasked (set to 0).
178180
179181
0 ○ ○ ○ ○ ○ ● ● ● ● ● ● ○ ○ ○ ○
@@ -213,7 +215,7 @@ def __init__(
213215
def mask(self):
214216
return self._mask
215217

216-
def smart_mask_update(self, pos, n_updates):
218+
def smart_mask_update(self, pos, n_updates, lade_pos_offset):
217219
"""
218220
Smart Mask mechanism for attention mask updating
219221
@@ -237,7 +239,8 @@ def smart_mask_update(self, pos, n_updates):
237239
3 ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ● ● ● ○
238240
4 ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ● ● ●
239241
240-
After 2nd update (e.g., pos=5, n_updates=5):
242+
243+
After 2nd update (e.g., pos=5, n_updates=5, sliding_window=3):
241244
Sliding window shifts again, masking older positions and activate new postion.
242245
243246
0 ○ ○ ○ ○ ○ ○ ○ ○ ● ● ● ○ ○ ○ ○
@@ -252,16 +255,18 @@ def smart_mask_update(self, pos, n_updates):
252255
self.mask[:, :, start_pos:end_pos] = 0
253256

254257
for i in range(self.ar_len):
255-
# Calculate how many cached tokens are still avalible for this row
256-
avalible_cache_len = self.sliding_window - (i + 1)
258+
# Calculate how many cached tokens are still available for this row
259+
available_cache_len = self.sliding_window - (
260+
(i + 1) if lade_pos_offset is None else (lade_pos_offset[i] + 1)
261+
)
257262

258263
# If the current position exceeds available cache, mask the overflow
259-
if end_pos > avalible_cache_len:
264+
if end_pos > available_cache_len:
260265
# Mask tokens that are no longer within the sliding window
261266
# TODO: [Optional]: it can be optimized by computing the exact start index
262-
self.mask[:, i, : end_pos - avalible_cache_len] = -255.0
267+
self.mask[:, i, : end_pos - available_cache_len] = -255.0
263268

264-
def shift_pointer_update(self, pos, n_updates):
269+
def shift_pointer_update(self, pos, n_updates, lade_pos_offset):
265270
"""
266271
Shift Pointer mechanism for attention mask updating
267272
@@ -283,7 +288,7 @@ def shift_pointer_update(self, pos, n_updates):
283288
3 ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ● ● ● ○
284289
4 ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ● ● ●
285290
286-
After 2nd update (e.g., pos=5, n_updates=5):
291+
After 2nd update (e.g., pos=5, n_updates=5, sliding_window=3):
287292
288293
0 ○ ○ ○ ○ ○ ○ ○ ○ ● ● ● ○ ○ ○ ○
289294
1 ○ ○ ○ ○ ○ ○ ○ ○ ○ ● ● ● ○ ○ ○
@@ -297,28 +302,30 @@ def shift_pointer_update(self, pos, n_updates):
297302
self.mask[:, :, start_pos:end_pos] = 0
298303

299304
for i in range(self.ar_len):
300-
avalible_cache_len = self.sliding_window - (i + 1)
301-
if abs(start_pos + self.ar_len) > avalible_cache_len:
305+
available_cache_len = self.sliding_window - (
306+
(i + 1) if lade_pos_offset is None else (lade_pos_offset[i] + 1)
307+
)
308+
if abs(start_pos + self.ar_len) > available_cache_len:
302309
self.mask[
303310
:,
304311
i,
305312
start_pos : start_pos
306313
+ abs(start_pos + self.ar_len)
307-
- avalible_cache_len,
314+
- available_cache_len,
308315
] = -255.0
309316

310317

311318
class AttentionMask:
312319
def __init__(self, masks: Union[BaseAttentionMask, List[BaseAttentionMask]]):
313320
self.masks = masks if isinstance(masks, list) else [masks]
314321

315-
def smart_mask_update(self, pos, n_updates):
322+
def smart_mask_update(self, pos, n_updates, lade_pos_offset=None):
316323
for mask in self.masks:
317-
mask.smart_mask_update(pos, n_updates)
324+
mask.smart_mask_update(pos, n_updates, lade_pos_offset)
318325

319-
def shift_pointer_update(self, pos, n_updates):
326+
def shift_pointer_update(self, pos, n_updates, lade_pos_offset=None):
320327
for mask in self.masks:
321-
mask.shift_pointer_update(pos, n_updates)
328+
mask.shift_pointer_update(pos, n_updates, lade_pos_offset)
322329

323330
def __iter__(self):
324331
return iter([mask.mask for mask in self.masks])

examples/qualcomm/oss_scripts/llama/runner/kv_manager.cpp

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,8 @@ void KVManager<T>::init_attention_mask(
122122
const std::vector<int32_t>& attention_map,
123123
int32_t ar_len,
124124
int32_t n_past,
125-
int32_t sliding_window) {
125+
int32_t sliding_window,
126+
const std::vector<int32_t>& position_offset) {
126127
ET_CHECK_MSG(
127128
attention_map.size() <= ar_len,
128129
"The size of attention_map (%zu) doesn't match with ar_len (%d)",
@@ -154,11 +155,12 @@ void KVManager<T>::init_attention_mask(
154155
}
155156
// Attend to itself
156157
new_ptr[i] = pos_val;
157-
158158
// mask by limitation of sliding_window
159-
int32_t avalible_context_len = sliding_window - (i + 1) - n_past;
160-
if (n_past > avalible_context_len) {
161-
std::fill_n(past_ptr, n_past - avalible_context_len, neg_val);
159+
int32_t available_context_len = position_offset.empty()
160+
? sliding_window - (i + 1) - n_past
161+
: sliding_window - (position_offset[i] + 1) - n_past;
162+
if (n_past > available_context_len) {
163+
std::fill_n(past_ptr, n_past - available_context_len, neg_val);
162164
}
163165

164166
past_ptr += metadata_.context_len;
@@ -219,7 +221,8 @@ void KVManager<T>::update_attention_mask(
219221
int32_t ar_len,
220222
int32_t n_past,
221223
int32_t n_update,
222-
int32_t sliding_window) {
224+
int32_t sliding_window,
225+
const std::vector<int32_t>& position_offset) {
223226
uint16_t pos_val = 65535;
224227
uint16_t neg_val = 0;
225228
uint16_t* cur_ptr = attention_mask;
@@ -230,17 +233,19 @@ void KVManager<T>::update_attention_mask(
230233

231234
for (int i = 0; i < ar_len; i++) {
232235
std::fill_n(cur_ptr, n_update, pos_val);
233-
int32_t avalible_cache_len = sliding_window - (i + 1);
236+
int32_t available_cache_len = position_offset.empty()
237+
? sliding_window - (i + 1)
238+
: sliding_window - (position_offset[i] + 1);
234239
if (kv_updater_ == KVManagerMode::SMART_MASK) {
235-
if (n_past + n_update > avalible_cache_len) {
240+
if (n_past + n_update > available_cache_len) {
236241
std::fill_n(
237-
cur_ptr - n_past, n_past + n_update - avalible_cache_len, neg_val);
242+
cur_ptr - n_past, n_past + n_update - available_cache_len, neg_val);
238243
}
239244
} else if (kv_updater_ == KVManagerMode::SHIFT_POINTER) {
240-
if (std::abs(n_past + ar_len) > avalible_cache_len) {
241-
int32_t n_invalid = n_past - avalible_cache_len;
245+
if (std::abs(n_past + ar_len) > available_cache_len) {
246+
int32_t n_invalid = n_past - available_cache_len;
242247
std::fill_n(
243-
cur_ptr, std::abs(n_past + ar_len) - avalible_cache_len, neg_val);
248+
cur_ptr, std::abs(n_past + ar_len) - available_cache_len, neg_val);
244249
}
245250
}
246251
cur_ptr += metadata_.context_len;

examples/qualcomm/oss_scripts/llama/runner/kv_manager.h

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,13 +95,17 @@ class KVManager {
9595
* of attention map should be [ar_len].
9696
* @param ar_len Length of input tokens.
9797
* @param n_past Number of past elements in the cache.
98+
* @param sliding_window Length of sliding window for sliding window attention
99+
* mask
100+
* @param position_offset (optional) attention mask position offset of
98101
*/
99102
void init_attention_mask(
100103
uint16_t* attention_mask,
101104
const std::vector<int32_t>& attention_map,
102105
int32_t ar_len,
103106
int32_t n_past,
104-
int32_t sliding_window);
107+
int32_t sliding_window,
108+
const std::vector<int32_t>& position_offset = {});
105109

106110
/**
107111
* @brief Update attention mask based on kv manager mode, and n_update.
@@ -126,13 +130,16 @@ class KVManager {
126130
* @param n_update Number of elements to be updated.
127131
* @param sliding_window Length of sliding window for sliding window attention
128132
* mask
133+
* @param position_offset (optional) attention mask position offset of
134+
* lookahead decoder
129135
*/
130136
void update_attention_mask(
131137
uint16_t* attention_mask,
132138
int32_t ar_len,
133139
int32_t n_past,
134140
int32_t n_update,
135-
int32_t sliding_window);
141+
int32_t sliding_window,
142+
const std::vector<int32_t>& position_offset = {});
136143

137144
/**
138145
* @brief Reset the data pointer of the I/O cache tensor based on number of

examples/qualcomm/oss_scripts/llama/runner/lhd_token_generator.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,16 @@ void LhdTokenGenerator<T>::init_attention_mask(int32_t n_past) {
6161

6262
this->kv_manager_->init_attention_mask(
6363
this->attention_mask_.data, attention_map, metadata_.ar_len, n_past);
64+
// Initialize window attention mask with current position
65+
if (metadata_.cache_mode == CacheMode::HybridCache) {
66+
this->kv_manager_->init_attention_mask(
67+
this->window_attention_mask_.data,
68+
attention_map,
69+
metadata_.ar_len,
70+
n_past,
71+
metadata_.sliding_window,
72+
position_offset_);
73+
}
6474
}
6575

6676
template <typename T>
@@ -378,6 +388,15 @@ Result<int64_t> LhdTokenGenerator<T>::generate(
378388
// Update attention mask with current position
379389
this->kv_manager_->update_attention_mask(
380390
this->attention_mask_.data, metadata_.ar_len, prev_pos, n_update);
391+
if (metadata_.cache_mode == CacheMode::HybridCache) {
392+
this->kv_manager_->update_attention_mask(
393+
this->window_attention_mask_.data,
394+
metadata_.ar_len,
395+
prev_pos,
396+
n_update,
397+
metadata_.sliding_window,
398+
position_offset_);
399+
}
381400

382401
// data-dependent terminating condition: we have n_eos_ number of EOS
383402
if (this->eos_ids_->count(cur_token) > 0) {

examples/qualcomm/oss_scripts/llama/runner/lhd_token_generator.h

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ class LhdTokenGenerator : public TokenGenerator<T> {
2929
int32_t window;
3030
int32_t gcap;
3131
int sliding_window;
32+
CacheMode cache_mode;
3233
};
3334
LhdTokenGenerator(
3435
tokenizers::Tokenizer* tokenizer,
@@ -51,7 +52,8 @@ class LhdTokenGenerator : public TokenGenerator<T> {
5152
metadata.ar_len,
5253
metadata.vocab_size,
5354
metadata.use_int64_token,
54-
metadata.sliding_window},
55+
metadata.sliding_window,
56+
metadata.cache_mode},
5557
stats),
5658
metadata_(metadata),
5759
lhd_branch_(metadata.ngram - 1, std::vector<int32_t>(metadata.window)),
@@ -63,6 +65,22 @@ class LhdTokenGenerator : public TokenGenerator<T> {
6365
metadata.ngram,
6466
metadata.window,
6567
metadata.gcap);
68+
69+
// initialize position offset
70+
position_offset_ = std::vector<int32_t>(metadata.ar_len);
71+
int idx = 0;
72+
// lookahead branches
73+
for (int i = 0; i < metadata.ngram - 1; ++i) {
74+
for (int j = 0; j < metadata.window; ++j) {
75+
position_offset_[idx++] = i + j;
76+
}
77+
}
78+
// verification branches
79+
for (int i = 0; i < metadata.gcap; ++i) {
80+
for (int j = 1; j < metadata.ngram; ++j) {
81+
position_offset_[idx++] = j;
82+
}
83+
}
6684
}
6785

6886
~LhdTokenGenerator() = default;
@@ -136,6 +154,9 @@ class LhdTokenGenerator : public TokenGenerator<T> {
136154
// verification branch
137155
std::vector<NgramData> v_branch_;
138156

157+
// position offset in attention mask
158+
std::vector<int32_t> position_offset_;
159+
139160
// n-gram pools
140161
NgramContainer ngrams_pool_;
141162
};

0 commit comments

Comments
 (0)