Skip to content

Commit db99d2b

Browse files
Avoid RoPE on NPU for Gemma Models (#32093)
Computing RoPE for Gemma Like Models in NPU has some accuracy issues. This patch adds a matcher pass in avoid.cpp of npuw pipeline to create an NPU avoid pattern for RoPE computation in Gemma like models. This is added as an avoid pattern and to enable this the below flag : "NPUW_ONLINE_AVOID": "P:GemmaRoPE/NPU" must be added to config. To recognize this pattern check is added to "snapshot.cpp". Signed-off-by: Ghosh, Tamoghna <[email protected]> --------- Signed-off-by: Ghosh, Tamoghna <[email protected]> Signed-off-by: tamoghna <[email protected]> Co-authored-by: Alexey Smirnov <[email protected]>
1 parent 8ba9a4b commit db99d2b

File tree

3 files changed

+43
-4
lines changed

3 files changed

+43
-4
lines changed

src/plugins/intel_npu/src/plugin/npuw/partitioning/online/snapshot.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -451,18 +451,22 @@ void Snapshot::earlyAvoids() {
451451
}
452452
case PatternType::PATTERN: {
453453
// FIXME: refactor as more patterns are supported
454-
if (avoid.pattern != "RMSNorm" && avoid.pattern != "SinCos") {
455-
LOG_WARN(
456-
"OPENVINO_NPUW_AVOID only supports RMSNorm and SinCos as patterns (don't confuse with operations)."
457-
<< " Avoid pattern " << avoid.pattern << " is skipped!");
454+
if (avoid.pattern != "RMSNorm" && avoid.pattern != "SinCos" && avoid.pattern != "GemmaRoPE") {
455+
LOG_WARN("OPENVINO_NPUW_AVOID only supports RMSNorm, SinCos and GemmaRoPE as patterns "
456+
"(don't confuse with operations). "
457+
"Avoid pattern "
458+
<< avoid.pattern << " is skipped!");
458459
break;
459460
}
460461
handle_patterns = true;
461462
if (avoid.pattern == "RMSNorm") {
462463
rewr.add_matcher<ov::npuw::patterns::avoid::RMSNorm>(shared_from_this(), avoid.device);
463464
} else if (avoid.pattern == "SinCos") {
464465
rewr.add_matcher<ov::npuw::patterns::avoid::SinCos>(shared_from_this(), avoid.device);
466+
} else if (avoid.pattern == "GemmaRoPE") {
467+
rewr.add_matcher<ov::npuw::patterns::avoid::GemmaRoPE>(shared_from_this(), avoid.device);
465468
}
469+
466470
break;
467471
}
468472
}

src/plugins/intel_npu/src/plugin/npuw/partitioning/patterns/avoid.cpp

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,35 @@ SinCos::SinCos(const std::shared_ptr<ov::npuw::online::Snapshot>& snapshot, cons
105105
};
106106
register_matcher(std::make_shared<opp::Matcher>(sin_cos, "TagSinCos"), std::move(callback));
107107
}
108+
GemmaRoPE::GemmaRoPE(const std::shared_ptr<ov::npuw::online::Snapshot>& snapshot, const std::string& avoid_device) {
109+
auto power = opp::wrap_type<ov::op::v1::Power>({opp::any_input(), opp::any_input()});
110+
auto unsqueeze1 = opp::wrap_type<ov::op::v0::Unsqueeze>({power, opp::wrap_type<ov::op::v0::Constant>()});
111+
auto unsqueeze2 = opp::wrap_type<ov::op::v0::Unsqueeze>({unsqueeze1, opp::wrap_type<ov::op::v0::Constant>()});
112+
auto divide = opp::wrap_type<ov::op::v1::Divide>({opp::wrap_type<ov::op::v0::Convert>(), unsqueeze2});
113+
auto unsqueeze3 = opp::wrap_type<ov::op::v0::Unsqueeze>({divide, opp::wrap_type<ov::op::v0::Constant>()});
114+
auto sin_cos = opp::wrap_type<ov::op::v0::Sin, ov::op::v0::Cos>({unsqueeze3});
115+
auto node_to_gptr = snapshot->getNodeToGroupMap();
116+
117+
auto callback = [=](ov::pass::pattern::Matcher& m) {
118+
auto& node_to_output = m.get_pattern_value_map();
119+
120+
auto matched_power = node_to_output.at(power).get_node_shared_ptr();
121+
auto matched_unsqueeze1 = node_to_output.at(unsqueeze1).get_node_shared_ptr();
122+
auto matched_unsqueeze2 = node_to_output.at(unsqueeze2).get_node_shared_ptr();
123+
auto matched_divide = node_to_output.at(divide).get_node_shared_ptr();
124+
auto matched_unsqueeze3 = node_to_output.at(unsqueeze3).get_node_shared_ptr();
125+
auto matched_sin_cos = node_to_output.at(sin_cos).get_node_shared_ptr();
126+
127+
node_to_gptr->at(matched_power)->avoid(avoid_device);
128+
node_to_gptr->at(matched_unsqueeze1)->avoid(avoid_device);
129+
node_to_gptr->at(matched_unsqueeze2)->avoid(avoid_device);
130+
node_to_gptr->at(matched_divide)->avoid(avoid_device);
131+
node_to_gptr->at(matched_unsqueeze3)->avoid(avoid_device);
132+
node_to_gptr->at(matched_sin_cos)->avoid(avoid_device);
133+
return false;
134+
};
135+
register_matcher(std::make_shared<opp::Matcher>(sin_cos, "TagGemmaRoPE"), std::move(callback));
136+
}
108137
} // namespace avoid
109138
} // namespace patterns
110139
} // namespace npuw

src/plugins/intel_npu/src/plugin/npuw/partitioning/patterns/avoid.hpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,12 @@ class SinCos : public ov::pass::MatcherPass {
3434
SinCos(const std::shared_ptr<ov::npuw::online::Snapshot>& snapshot, const std::string& avoid_device);
3535
};
3636

37+
class GemmaRoPE : public ov::pass::MatcherPass {
38+
public:
39+
OPENVINO_MATCHER_PASS_RTTI("npuw::patterns::avoid::GemmaRoPE");
40+
GemmaRoPE(const std::shared_ptr<ov::npuw::online::Snapshot>& snapshot, const std::string& avoid_device);
41+
};
42+
3743
} // namespace avoid
3844
} // namespace patterns
3945
} // namespace npuw

0 commit comments

Comments
 (0)