Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions examples/models/llama/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,12 +100,20 @@ int32_t main(int32_t argc, char** argv) {
}

if (warmup) {
runner->warmup(prompt, /*max_new_tokens=*/seq_len);
auto error = runner->warmup(prompt, /*max_new_tokens=*/seq_len);
if (error != executorch::runtime::Error::Ok) {
ET_LOG(Error, "Failed to warmup llama runner");
return 1;
}
}
// generate
executorch::extension::llm::GenerationConfig config{
.seq_len = seq_len, .temperature = temperature};
runner->generate(prompt, config);
auto error = runner->generate(prompt, config);
if (error != executorch::runtime::Error::Ok) {
ET_LOG(Error, "Failed to warmup llama runner");
return 1;
}

return 0;
}
14 changes: 13 additions & 1 deletion examples/models/llama/source_transformation/custom_kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def __init__(
head_dim,
cache_type: QuantizedCacheType = QuantizedCacheType.AffineSymmetric,
use_custom_update_cache_op: bool = False,
return_float_values: bool = True,
):
super().__init__()
if cache_type not in (
Expand All @@ -57,7 +58,7 @@ def __init__(
self.use_custom_update_cache_op = use_custom_update_cache_op
self.quantized_cache_dtype = torch.int8
self.cache_fp_type = torch.float32
self.return_float_values = True
self.return_float_values = return_float_values
self.max_context_length = max_context_length
cache_shape = (max_batch_size, max_context_length, n_heads, head_dim)
scale_shape = (max_batch_size, max_context_length, n_heads, 1)
Expand Down Expand Up @@ -400,6 +401,7 @@ def __init__(
head_dim,
cache_type: QuantizedCacheType = QuantizedCacheType.AffineSymmetric,
use_custom_update_cache_op: bool = False,
return_float_values: bool = True,
):
# Look at attention.py for explanation on why max_context_length * 2
super().__init__(
Expand All @@ -409,6 +411,7 @@ def __init__(
head_dim,
cache_type,
use_custom_update_cache_op,
return_float_values,
)
self.cache_positions_manager = CachePositionsManager(self.max_context_length)
self.is_ring_buffer = True
Expand Down Expand Up @@ -459,6 +462,7 @@ def from_quantized_kv_cache(
head_dim,
kv_cache.cache_type,
kv_cache.use_custom_update_cache_op,
kv_cache.return_float_values,
)


Expand Down Expand Up @@ -583,4 +587,12 @@ def replace_kv_cache_with_ring_kv_cache(module, layer_sizes):
# it is not doing causal attention
if "SDPACustom" in attention.SDPA.__class__.__name__:
attention.SDPA.use_attention_mask = True
# QuantizedSDPA has to store kv_cache in order to obtrain
# scales and zero points for k and v cache.
# So if we replcaed attention module's quantized kv cache with
# QuantizedRingKVCache then we also have to replace attention's
# SDPA module kv_cache so that it refers to the same kv_cache
if "QuantizedSDPA" in attention.SDPA.__class__.__name__:
attention.SDPA.use_attention_mask = True
attention.SDPA.kv_cache = attention.kv_cache
return module
4 changes: 2 additions & 2 deletions extension/llm/custom_ops/op_sdpa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ bool validate_flash_attention_args(

ET_CHECK_OR_RETURN_FALSE(
!attn_mask.has_value() ||
attn_mask.value().scalar_type() == query.scalar_type(),
"Attention mask must be a 2D tensor");
attn_mask.value().scalar_type() == ScalarType::Float,
"Attention mask must be a Float tensor");

ET_CHECK_OR_RETURN_FALSE(
is_contiguous_dim_order(query.dim_order().data(), query.dim()),
Expand Down
Loading