Skip to content

Commit f7ee437

Browse files
[NPUW] Add new LiftGather pattern (#30393)
1 parent 4c666c8 commit f7ee437

File tree

3 files changed

+40
-0
lines changed

3 files changed

+40
-0
lines changed

src/plugins/intel_npu/src/plugin/npuw/compiled_model.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ void pre_load_transform(const std::shared_ptr<ov::Model>& model, const ov::AnyMa
107107
rewr.add_matcher<ov::npuw::patterns::opt::DQLiftGatherAsymCW>();
108108
rewr.add_matcher<ov::npuw::patterns::opt::DQLiftGatherSymCW>();
109109
rewr.add_matcher<ov::npuw::patterns::opt::DQLiftGatherSymGQ>();
110+
rewr.add_matcher<ov::npuw::patterns::opt::DQLiftGatherCW>();
110111
rewr.run_on_model(model);
111112
}
112113

src/plugins/intel_npu/src/plugin/npuw/partitioning/patterns/opt.cpp

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1045,6 +1045,39 @@ DQLiftGatherSymCW::DQLiftGatherSymCW() {
10451045
register_matcher(std::make_shared<opp::Matcher>(gather, "DQGatherSymCW"), std::move(callback));
10461046
}
10471047

1048+
// FIXME: this is mostly a workaround pattern for the partitioning
1049+
DQLiftGatherCW::DQLiftGatherCW() {
1050+
auto qweight = opp::wrap_type<ov::op::v0::Constant>();
1051+
auto qcvtw = opp::wrap_type<ov::op::v0::Convert>({qweight});
1052+
1053+
auto pids = opp::wrap_type<ov::op::v0::Parameter>();
1054+
auto cvtids = opp::optional<ov::op::v0::Convert>({pids->output(0)});
1055+
auto gather = opp::wrap_type<ov::op::v8::Gather>({qcvtw, cvtids, opp::any_input()});
1056+
1057+
// Note: Use [=] to make sure the above objects stay alive in the callback
1058+
auto callback = [=](ov::pass::pattern::Matcher& m) {
1059+
auto& node_to_output = m.get_pattern_value_map();
1060+
1061+
auto matched_out_w = node_to_output.at(qweight);
1062+
auto matched_out_ids = uat::_(node_to_output).at_or_at(cvtids, pids);
1063+
const auto& matched_out_gather = node_to_output.at(gather);
1064+
1065+
// Create new gathers on W, connect respectively
1066+
auto new_cvt_w = std::make_shared<ov::op::v0::Convert>(matched_out_w, ov::element::f16);
1067+
auto gather_c = std::make_shared<ov::op::v0::Constant>(ov::element::i32, ov::Shape{}, 0);
1068+
auto new_g_w = std::make_shared<ov::op::v8::Gather>(new_cvt_w, matched_out_ids, gather_c);
1069+
1070+
auto new_out = std::make_shared<ov::op::v0::Convert>(new_g_w, ov::element::f32);
1071+
// Reconnect old gather readers to the new Convert
1072+
for (auto&& r : matched_out_gather.get_target_inputs()) {
1073+
r.replace_source_output(new_out);
1074+
}
1075+
1076+
return true; // root was changed
1077+
};
1078+
register_matcher(std::make_shared<opp::Matcher>(gather, "DQGatherCW"), std::move(callback));
1079+
}
1080+
10481081
// Identify a Gather+DQ Sym GQ MatMul pattern, lift Gather up
10491082
// Note(1): this pattern is applied on the full model before any partitioning
10501083
// Note(2): here's a difference, the new lifted Gathers stay behind Convert(W) & Convert(S)

src/plugins/intel_npu/src/plugin/npuw/partitioning/patterns/opt.hpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,12 @@ class DQLiftGatherSymGQ : public ov::pass::MatcherPass {
120120
DQLiftGatherSymGQ();
121121
};
122122

123+
class DQLiftGatherCW : public ov::pass::MatcherPass {
124+
public:
125+
OPENVINO_MATCHER_PASS_RTTI("npuw::patterns::opt::DQLiftGatherCW");
126+
DQLiftGatherCW();
127+
};
128+
123129
// Head vocab unpacks
124130

125131
class DQUnpackDictGatheru : public ov::pass::MatcherPass {

0 commit comments

Comments
 (0)