Skip to content

Commit 370ed78

Browse files
Prefix caching for sequences with embeddings. (#1841)
Ticket: CVS-163347 --------- Co-authored-by: Ilya Lavrenov <[email protected]>
1 parent 85263fd commit 370ed78

File tree

6 files changed

+257
-54
lines changed

6 files changed

+257
-54
lines changed

src/cpp/src/block_manager.hpp

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1073,7 +1073,7 @@ class BlockManager {
10731073
// When add_request() is executed in multiple threads accessing to cached_blocks causes segfault.
10741074
// The mutex is needed to prevent such segfaults.
10751075
const std::lock_guard<std::mutex> lock(m_cached_blocks_map_mutex);
1076-
auto prompt_ids = group->get_prompt_ids();
1076+
auto prompt_len = group->get_prompt_len();
10771077
auto sequences = group->get_not_finished_sequences();
10781078
OPENVINO_ASSERT(sequences.size() == 1);
10791079
auto sequence = sequences[0];
@@ -1085,11 +1085,11 @@ class BlockManager {
10851085
auto& block_table = m_block_table[seq_id];
10861086

10871087
size_t content_len = 0;
1088-
while (content_len < prompt_ids.size()) {
1088+
while (content_len < prompt_len) {
10891089
size_t prev_iteration_content_len = content_len;
10901090
content_len += m_block_size;
1091-
if (content_len > prompt_ids.size()) {
1092-
content_len = prompt_ids.size();
1091+
if (content_len > prompt_len) {
1092+
content_len = prompt_len;
10931093
}
10941094
// restore fully filled blocks
10951095
auto full_block_hash = sequence->get_hash(content_len);
@@ -1101,11 +1101,11 @@ class BlockManager {
11011101
block->set_timestamp(timestamp);
11021102
block_table[layer_idx].push_back(block);
11031103
}
1104-
group->update_processed_tokens_num(content_len == prompt_ids.size() ? content_len - 1 : content_len);
1104+
group->update_processed_tokens_num(content_len == prompt_len ? content_len - 1 : content_len);
11051105
} else {
11061106
// restore partially filled block
11071107
for (size_t i = 1; i < m_block_size; i++) {
1108-
if (prev_iteration_content_len + i > prompt_ids.size()) {
1108+
if (prev_iteration_content_len + i > prompt_len) {
11091109
break;
11101110
}
11111111
auto hash = sequence->get_hash(prev_iteration_content_len + i);
@@ -1118,8 +1118,7 @@ class BlockManager {
11181118
block->set_timestamp(timestamp);
11191119
block_table[layer_idx].push_back(block);
11201120
}
1121-
1122-
group->update_processed_tokens_num(prev_iteration_content_len + i == prompt_ids.size() ? prev_iteration_content_len + i - 1 : prev_iteration_content_len + i);
1121+
group->update_processed_tokens_num(prev_iteration_content_len + i == prompt_len ? prev_iteration_content_len + i - 1 : prev_iteration_content_len + i);
11231122

11241123
break;
11251124
}

src/cpp/src/continuous_batching_impl.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -266,9 +266,6 @@ ContinuousBatchingPipeline::ContinuousBatchingImpl::add_request(uint64_t request
266266
SequenceGroup::Ptr sequence_group = std::make_shared<SequenceGroup>(request_id, input_ids, sampling_params, m_block_size);
267267

268268
if (m_scheduler->get_config().enable_prefix_caching) {
269-
if (m_model_input_type == ModelInputType::EMBEDDINGS) {
270-
OPENVINO_THROW("Prefix caching is not supported for VLM models.");
271-
}
272269
m_scheduler->restore_cached_blocks(sequence_group);
273270
}
274271

@@ -402,6 +399,10 @@ void ContinuousBatchingPipeline::ContinuousBatchingImpl::step() {
402399

403400
free_fork_timer.end();
404401
}
402+
403+
// append embeddings for generated tokens
404+
if (m_model_input_type == ModelInputType::EMBEDDINGS)
405+
m_model_runner->append_embeddings(m_requests, scheduler_output);
405406

406407
// notify requests dropped by handle
407408
{

src/cpp/src/model_runner.hpp

Lines changed: 59 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,6 @@ class ModelRunner {
119119
size_t total_num_tokens = 0, total_num_blocks = 0;
120120
size_t max_context_len_val = 0;
121121
size_t hidden_size = 0;
122-
size_t num_generated_ids = 0;
123122
OPENVINO_ASSERT(sequence_groups.size() > 0);
124123
auto sequence_group_type = sequence_groups[0]->get_sequence_group_type();
125124
if (sequence_group_type == SequenceGroupType::EMBEDDINGS) {
@@ -135,9 +134,6 @@ class ModelRunner {
135134
total_num_tokens += sequence_group->get_num_scheduled_tokens() * num_sequences;
136135
total_num_blocks += sequence_group->get_num_blocks() * num_sequences;
137136
max_context_len_val = std::max(max_context_len_val, sequence_group->get_context_len());
138-
for (auto seq: sequence_group->get_running_sequences()) {
139-
num_generated_ids += seq->get_generated_len();
140-
}
141137
}
142138

143139
ov::Tensor
@@ -163,27 +159,6 @@ class ModelRunner {
163159
if (sequence_group_type == SequenceGroupType::EMBEDDINGS) {
164160
OPENVINO_ASSERT(m_embedding.get_request(), "Got sequence group with embeddings, but embeddings model wasn't set.");
165161
inputs_embeds_data = inputs_embeds.data<float>();
166-
167-
ov::Tensor generated_ids = ov::Tensor(ov::element::i64, {1, num_generated_ids});
168-
int64_t *generated_ids_data = generated_ids.data<int64_t>();
169-
size_t pos = 0;
170-
for (size_t i = 0; i < num_sequence_groups; ++i) {
171-
size_t seq_group_id = scheduler_output.m_scheduled_sequence_groups_ids[i];
172-
SequenceGroup::CPtr sequence_group = sequence_groups[seq_group_id];
173-
for (auto seq: sequence_group->get_running_sequences()) {
174-
auto generated_ids = seq->get_generated_ids();
175-
for (size_t token_idx = 0; token_idx < generated_ids.size(); token_idx++) {
176-
generated_ids_data[pos] = generated_ids[token_idx];
177-
pos++;
178-
}
179-
}
180-
}
181-
if (pos > 0) {
182-
// TODO: Compute embeddings only for last generated token, while previously generated embeddings save in SequenceGroup
183-
generated_ids_embeds = m_embedding.infer(generated_ids);
184-
generated_ids_embeds_data = generated_ids_embeds.data<float>();
185-
}
186-
187162
} else if (sequence_group_type == SequenceGroupType::TOKENS) {
188163
input_ids_data = input_ids.data<int64_t>();
189164
}
@@ -234,8 +209,8 @@ class ModelRunner {
234209
sequence_group->get_prompt_ids()[position_id] :
235210
sequence->get_generated_ids()[position_id - prompt_len];
236211
} else if (sequence_group_type == SequenceGroupType::EMBEDDINGS) {
237-
auto embeds_pos = position_id < prompt_len ? 0 : hidden_size * (position_id - prompt_len);
238-
const float* src = position_id < prompt_len ? sequence_group->get_input_embeds()[position_id].data() : generated_ids_embeds_data + embeds_pos;
212+
const auto& generated_embeds = sequence->get_generated_ids_embeds();
213+
const float* src = position_id < prompt_len ? sequence_group->get_input_embeds()[position_id].data() : generated_embeds[position_id - prompt_len].data();
239214
std::copy_n(src, hidden_size, inputs_embeds_data + token_id * hidden_size);
240215
} else {
241216
OPENVINO_THROW("Unknown model inputs type.");
@@ -271,7 +246,6 @@ class ModelRunner {
271246
input_ids_data += num_scheduled_tokens;
272247
} else if (sequence_group_type == SequenceGroupType::EMBEDDINGS) {
273248
inputs_embeds_data += num_scheduled_tokens * hidden_size;
274-
generated_ids_embeds_data += sequence->get_generated_len() * hidden_size;
275249
}
276250

277251
position_ids_data += num_scheduled_tokens;
@@ -337,6 +311,63 @@ class ModelRunner {
337311
return m_request.get_tensor("logits");
338312
}
339313

314+
void append_embeddings(const std::vector<SequenceGroup::Ptr> & sequence_groups, const Scheduler::Output& scheduler_output) {
315+
size_t num_sequence_groups = scheduler_output.m_scheduled_sequence_groups_ids.size();
316+
size_t num_generated_ids_without_embeddings = 0;
317+
OPENVINO_ASSERT(sequence_groups.size() > 0);
318+
319+
// compute aggregated values
320+
for (size_t i = 0; i < num_sequence_groups; ++i) {
321+
size_t seq_group_id = scheduler_output.m_scheduled_sequence_groups_ids[i];
322+
SequenceGroup::CPtr sequence_group = sequence_groups[seq_group_id];
323+
size_t num_sequences = sequence_group->num_running_seqs();
324+
OPENVINO_ASSERT(sequence_group->get_sequence_group_type() == SequenceGroupType::EMBEDDINGS);
325+
for (auto seq: sequence_group->get_running_sequences()) {
326+
num_generated_ids_without_embeddings += seq->get_generated_len() - seq->get_generated_ids_embeds().size();
327+
}
328+
}
329+
size_t hidden_size = sequence_groups[0]->get_hidden_size();
330+
331+
ov::Tensor generated_ids_embeds;
332+
float *generated_ids_embeds_data = nullptr;
333+
334+
OPENVINO_ASSERT(m_embedding.get_request(), "Got sequence group with embeddings, but embeddings model wasn't set.");
335+
336+
ov::Tensor generated_ids = ov::Tensor(ov::element::i64, {1, num_generated_ids_without_embeddings});
337+
int64_t *generated_ids_data = generated_ids.data<int64_t>();
338+
size_t pos = 0;
339+
for (size_t i = 0; i < num_sequence_groups; ++i) {
340+
size_t seq_group_id = scheduler_output.m_scheduled_sequence_groups_ids[i];
341+
SequenceGroup::CPtr sequence_group = sequence_groups[seq_group_id];
342+
for (auto seq: sequence_group->get_running_sequences()) {
343+
const auto& generated_ids = seq->get_generated_ids();
344+
for (size_t token_idx = seq->get_generated_ids_embeds().size(); token_idx < generated_ids.size(); token_idx++) {
345+
generated_ids_data[pos] = generated_ids[token_idx];
346+
pos++;
347+
}
348+
}
349+
}
350+
if (pos > 0) {
351+
generated_ids_embeds = m_embedding.infer(generated_ids);
352+
generated_ids_embeds_data = generated_ids_embeds.data<float>();
353+
354+
for (size_t i = 0; i < num_sequence_groups; ++i) {
355+
size_t seq_group_id = scheduler_output.m_scheduled_sequence_groups_ids[i];
356+
size_t embeds_pos = 0;
357+
SequenceGroup::Ptr sequence_group = sequence_groups[seq_group_id];
358+
for (auto seq: sequence_group->get_running_sequences()) {
359+
auto generated_ids = seq->get_generated_ids();
360+
size_t new_embeds_count = seq->get_generated_len() - seq->get_generated_ids_embeds().size();
361+
ov::Coordinate start{0, embeds_pos, 0};
362+
ov::Coordinate end{1, embeds_pos + new_embeds_count, hidden_size};
363+
ov::Tensor embedding(generated_ids_embeds, start, end);
364+
seq->append_generated_ids_embeds(embedding);
365+
embeds_pos += new_embeds_count;
366+
}
367+
}
368+
}
369+
}
370+
340371
private:
341372
void _fill_indices_from_block_tables(
342373
const std::vector<std::string>& dst_tensor_names,

src/cpp/src/sequence_group.cpp

Lines changed: 43 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,20 +24,56 @@ size_t Sequence::_make_hash(size_t content_length) {
2424
content.insert(content.end(), m_prefix_hashes.begin(), m_prefix_hashes.begin() + prefix_hashes_needed_count);
2525

2626
// get tokens corresponding to current block
27-
const auto prompt_ids = sequence_group->get_prompt_ids();
28-
OPENVINO_ASSERT(content_length <= prompt_ids.size() + m_generated_ids.size());
29-
if (block_start_idx < prompt_ids.size()) {
30-
content.insert(content.end(), prompt_ids.begin() + block_start_idx, prompt_ids.begin() + std::min(prompt_ids.size(), content_length));
27+
if (sequence_group->get_sequence_group_type() == SequenceGroupType::TOKENS) {
28+
const auto prompt_ids = sequence_group->get_prompt_ids();
29+
OPENVINO_ASSERT(content_length <= prompt_ids.size() + m_generated_ids.size());
30+
if (block_start_idx < prompt_ids.size()) {
31+
content.insert(content.end(), prompt_ids.begin() + block_start_idx, prompt_ids.begin() + std::min(prompt_ids.size(), content_length));
32+
}
33+
if (content_length > prompt_ids.size()) {
34+
size_t start = block_start_idx < prompt_ids.size() ? 0 : block_start_idx - prompt_ids.size();
35+
content.insert(content.end(), m_generated_ids.begin() + start, m_generated_ids.begin() + content_length - prompt_ids.size());
36+
}
3137
}
32-
if (content_length > prompt_ids.size()) {
33-
size_t start = block_start_idx < prompt_ids.size() ? 0 : block_start_idx - prompt_ids.size();
34-
content.insert(content.end(), m_generated_ids.begin() + start, m_generated_ids.begin() + content_length - prompt_ids.size());
38+
else if (sequence_group->get_sequence_group_type() == SequenceGroupType::EMBEDDINGS) {
39+
const auto& input_embeds = sequence_group->get_input_embeds();
40+
const auto generated_embeds = m_generated_ids_embeds;
41+
OPENVINO_ASSERT(content_length <= input_embeds.size() + generated_embeds.size());
42+
43+
// get inputs embeddings
44+
if (block_start_idx < input_embeds.size()) {
45+
for (size_t idx = block_start_idx; idx < std::min(input_embeds.size(), content_length); idx++) {
46+
auto embed = _reduce_embedding(input_embeds[idx]);
47+
content.insert(content.end(), embed.begin(), embed.end());
48+
}
49+
}
50+
51+
// get generated ids embeddings
52+
if (content_length > input_embeds.size()) {
53+
size_t start = block_start_idx < input_embeds.size() ? 0 : block_start_idx - input_embeds.size();
54+
for (size_t idx = start; idx < content_length - input_embeds.size(); idx++) {
55+
auto embed = _reduce_embedding(generated_embeds[idx]);
56+
content.insert(content.end(), embed.begin(), embed.end());
57+
}
58+
}
59+
}
60+
else {
61+
OPENVINO_THROW("Hash calculation is not supported for this sequence type.");
3562
}
3663
const char* data = reinterpret_cast<const char*>(content.data());
3764
std::size_t size = content.size() * sizeof(content[0]);
3865
return std::hash<std::string_view>{}(std::string_view(data, size));
3966
}
4067

68+
std::vector<int64_t> Sequence::_reduce_embedding(const std::vector<float>& embedding) {
69+
size_t res_size = std::min((size_t)ceil(float(embedding.size()) / m_embeddings_hash_calculation_stride), m_embeddings_hash_max_num_values);
70+
std::vector<int64_t> res(res_size);
71+
for (size_t i = 0, idx=0; idx < res_size; i+= m_embeddings_hash_calculation_stride, idx++) {
72+
res[idx] = std::round(embedding[i] * m_multiplier);
73+
}
74+
return res;
75+
}
76+
4177
// Each KV block can be uniquely identified by
4278
// the tokens within the block and the tokens in the prefix before the block.
4379
// hash(prefix tokens + block tokens) <--> KV Block

src/cpp/src/sequence_group.hpp

Lines changed: 40 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -49,26 +49,38 @@ class Sequence {
4949
std::vector<int64_t> m_prefix_hashes;
5050
SequenceGroup* m_sequence_group = nullptr;
5151
static std::mutex m_counter_mutex;
52+
std::vector<std::vector<float>> m_generated_ids_embeds;
53+
SequenceGroupType m_type;
54+
size_t m_hidden_size;
55+
56+
// Embeddings hash calculation params
57+
static constexpr size_t m_embeddings_hash_max_num_values = 10; // max number of values used for embeddings hash calculation
58+
static constexpr size_t m_embeddings_hash_calculation_stride = 50; // the stride with which values are taken from embeddings vector
59+
static constexpr size_t m_multiplier = 10000; // multiplier by which float values are multiplied before conversion to size_t
5260

5361
size_t _make_hash(size_t content_length);
5462

55-
explicit Sequence(const uint64_t id) : m_grouped_id(id) {}
63+
static std::vector<int64_t> _reduce_embedding(const std::vector<float>& embedding);
64+
65+
explicit Sequence(const uint64_t id, const SequenceGroupType type, const size_t hidden_size) : m_grouped_id(id), m_type(type), m_hidden_size(hidden_size) {}
5666

5767
Sequence(const Sequence& seq, const uint64_t id) :
5868
m_generated_ids(seq.m_generated_ids),
5969
m_grouped_id(id),
6070
m_status(seq.m_status),
6171
m_cumulative_log_prob(seq.m_cumulative_log_prob),
62-
m_sequence_group(seq.m_sequence_group) {
72+
m_sequence_group(seq.m_sequence_group),
73+
m_type(seq.m_type),
74+
m_hidden_size(seq.m_hidden_size) {
6375
OPENVINO_ASSERT(seq.m_id != m_id);
6476
}
6577

6678
public:
6779
using Ptr = std::shared_ptr<Sequence>;
6880
using CPtr = std::shared_ptr<const Sequence>;
6981

70-
static Sequence::Ptr create(const uint64_t id) {
71-
return Sequence::Ptr(new Sequence(id));
82+
static Sequence::Ptr create(const uint64_t id, const SequenceGroupType type = SequenceGroupType::TOKENS, const size_t hidden_size = 0) {
83+
return Sequence::Ptr(new Sequence(id, type, hidden_size));
7284
}
7385

7486
static Sequence::Ptr fork(Sequence::CPtr sequence, const uint64_t id) {
@@ -191,6 +203,25 @@ class Sequence {
191203
m_sequence_group = sequence_group;
192204
}
193205

206+
const std::vector<std::vector<float>>& get_generated_ids_embeds() const {
207+
OPENVINO_ASSERT(m_type == ov::genai::SequenceGroupType::EMBEDDINGS);
208+
return m_generated_ids_embeds;
209+
}
210+
211+
void append_generated_ids_embeds(ov::Tensor generated_ids_embeds) {
212+
OPENVINO_ASSERT(m_type == SequenceGroupType::EMBEDDINGS);
213+
auto embeds_count = generated_ids_embeds.get_shape()[1];
214+
OPENVINO_ASSERT(m_hidden_size == generated_ids_embeds.get_shape()[2]);
215+
216+
auto current_embeds_size = m_generated_ids_embeds.size();
217+
for (size_t i = current_embeds_size, idx = 0; i < current_embeds_size + embeds_count; i++, idx++) {
218+
m_generated_ids_embeds.emplace_back(std::vector<float>());
219+
m_generated_ids_embeds[i].resize(m_hidden_size);
220+
std::copy_n(generated_ids_embeds.data<float>() + idx * m_hidden_size, m_hidden_size, m_generated_ids_embeds[i].begin());
221+
222+
}
223+
}
224+
194225
std::shared_ptr<SequenceGroup> get_sequence_group_ptr() const;
195226

196227
// Each KV block can be uniquely identified by
@@ -261,6 +292,7 @@ class SequenceGroup : public std::enable_shared_from_this<SequenceGroup> {
261292
: SequenceGroup(request_id, sampling_params, block_size) {
262293

263294
size_t prompt_len;
295+
size_t hidden_size = 0;
264296
if (input_ids.get_shape().size() > 1) {
265297
prompt_len = input_ids.get_shape()[1];
266298
} else {
@@ -273,11 +305,11 @@ class SequenceGroup : public std::enable_shared_from_this<SequenceGroup> {
273305
std::copy_n(input_ids.data<int64_t>(), prompt_len, m_prompt_ids.begin());
274306
m_sequence_group_type = SequenceGroupType::TOKENS;
275307
} else if (input_ids.get_element_type() == ov::element::f32) {
276-
auto embeds_len = input_ids.get_shape()[2];
308+
hidden_size = input_ids.get_shape()[2];
277309
m_input_embeds.resize(prompt_len);
278310
for (size_t i = 0; i < prompt_len; i++) {
279-
m_input_embeds[i].resize(embeds_len);
280-
std::copy_n(input_ids.data<float>() + i * embeds_len, embeds_len, m_input_embeds[i].begin());
311+
m_input_embeds[i].resize(hidden_size);
312+
std::copy_n(input_ids.data<float>() + i * hidden_size, hidden_size, m_input_embeds[i].begin());
281313
}
282314
m_sequence_group_type = SequenceGroupType::EMBEDDINGS;
283315
}
@@ -287,7 +319,7 @@ class SequenceGroup : public std::enable_shared_from_this<SequenceGroup> {
287319
m_prompt_log_probs.reserve(prompt_len);
288320

289321
// create a single sequence
290-
add_sequence(Sequence::create(m_next_sequence_id++));
322+
add_sequence(Sequence::create(m_next_sequence_id++, m_sequence_group_type, hidden_size));
291323
}
292324

293325
void add_sequence(const Sequence::Ptr & sequence) {

0 commit comments

Comments
 (0)