Skip to content

Commit 81b1843

Browse files
[NPUW] Initial support for Gemma3 on NPU (#32102)
* Fixed logic of switching between prefill and generate stage for gemma3 * Add support for new input - `token_type_ids` * Add GemmaSlidingMask which search for sliding mask pattern, replaces it with input and extracts sliding window size from it. Gemma3 explanation: https://developers.googleblog.com/en/gemma-explained-whats-new-in-gemma-3/ Targeted sliding mask subgraph: <img width="325" height="456" alt="sliding_mask_subgraph" src="https://github.com/user-attachments/assets/86c6fb7c-7bab-407e-b277-73ecdc07fcf2" /> It is basically implements the following check: y - sliding_window_size < x <= y Related OpenVino.Genai PR: [[NPUW] Disable chunking and F16IC for gemma3 as they are not supported currently](openvinotoolkit/openvino.genai#2800)
1 parent f0e1e64 commit 81b1843

File tree

5 files changed

+229
-15
lines changed

5 files changed

+229
-15
lines changed

src/plugins/intel_npu/src/plugin/npuw/llm_compiled_model.cpp

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,95 @@ class GroupQueryAttentionDecomposition : public ov::pass::MatcherPass {
303303
}
304304
};
305305

306+
class GemmaSlidingMask : public ov::pass::MatcherPass {
307+
public:
308+
OPENVINO_MATCHER_PASS_RTTI("npuw::LLMCompiledModel::GemmaSlidingMask");
309+
310+
struct Result {
311+
Result() = default;
312+
313+
bool found = false;
314+
int32_t window_size = 0;
315+
std::shared_ptr<ov::op::v0::Parameter> mask_input;
316+
};
317+
318+
explicit GemmaSlidingMask(Result* result) {
319+
// Searching for gemma sliding mask pattern and replace it's output
320+
// with Paramater of the same size and type.
321+
/* -\
322+
range_w -> unsqueeze -> unsqueeze -> unsqueeze1 -> convert -\/ => LessEqual -\
323+
\ /\-/ \
324+
/-----\----/ =>BWAnd_res
325+
/ \ /
326+
range_h -> unsqueeze -> unsqueeze -> unsqueeze2 -> add -=> Greater -> BWAnd --/
327+
const (-window_size) ----/
328+
*/
329+
// Basically this subgrapgh is doing the following:
330+
// range_w is range (0, ..., width - 1) (probably + something)
331+
// renge_h is range (0, ..., height - 1) (probably + something)
332+
// And then doing the following check:
333+
// y - window_size < x <= y
334+
// Producing squared sliding mask:
335+
// 1 0 0 0 0 0
336+
// 1 1 0 0 0 0
337+
// 1 1 1 0 0 0
338+
// 0 1 1 1 0 0
339+
// 0 0 1 1 1 0
340+
// 0 0 0 1 1 1
341+
//
342+
// Please also note, that sliding windows size is stored as negative value and the
343+
// subgraph is actually doing:
344+
// y + (negative)window_size < x <= y
345+
346+
auto range_sequence = [&]() {
347+
auto range = opp::wrap_type<ov::op::v4::Range>({opp::any_input(), opp::any_input(), opp::any_input()});
348+
auto unsqueeze1 = opp::wrap_type<ov::op::v0::Unsqueeze>({range, opp::any_input()});
349+
auto unsqueeze2 = opp::wrap_type<ov::op::v0::Unsqueeze>({unsqueeze1, opp::any_input()});
350+
auto unsqueeze3 = opp::wrap_type<ov::op::v0::Unsqueeze>({unsqueeze2, opp::any_input()});
351+
352+
return unsqueeze3;
353+
};
354+
355+
auto unsqueeze1 = range_sequence();
356+
auto convert = opp::wrap_type<ov::op::v0::Convert>({unsqueeze1});
357+
auto unsqueeze2 = range_sequence();
358+
auto window_size = opp::wrap_type<ov::op::v0::Constant>();
359+
auto add = opp::wrap_type<ov::op::v1::Add>({unsqueeze2, window_size});
360+
auto greater = opp::wrap_type<ov::op::v1::Greater>({convert, add});
361+
auto bwand = opp::wrap_type<ov::op::v13::BitwiseAnd>({opp::any_input(), greater});
362+
auto less_equal = opp::wrap_type<ov::op::v1::LessEqual>({convert, unsqueeze2});
363+
auto bwand_res = opp::wrap_type<ov::op::v13::BitwiseAnd>({bwand, less_equal});
364+
365+
auto callback = [=](ov::pass::pattern::Matcher& m) {
366+
auto& node_to_output = m.get_pattern_value_map();
367+
auto* bwand_matched_node = node_to_output.at(bwand_res).get_node();
368+
auto* window_size_node = node_to_output.at(window_size).get_node();
369+
auto output = bwand_matched_node->output(0);
370+
auto target_inputs = output.get_target_inputs();
371+
372+
auto* window_size_constant = static_cast<ov::op::v0::Constant*>(window_size_node);
373+
OPENVINO_ASSERT(window_size_constant->get_output_size() == 1,
374+
"Sliding window size constant must be of size 1, but got " +
375+
std::to_string(window_size_constant->get_output_size()));
376+
OPENVINO_ASSERT(!result->found, "Second gemma sliding mask pattern found, what is unexpected!");
377+
378+
auto input = std::make_shared<ov::op::v0::Parameter>(output.get_element_type(), output.get_partial_shape());
379+
input->set_friendly_name(ov::npuw::LLMInferRequest::layer_names::gemma_sliding_mask);
380+
output.replace(input->output(0));
381+
382+
auto window_size_vec = window_size_constant->cast_vector<int32_t>(1);
383+
384+
result->found = true;
385+
// since we are doing Add and need to do subtract window size is stored as negative value
386+
result->window_size = -window_size_vec[0];
387+
result->mask_input = input;
388+
389+
return true;
390+
};
391+
register_matcher(std::make_shared<opp::Matcher>(bwand_res, "GemmaSlidingMask"), std::move(callback));
392+
}
393+
};
394+
306395
namespace {
307396
uint32_t align_to(uint32_t value, uint32_t alignment) {
308397
return (value + alignment - 1) & ~(alignment - 1);
@@ -468,6 +557,8 @@ void reshape_to_static(std::shared_ptr<ov::Model> model,
468557
ov::PartialShape new_shape;
469558
if (input_name.find("input_ids") != std::string::npos) {
470559
new_shape = ov::PartialShape({1, input_size});
560+
} else if (input_name.find("token_type_ids") != std::string::npos) {
561+
new_shape = ov::PartialShape({1, input_size});
471562
} else if (input_name.find("inputs_embeds") != std::string::npos) {
472563
// NB: VLMs case, model accepts inputs_embeds[BATCH, SEQ_LEN, EMB_SIZE]
473564
NPUW_ASSERT(input.get_partial_shape().size() == 3u);
@@ -785,6 +876,41 @@ void ov::npuw::LLMCompiledModel::convert_stateful_lora_to_stateless(std::shared_
785876
model->add_parameters(new_parameters);
786877
}
787878

879+
void ov::npuw::LLMCompiledModel::gemma_transformations(const std::shared_ptr<ov::Model>& model) {
880+
// For now only do transformations for gemma3 which has token_type_ids input.
881+
bool token_type_ids_found = false;
882+
for (const auto& input : model->inputs()) {
883+
const auto& input_name = input.get_any_name();
884+
if (input_name.find("token_type_ids") != std::string::npos) {
885+
token_type_ids_found = true;
886+
break;
887+
}
888+
}
889+
890+
if (token_type_ids_found) {
891+
ov::pass::GraphRewrite rewr;
892+
auto RewrRes = std::make_unique<GemmaSlidingMask::Result>();
893+
rewr.add_matcher<GemmaSlidingMask>(RewrRes.get());
894+
rewr.run_on_model(model);
895+
896+
if (RewrRes->found) {
897+
OPENVINO_ASSERT(
898+
RewrRes->window_size > 0,
899+
"Gemma sliding window size must be strictly positive, but got " + std::to_string(RewrRes->window_size));
900+
901+
m_gemma_sliding_window_size = RewrRes->window_size;
902+
auto mask_input = RewrRes->mask_input;
903+
model->add_parameters({mask_input});
904+
for (auto&& input : model->inputs()) {
905+
if (input.get_node() == mask_input.get()) {
906+
input.set_names({mask_input->get_friendly_name()});
907+
}
908+
}
909+
model->validate_nodes_and_infer_types();
910+
}
911+
}
912+
}
913+
788914
ov::npuw::LLMCompiledModel::LLMCompiledModel(const std::shared_ptr<ov::Model>& model,
789915
const std::shared_ptr<const ov::IPlugin>& plugin,
790916
const ov::AnyMap& properties)
@@ -910,6 +1036,8 @@ ov::npuw::LLMCompiledModel::LLMCompiledModel(const std::shared_ptr<ov::Model>& m
9101036
m_kvcache_desc.total_size,
9111037
axes,
9121038
m_max_lora_rank);
1039+
gemma_transformations(kvcache_model);
1040+
9131041
if (lm_head_model) {
9141042
LOG_DEBUG("Shared LM head: slice the prefill output");
9151043
// KVCache model is already reshaped to [1, max_generation_token_len, embed size],
@@ -1147,6 +1275,7 @@ void ov::npuw::LLMCompiledModel::serialize(std::ostream& stream, const ov::npuw:
11471275
write(model_stream, m_prefill_chunk_size);
11481276
write(model_stream, m_use_chunk_prefill);
11491277
write(model_stream, m_max_lora_rank);
1278+
write(model_stream, m_gemma_sliding_window_size);
11501279

11511280
// Write config
11521281
write(model_stream, m_cfg);
@@ -1357,6 +1486,7 @@ std::shared_ptr<ov::npuw::LLMCompiledModel> ov::npuw::LLMCompiledModel::deserial
13571486
read(model_stream, compiled->m_prefill_chunk_size);
13581487
read(model_stream, compiled->m_use_chunk_prefill);
13591488
read(model_stream, compiled->m_max_lora_rank);
1489+
read(model_stream, compiled->m_gemma_sliding_window_size);
13601490

13611491
// Deserialize config
13621492
read(model_stream, compiled->m_cfg);

src/plugins/intel_npu/src/plugin/npuw/llm_compiled_model.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,9 @@ class LLMCompiledModel : public ov::npuw::ICompiledModel {
8080
// Support LoRA
8181
void convert_stateful_lora_to_stateless(std::shared_ptr<ov::Model>& model);
8282
uint32_t m_max_lora_rank = 32;
83+
84+
void gemma_transformations(const std::shared_ptr<ov::Model>& model);
85+
int32_t m_gemma_sliding_window_size = 0;
8386
};
8487

8588
} // namespace npuw

src/plugins/intel_npu/src/plugin/npuw/llm_infer_request.cpp

Lines changed: 80 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,27 @@ std::pair<uint32_t, uint32_t> get_lora_dims_by_name(const std::string& state_nam
274274
return std::make_pair(low_rank_dim, full_rank_dim);
275275
}
276276

277+
void copy_to_right(const ov::SoPtr<ov::ITensor>& src, const ov::SoPtr<ov::ITensor>& dst) {
278+
OPENVINO_ASSERT(src->get_byte_size() <= dst->get_byte_size());
279+
std::copy_n(reinterpret_cast<uint8_t*>(src->data()),
280+
src->get_byte_size(),
281+
reinterpret_cast<uint8_t*>(dst->data()) + dst->get_byte_size() - src->get_byte_size());
282+
}
283+
284+
void fill_sliding_mask(const ov::SoPtr<ov::ITensor>& mask, int64_t curr_pos, int64_t window_size) {
285+
auto start = curr_pos - window_size;
286+
auto end = curr_pos;
287+
288+
auto* mask_data = mask->data<bool>();
289+
for (int64_t i = 0; i < static_cast<int64_t>(mask->get_size()); ++i) {
290+
// Unlike original subgraph which do i <= end we are excluding end
291+
// as it is a new token and is located in last position of mask buffer
292+
mask_data[i] = i > start && i < end;
293+
}
294+
295+
mask_data[mask->get_size() - 1] = true;
296+
}
297+
277298
constexpr uint32_t INPUT_IDS_SEQ_LEN_DIM = 1;
278299

279300
constexpr std::size_t kStartOutputKVCacheLayers = 1;
@@ -380,6 +401,7 @@ ov::npuw::LLMInferRequest::LLMInferRequest(const std::shared_ptr<ov::npuw::LLMCo
380401
}
381402

382403
m_generate_initialized = false;
404+
m_gemma_sliding_window_size = compiled_model->m_gemma_sliding_window_size;
383405
}
384406

385407
void ov::npuw::LLMInferRequest::init_tensor(const ov::Output<const ov::Node>& port) {
@@ -498,6 +520,10 @@ void ov::npuw::LLMInferRequest::apply_lora() {
498520

499521
void ov::npuw::LLMInferRequest::prepare_for_new_conversation() {
500522
fill_tensor_bytes(m_prefill_request->get_tensor(m_prefill_in_ports.at(m_input_ids_name)), 0u);
523+
if (auto type_ids_port = m_prefill_in_ports.find(layer_names::token_type_ids);
524+
type_ids_port != m_prefill_in_ports.end()) {
525+
fill_tensor_bytes(m_prefill_request->get_tensor(type_ids_port->second), 0u);
526+
}
501527
fill_tensor<int64_t>(m_prefill_request->get_tensor(m_prefill_in_ports.at(layer_names::attention_mask)), 0);
502528
fill_tensor<int64_t>(m_prefill_request->get_tensor(m_prefill_in_ports.at(layer_names::position_ids)), 0);
503529
m_npuw_llm_compiled_model->m_kvcache_desc.num_stored_tokens = 0u;
@@ -586,8 +612,8 @@ void ov::npuw::LLMInferRequest::copy_kvcache() {
586612

587613
void ov::npuw::LLMInferRequest::update_kvcache_for(
588614
std::shared_ptr<ov::IAsyncInferRequest> request,
589-
std::unordered_map<std::string, ov::Output<const ov::Node>> in_ports,
590-
std::unordered_map<std::string, ov::Output<const ov::Node>> out_ports,
615+
const std::unordered_map<std::string, ov::Output<const ov::Node>>& in_ports,
616+
const std::unordered_map<std::string, ov::Output<const ov::Node>>& out_ports,
591617
uint32_t num_tokens,
592618
bool v_transposed) {
593619
LOG_DEBUG("Store computed key and values for passed number of tokens in the input kv-cache"
@@ -750,7 +776,8 @@ void ov::npuw::LLMInferRequest::infer_chunked_prefill(ov::SoPtr<ov::ITensor> inp
750776

751777
void ov::npuw::LLMInferRequest::infer_whole_prefill(ov::SoPtr<ov::ITensor> input_ids,
752778
ov::SoPtr<ov::ITensor> attention_mask,
753-
ov::SoPtr<ov::ITensor> position_ids) {
779+
ov::SoPtr<ov::ITensor> position_ids,
780+
ov::SoPtr<ov::ITensor> token_type_ids) {
754781
LOG_DEBUG("Calling inference for prefill model in a single launch.");
755782
LOG_BLOCK();
756783

@@ -767,6 +794,13 @@ void ov::npuw::LLMInferRequest::infer_whole_prefill(ov::SoPtr<ov::ITensor> input
767794
attention_mask->get_size(),
768795
padded_attention_mask->data<int64_t>() + padded_attention_mask->get_size() - attention_mask->get_size());
769796

797+
if (token_type_ids) {
798+
auto padded_token_type_ids = m_prefill_request->get_tensor(m_prefill_in_ports.at(layer_names::token_type_ids));
799+
800+
std::fill_n(reinterpret_cast<uint8_t*>(padded_token_type_ids->data()), token_type_ids->get_byte_size(), 0);
801+
copy_to_right(token_type_ids, padded_token_type_ids);
802+
}
803+
770804
auto padded_position_ids = m_prefill_request->get_tensor(m_prefill_in_ports.at(layer_names::position_ids));
771805
pad_position_ids(padded_position_ids, position_ids);
772806

@@ -779,7 +813,8 @@ void ov::npuw::LLMInferRequest::infer_whole_prefill(ov::SoPtr<ov::ITensor> input
779813

780814
void ov::npuw::LLMInferRequest::infer_prefill(ov::SoPtr<ov::ITensor> input_ids,
781815
ov::SoPtr<ov::ITensor> attention_mask,
782-
ov::SoPtr<ov::ITensor> position_ids) {
816+
ov::SoPtr<ov::ITensor> position_ids,
817+
ov::SoPtr<ov::ITensor> token_type_ids) {
783818
LOG_DEBUG("Calling inference for prefill model...");
784819
LOG_BLOCK();
785820

@@ -795,9 +830,12 @@ void ov::npuw::LLMInferRequest::infer_prefill(ov::SoPtr<ov::ITensor> input_ids,
795830

796831
const bool use_chunk_prefill = m_npuw_llm_compiled_model->m_use_chunk_prefill;
797832
if (use_chunk_prefill) {
833+
OPENVINO_ASSERT(m_gemma_sliding_window_size == 0,
834+
"Chunking is not implemented for Gemma model family yet. "
835+
"Please use set NPUW_LLM_PREFILL_HINT to 'STATIC'");
798836
infer_chunked_prefill(input_ids, attention_mask, position_ids);
799837
} else {
800-
infer_whole_prefill(input_ids, attention_mask, position_ids);
838+
infer_whole_prefill(input_ids, attention_mask, position_ids, token_type_ids);
801839
}
802840

803841
if (m_lm_head_request) {
@@ -815,7 +853,8 @@ void ov::npuw::LLMInferRequest::infer_prefill(ov::SoPtr<ov::ITensor> input_ids,
815853

816854
void ov::npuw::LLMInferRequest::infer_generate(ov::SoPtr<ov::ITensor> input_ids,
817855
ov::SoPtr<ov::ITensor> attention_mask,
818-
ov::SoPtr<ov::ITensor> position_ids) {
856+
ov::SoPtr<ov::ITensor> position_ids,
857+
ov::SoPtr<ov::ITensor> token_type_ids) {
819858
LOG_DEBUG("Calling inference for generate model...");
820859
LOG_BLOCK();
821860
auto& kvcache_desc = m_npuw_llm_compiled_model->m_kvcache_desc;
@@ -834,6 +873,9 @@ void ov::npuw::LLMInferRequest::infer_generate(ov::SoPtr<ov::ITensor> input_ids,
834873
fill_tensor_bytes(m_kvcache_request->get_tensor(m_kvcache_in_ports.at(m_input_ids_name)), 0u);
835874
fill_tensor<int64_t>(m_kvcache_request->get_tensor(m_kvcache_in_ports.at(layer_names::attention_mask)), 0);
836875
fill_tensor<int64_t>(m_kvcache_request->get_tensor(m_kvcache_in_ports.at(layer_names::position_ids)), 0);
876+
if (token_type_ids) {
877+
fill_tensor<int64_t>(m_kvcache_request->get_tensor(m_kvcache_in_ports.at(layer_names::token_type_ids)), 0);
878+
}
837879
m_generate_initialized = true;
838880
}
839881

@@ -842,6 +884,14 @@ void ov::npuw::LLMInferRequest::infer_generate(ov::SoPtr<ov::ITensor> input_ids,
842884
OPENVINO_THROW("KV-Cache is full.");
843885
}
844886

887+
if (auto sliding_mask_port = m_kvcache_in_ports.find(layer_names::gemma_sliding_mask);
888+
sliding_mask_port != m_kvcache_in_ports.end()) {
889+
// TODO: Fill once and update on each iteration instead
890+
fill_sliding_mask(m_kvcache_request->get_tensor(sliding_mask_port->second),
891+
kvcache_desc.num_stored_tokens + input_tokens_len,
892+
m_gemma_sliding_window_size);
893+
}
894+
845895
// FIXME: these tensors should be shared between the parent & child models
846896
// NB: input_ids can be either fp32(VLM) or i64(LLM)
847897
auto kv_input_ids = m_kvcache_request->get_tensor(m_kvcache_in_ports.at(m_input_ids_name));
@@ -854,6 +904,11 @@ void ov::npuw::LLMInferRequest::infer_generate(ov::SoPtr<ov::ITensor> input_ids,
854904
input_ids->get_byte_size(),
855905
reinterpret_cast<uint8_t*>(kv_input_ids->data()) + kv_input_ids->get_byte_size() - input_ids->get_byte_size());
856906

907+
if (token_type_ids) {
908+
auto kv_token_type_ids = m_kvcache_request->get_tensor(m_kvcache_in_ports.at(layer_names::token_type_ids));
909+
copy_to_right(token_type_ids, kv_token_type_ids);
910+
}
911+
857912
// NOTE: Attention mask pattern for generate model requires the set of "1"
858913
// units of length of the current prompt on the right (for present
859914
// kv layers) and the set of "1" units of number of previously calculated
@@ -912,12 +967,28 @@ void ov::npuw::LLMInferRequest::infer() {
912967
// FIXME: position_ids might be optional for some models!
913968
auto position_ids = get_tensor(find_port_by_name(inputs, layer_names::position_ids).value());
914969

970+
auto token_type_ids = ov::npuw::util::TensorPtr();
971+
972+
if (auto type_ids_port = find_port_by_name(inputs, layer_names::token_type_ids); type_ids_port.has_value()) {
973+
token_type_ids = get_tensor(type_ids_port.value());
974+
}
975+
915976
// NB: For VLM, the "inputs_embeds" contains float values (embeddings)
916977
OPENVINO_ASSERT(ov::element::f32 == input_ids->get_element_type() ||
917978
ov::element::i64 == input_ids->get_element_type());
918979
OPENVINO_ASSERT(ov::element::i64 == attention_mask->get_element_type());
919980
OPENVINO_ASSERT(ov::element::i64 == position_ids->get_element_type());
920981

982+
if (m_first_run) {
983+
// Most of the models have position_ids->data<int64_t>()[0] == 0 for the first infer
984+
// But gemma3 has it == 1
985+
// We need to store original first position id in order to distinguish between prefill and generate stage
986+
// While in most of the cases we need to do prefill only once, it is not true for chat mode
987+
// where we need to do prefill on each user input.
988+
m_first_position_id = position_ids->data<int64_t>()[0];
989+
m_first_run = false;
990+
}
991+
921992
// NB: Check the sequence length provided for input_ids
922993
// and start position idx in order to distinguish prefill
923994
// and generate stages.
@@ -940,11 +1011,11 @@ void ov::npuw::LLMInferRequest::infer() {
9401011
// The outcome of two items is that prefill and generate stages
9411012
// can be safely differentiated by start position id for
9421013
// both main and draft models.
943-
if (input_ids->get_shape()[INPUT_IDS_SEQ_LEN_DIM] > 1 && position_ids->data<int64_t>()[0] == 0) {
944-
infer_prefill(input_ids, attention_mask, position_ids);
1014+
if (input_ids->get_shape()[INPUT_IDS_SEQ_LEN_DIM] > 1 && position_ids->data<int64_t>()[0] == m_first_position_id) {
1015+
infer_prefill(input_ids, attention_mask, position_ids, token_type_ids);
9451016
} else {
9461017
trim_kvcache_for_speculative_decoding(position_ids);
947-
infer_generate(input_ids, attention_mask, position_ids);
1018+
infer_generate(input_ids, attention_mask, position_ids, token_type_ids);
9481019
}
9491020
}
9501021

0 commit comments

Comments
 (0)