Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 27 additions & 109 deletions examples/models/llama/runner/static_attention_io_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,9 @@
namespace example {

enum class StaticAttentionUpdateStyle {
/**
* KV caches will have valid data at the end of the cache. New elements are
* added at the end and the start of the cache will slide forward to maintain
* this invariant. This potentially allows shorter caches to be passed into
* the model by adjusting the start pointer.
*/
SLIDING_CACHE,
/**
* I/O pointers do not change which can enable persistent memory mapping
* between AP and NPU.
* between AP and NPU. However cache updates need to be copied.
*/
SMART_MASK,
};
Expand All @@ -49,35 +42,16 @@ class StaticKVCache {
size_t head_dim,
size_t max_input_len = 1,
size_t n_heads_per_cache = 1,
bool transpose = false,
StaticAttentionUpdateStyle style =
StaticAttentionUpdateStyle::SLIDING_CACHE)
StaticAttentionUpdateStyle style = StaticAttentionUpdateStyle::SMART_MASK)
: n_caches_(n_caches),
cache_len_(cache_len),
max_input_len_(max_input_len),
n_heads_per_cache_(n_heads_per_cache),
head_dim_(head_dim),
transpose_(transpose),
style_(style),
input_ptrs_(n_caches_),
output_ptrs_(n_caches_) {
if (transpose_) {
throw std::runtime_error("Not implemented.");
}

if (style_ == StaticAttentionUpdateStyle::SLIDING_CACHE &&
n_heads_per_cache_ > 1) {
throw std::runtime_error(
"Sliding cache update strategy doesn't support more than one head per cache tensor.");
}

if (style_ == StaticAttentionUpdateStyle::SLIDING_CACHE) {
// Allocates on extra copy to accomodate caches sliding forward.
cache_data_size_ = (n_caches_ + 1) * cache_len_ * head_dim_;
} else {
cache_data_size_ =
n_caches_ * n_heads_per_cache_ * cache_len_ * head_dim_;
}
cache_data_size_ = n_caches_ * n_heads_per_cache_ * cache_len_ * head_dim_;
update_data_size_ =
n_caches_ * n_heads_per_cache_ * max_input_len_ * head_dim_;

Expand Down Expand Up @@ -145,17 +119,10 @@ class StaticKVCache {
"Cache input tensor expected to have rank 3.");
}
auto ndim = inSizes.size();
if (transpose_) {
ET_CHECK_MSG(inSizes[ndim - 2] == head_dim_, "KV head dim mismatch.");
ET_CHECK_MSG(outSizes[ndim - 2] == head_dim_, "KV head dim mismatch.");
ET_CHECK_MSG(
inSizes[ndim - 1] == cache_len_, "Cache length dim mismatch.");
} else {
ET_CHECK_MSG(inSizes[ndim - 1] == head_dim_, "KV head dim mismatch.");
ET_CHECK_MSG(outSizes[ndim - 1] == head_dim_, "KV head dim mismatch.");
ET_CHECK_MSG(
inSizes[ndim - 2] == cache_len_, "Cache length dim mismatch.");
}
ET_CHECK_MSG(inSizes[ndim - 1] == head_dim_, "KV head dim mismatch.");
ET_CHECK_MSG(outSizes[ndim - 1] == head_dim_, "KV head dim mismatch.");
ET_CHECK_MSG(
inSizes[ndim - 2] == cache_len_, "Cache length dim mismatch.");

