Skip to content

Commit 3d219ae

Browse files
Expand CWAI to Keep the Weight scales as Constants (#32232)
### Details: Performance impact of ~15ms per chunk (16 total chunks per inference) is seen, netting a E2E inference runtime reduction of ~240ms. This patch expands CWAI3 to include additional generalized pattern matching for keeping weight scales as const. Performance benefit is seen, outlined above. A regression is introduced with this patch for gaussian_topk_sub and general performance for some ops seemed less efficient when doing a FW Trace comparison. Additional savings can be brought in once resolved, tracked in Ticket bellow. ### Tickets: - [EISW-183592](https://jira.devtools.intel.com/browse/EISW-183592) - Bug this PR is related to. - [EISW-185933](https://jira.devtools.intel.com/browse/EISW-185933) - Bug that PR introduces. Performance benefit is still seen, but larger benefit will be seen once this issue is resolved.
1 parent 6c5a86f commit 3d219ae

File tree

1 file changed

+54
-31
lines changed
  • src/plugins/intel_npu/src/plugin/npuw/partitioning/patterns

1 file changed

+54
-31
lines changed

src/plugins/intel_npu/src/plugin/npuw/partitioning/patterns/dcoff.cpp

Lines changed: 54 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -914,50 +914,73 @@ CWAI2::CWAI2(CWAI2::Results scales) {
914914
register_matcher(std::make_shared<opp::Matcher>(mulply, "TagCWAI2"), std::move(matcher_callback));
915915
}
916916

917-
// Pattern: Phi-3 4SymW16A/GPTQ for CWAI
918-
//
919-
// FIXME: Think how it can be unified with the above
917+
// Keep the Weight scales as Constants in Graph
918+
// The patern matching has been generalized for the following cases in the Graph. fp32 (non-compressed), fp16
919+
// (compressed), slice, non-slice:
920920
//
921921
// "tensor" "scale"
922922
// Const:A Const:C
923923
// i4 f16|f32
924-
// : :
925-
// V :
926-
// Convert :
927-
// f16|f32 :
928-
// : :
929-
// V V
930-
// Multiply
924+
// : :
925+
// V V
926+
// Slice Convert
927+
// (optional) fp16|f32
928+
// : (optional)
929+
// V :
930+
// Convert V
931+
// f16|f32 Slice
932+
// : (optional)
933+
// : :
934+
// : :
935+
// : :
936+
// : :
937+
// V V
938+
// Multiply
939+
// f16|f32
940+
// :
941+
// V
942+
// MatMul
931943
// f16|f32
932-
944+
//
933945
CWAI3::CWAI3(CWAI3::Results scales) {
934946
auto constA = opp::wrap_type<ov::op::v0::Constant>();
935947
auto constC = opp::wrap_type<ov::op::v0::Constant>();
936-
auto cvtA = opp::wrap_type<ov::op::v0::Convert>({constA});
937-
auto mulply = opp::wrap_type<ov::op::v1::Multiply>({cvtA, constC});
948+
auto sliceA = opp::optional<ov::op::v8::Slice>(
949+
{constA->output(0), opp::any_input(), opp::any_input(), opp::any_input(), opp::any_input()});
950+
auto cvtA = opp::wrap_type<ov::op::v0::Convert>({sliceA});
951+
auto cvtC = opp::optional<ov::op::v0::Convert>({constC->output(0)});
952+
auto sliceC = opp::optional<ov::op::v8::Slice>(
953+
{cvtC->output(0), opp::any_input(), opp::any_input(), opp::any_input(), opp::any_input()});
954+
auto mulply = opp::wrap_type<ov::op::v1::Multiply>({cvtA, sliceC});
955+
auto matmul = opp::wrap_type<ov::op::v0::MatMul>({opp::any_input(), mulply});
938956

939957
auto matcher_callback = [=](ov::pass::pattern::Matcher& m) {
940958
auto& node_to_output = m.get_pattern_value_map();
941-
auto matched_nodeA = node_to_output.at(constA).get_node_shared_ptr();
942-
auto matched_nodeC = node_to_output.at(constC).get_node_shared_ptr();
943-
944-
NPUW_ASSERT(ov::op::util::is_constant(matched_nodeA));
945-
NPUW_ASSERT(ov::op::util::is_constant(matched_nodeC));
946-
947-
auto matched_valueA = std::static_pointer_cast<ov::op::v0::Constant>(matched_nodeA);
948-
auto matched_valueC = std::static_pointer_cast<ov::op::v0::Constant>(matched_nodeC);
949959

950-
if ((ov::element::i4 == matched_valueA->get_element_type() ||
951-
ov::element::nf4 == matched_valueA->get_element_type()) &&
952-
(ov::element::f16 == matched_valueC->get_element_type() ||
953-
ov::element::f32 == matched_valueC->get_element_type())) {
954-
LOG_DEBUG("Matched: " << matched_valueC);
955-
scales.get().push_back(matched_valueC);
960+
auto matched_node_A = node_to_output.at(constA).get_node_shared_ptr();
961+
auto matched_node_C = node_to_output.at(constC).get_node_shared_ptr();
962+
auto matched_node_matmul = node_to_output.at(matmul).get_node_shared_ptr();
963+
964+
auto matched_A = std::static_pointer_cast<ov::op::v0::Constant>(matched_node_A);
965+
auto matched_C = std::static_pointer_cast<ov::op::v0::Constant>(matched_node_C);
966+
auto matched_matmul = std::static_pointer_cast<ov::op::v0::MatMul>(matched_node_matmul);
967+
968+
if ((ov::element::f16 == matched_C->get_element_type() || ov::element::f32 == matched_C->get_element_type()) &&
969+
(ov::element::f16 == matched_matmul->get_element_type() ||
970+
ov::element::f32 == matched_matmul->get_element_type()) &&
971+
(ov::element::i4 == matched_A->get_element_type() || ov::element::nf4 == matched_A->get_element_type() ||
972+
ov::element::i8 == matched_A->get_element_type())) {
973+
auto matched_C_shape = matched_C->output(0).get_shape();
974+
975+
if (matched_C_shape.size() == 2 && matched_matmul->get_transpose_b()) {
976+
scales.get().push_back(matched_C);
977+
LOG_DEBUG("Matched: " << matched_C->get_friendly_name());
978+
return false; // root hasn't changed
979+
}
956980
}
957-
return true;
958-
}; // matcher_callback
959-
960-
register_matcher(std::make_shared<opp::Matcher>(mulply, "TagCWAI3"), std::move(matcher_callback));
981+
return false; // root hasn't changed
982+
};
983+
register_matcher(std::make_shared<opp::Matcher>(matmul, "TagCWAI3"), std::move(matcher_callback));
961984
}
962985

963986
// As seen in LLaMa-v2-7b:

0 commit comments

Comments
 (0)