auto impl = ::executorch::runtime::etensor::TensorImpl(
inMeta->scalar_type(),
Expand Down Expand Up @@ -189,11 +156,21 @@ class StaticKVCache {
throw std::runtime_error("Cache capacity exceeded.");
}

if (style_ == StaticAttentionUpdateStyle::SLIDING_CACHE) {
update_sliding_cache(method, output_indices, update_len, update_pos);
} else {
update_smart_mask(method, output_indices, update_len, update_pos);
for (size_t i = 0; i < n_caches_; i++) {
const auto& updateTensor =
method.get_output(output_indices[i]).toTensor();
ET_CHECK(output_ptrs_[i] == updateTensor.mutable_data_ptr<T>());
auto update_seq_len = updateTensor.size(updateTensor.dim() - 2);
for (size_t j = 0; j < n_heads_per_cache_; j++) {
auto* update_head = output_ptrs_[i] + update_seq_len * head_dim_ * j;
auto* cache_head = input_ptrs_[i] + cache_len_ * head_dim_ * j;
std::copy(
update_head + update_pos * head_dim_,
update_head + (update_pos + update_len) * head_dim_,
cache_head + valid_len_ * head_dim_);
}
}
valid_len_ += update_len;
}

/**
Expand All @@ -202,9 +179,6 @@ class StaticKVCache {
*/
void reset() {
valid_len_ = 0;
if (style_ == StaticAttentionUpdateStyle::SLIDING_CACHE) {
init_ptrs();
}
}

size_t size() {
Expand All @@ -223,53 +197,11 @@ class StaticKVCache {
}
}

void update_sliding_cache(
torch::executor::Method& method,
const std::vector<size_t>& output_indices,
size_t update_len,
size_t update_pos) {
ET_CHECK(n_caches_ == output_indices.size());
for (size_t i = 0; i < n_caches_; i++) {
const auto& updateTensor =
method.get_output(output_indices[i]).toTensor();
ET_CHECK(output_ptrs_[i] == updateTensor.const_data_ptr<T>());
std::copy(
output_ptrs_[i] + update_pos * head_dim_,
output_ptrs_[i] + (update_pos + update_len) * head_dim_,
input_ptrs_[i] + cache_len_ * head_dim_);
input_ptrs_[i] += update_len * head_dim_;
}
valid_len_ += update_len;
}

void update_smart_mask(
torch::executor::Method& method,
const std::vector<size_t>& output_indices,
size_t update_len,
size_t update_pos) {
for (size_t i = 0; i < n_caches_; i++) {
const auto& updateTensor =
method.get_output(output_indices[i]).toTensor();
ET_CHECK(output_ptrs_[i] == updateTensor.mutable_data_ptr<T>());
auto update_seq_len = updateTensor.size(updateTensor.dim() - 2);
for (size_t j = 0; j < n_heads_per_cache_; j++) {
auto* update_head = output_ptrs_[i] + update_seq_len * head_dim_ * j;
auto* cache_head = input_ptrs_[i] + cache_len_ * head_dim_ * j;
std::copy(
update_head + update_pos * head_dim_,
update_head + (update_pos + update_len) * head_dim_,
cache_head + valid_len_ * head_dim_);
}
}
valid_len_ += update_len;
}

size_t n_caches_;
size_t cache_len_;
size_t max_input_len_;
size_t n_heads_per_cache_;
size_t head_dim_;
bool transpose_;
StaticAttentionUpdateStyle style_;
AllocatorT allocator_;
size_t cache_data_size_;
Expand Down Expand Up @@ -300,8 +232,7 @@ class StaticAttentionMask {
size_t head_dim,
T zero_val,
T mask_val,
StaticAttentionUpdateStyle style =
StaticAttentionUpdateStyle::SLIDING_CACHE)
StaticAttentionUpdateStyle style = StaticAttentionUpdateStyle::SMART_MASK)
: cache_len_(cache_len),
input_len_(input_len),
head_dim_(head_dim),
Expand Down Expand Up @@ -341,20 +272,10 @@ class StaticAttentionMask {
* prefilling with padded inputs.
*/
void unmask(size_t update_len) {
if (style_ == StaticAttentionUpdateStyle::SLIDING_CACHE) {
for (size_t i = 0; i < input_len_; i++) {
auto* p = data_ + (cache_len_ + input_len_) * i;
std::fill(
p + cache_len_ - cache_valid_len_ - update_len,
p + cache_len_ - cache_valid_len_,
zero_val_);
}
} else {
for (size_t i = 0; i < input_len_; i++) {
auto* p = data_ + (cache_len_ + input_len_) * i;
std::fill(
p + cache_valid_len_, p + cache_valid_len_ + update_len, zero_val_);
}
for (size_t i = 0; i < input_len_; i++) {
auto* p = data_ + (cache_len_ + input_len_) * i;
std::fill(
p + cache_valid_len_, p + cache_valid_len_ + update_len, zero_val_);
}
cache_valid_len_ += update_len;
}
Expand Down Expand Up @@ -468,8 +389,7 @@ class StaticAttentionIOManager {
std::vector<size_t> v_cache_output_indices;
RopeT* rope_freqs_cos;
RopeT* rope_freqs_sin;
StaticAttentionUpdateStyle style =
StaticAttentionUpdateStyle::SLIDING_CACHE;
StaticAttentionUpdateStyle style = StaticAttentionUpdateStyle::SMART_MASK;
};

StaticAttentionIOManager(StaticAttentionIOConfig config)
Expand All @@ -480,15 +400,13 @@ class StaticAttentionIOManager {
config_.head_dim,
config_.max_input_len,
config_.n_heads_per_cache,
false,
config_.style),
vCaches_(
config_.n_caches,
config_.cache_len,
config_.head_dim,
config_.max_input_len,
config_.n_heads_per_cache,
false,
config_.style) {
ET_LOG(
Info,
Expand Down
Loading