Skip to content

Commit d3231d9

Browse files
authored
[Snippets][CPU] Make EnforcePrecision pass arch independent (#32315)
### Tickets: - N/A
1 parent 5cb69a3 commit d3231d9

File tree

4 files changed

+20
-25
lines changed

4 files changed

+20
-25
lines changed

src/plugins/intel_cpu/src/nodes/subgraph.cpp

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,9 @@
4646

4747
# include "emitters/snippets/x64/cpu_generator.hpp"
4848
# include "executors/x64/subgraph.hpp"
49+
# include "snippets/op/brgemm.hpp"
4950
# include "snippets/pass/matmul_to_brgemm.hpp"
51+
# include "transformations/snippets/x64/op/brgemm_utils.hpp"
5052
#elif defined(OPENVINO_ARCH_ARM64)
5153
# include <cpu/aarch64/cpu_isa_traits.hpp>
5254

@@ -74,9 +76,9 @@
7476
# include "snippets/lowered/pass/init_loops.hpp"
7577
# include "snippets/lowered/pass/insert_buffers.hpp"
7678
# include "snippets/lowered/pass/insert_loops.hpp"
79+
# include "transformations/snippets/common/pass/enforce_precision.hpp"
7780
# include "transformations/snippets/x64/pass/brgemm_to_brgemm_cpu.hpp"
7881
# include "transformations/snippets/x64/pass/eliminate_brgemm_copy_b.hpp"
79-
# include "transformations/snippets/x64/pass/enforce_precision.hpp"
8082
# include "transformations/snippets/x64/pass/fuse_brgemm_cpu_postops.hpp"
8183
# include "transformations/snippets/x64/pass/lowered/adjust_brgemm_copy_b_loop_ports.hpp"
8284
# include "transformations/snippets/x64/pass/lowered/brgemm_cpu_blocking.hpp"
@@ -551,7 +553,19 @@ Subgraph::DataFlowPasses Subgraph::getDataFlowPasses() {
551553
ov::snippets::pass::MatMulToBrgemm,
552554
pass::EnforcePrecision,
553555
element::f32,
554-
context->getConfig().inferencePrecision);
556+
context->getConfig().inferencePrecision,
557+
[](const std::shared_ptr<ov::Node>& op) {
558+
std::set<std::vector<ov::element::Type>> types;
559+
if (ov::is_type<ov::snippets::op::Brgemm>(op)) {
560+
if (ov::intel_cpu::brgemm_utils::is_fp16_supported()) {
561+
types.insert({ov::element::f16, ov::element::f16});
562+
}
563+
if (ov::intel_cpu::brgemm_utils::is_bf16_supported()) {
564+
types.insert({ov::element::bf16, ov::element::bf16});
565+
}
566+
}
567+
return types;
568+
});
555569
}
556570

557571
SNIPPETS_REGISTER_PASS_RELATIVE_X86_64(Place::Before,

src/plugins/intel_cpu/src/transformations/snippets/x64/pass/enforce_precision.cpp renamed to src/plugins/intel_cpu/src/transformations/snippets/common/pass/enforce_precision.cpp

Lines changed: 2 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,8 @@
2121
#include "openvino/itt.hpp"
2222
#include "ov_ops/type_relaxed.hpp"
2323
#include "snippets/itt.hpp"
24-
#include "snippets/op/brgemm.hpp"
2524
#include "snippets/op/convert_saturation.hpp"
2625
#include "snippets/pass/propagate_precision.hpp"
27-
#include "transformations/snippets/x64/op/brgemm_utils.hpp"
2826
#include "transformations/utils/utils.hpp"
2927
#include "utils/general_utils.h"
3028

@@ -37,8 +35,8 @@ EnforcePrecision::EnforcePrecision(
3735
get_supported_precisions)
3836
: source(source),
3937
target(target),
40-
get_supported_precisions(get_supported_precisions == nullptr ? get_supported_precisions_default
41-
: get_supported_precisions) {
38+
get_supported_precisions(get_supported_precisions) {
39+
OPENVINO_ASSERT(get_supported_precisions != nullptr, "get_supported_precisions callback is not set");
4240
OPENVINO_ASSERT(source != target, "source and target precisions have to be different");
4341
}
4442

@@ -132,17 +130,3 @@ bool EnforcePrecision::run_on_model(const std::shared_ptr<ov::Model>& m) {
132130

133131
return was_updated;
134132
}
135-
136-
std::set<std::vector<ov::element::Type>> EnforcePrecision::get_supported_precisions_default(
137-
const std::shared_ptr<ov::Node>& op) noexcept {
138-
std::set<std::vector<ov::element::Type>> types;
139-
if (ov::is_type<snippets::op::Brgemm>(op)) {
140-
if (brgemm_utils::is_fp16_supported()) {
141-
types.insert({element::f16, element::f16});
142-
}
143-
if (brgemm_utils::is_bf16_supported()) {
144-
types.insert({element::bf16, element::bf16});
145-
}
146-
}
147-
return types;
148-
}

src/plugins/intel_cpu/src/transformations/snippets/x64/pass/enforce_precision.hpp renamed to src/plugins/intel_cpu/src/transformations/snippets/common/pass/enforce_precision.hpp

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,11 @@ class EnforcePrecision : public ov::pass::ModelPass {
2323
EnforcePrecision(element::Type source,
2424
element::Type target,
2525
const std::function<std::set<std::vector<element::Type>>(const std::shared_ptr<ov::Node>& op)>&
26-
get_supported_precisions = nullptr);
26+
get_supported_precisions);
2727

2828
bool run_on_model(const std::shared_ptr<ov::Model>& m) override;
2929

3030
private:
31-
static std::set<std::vector<element::Type>> get_supported_precisions_default(
32-
const std::shared_ptr<ov::Node>& op) noexcept;
33-
3431
const element::Type source;
3532
const element::Type target;
3633
const std::function<std::set<std::vector<element::Type>>(const std::shared_ptr<ov::Node>& op)>

src/plugins/intel_cpu/tests/unit/snippets_transformations/x64/enforce_precision.cpp renamed to src/plugins/intel_cpu/tests/unit/snippets_transformations/common/enforce_precision.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010
#include <gtest/gtest.h>
1111

1212
#include "openvino/core/type/element_type.hpp"
13-
#include "transformations/snippets/x64/pass/enforce_precision.hpp"
1413
#include "common_test_utils/common_utils.hpp"
14+
#include "transformations/snippets/common/pass/enforce_precision.hpp"
1515
#include "two_binary_ops.hpp"
1616

1717
namespace ov {

0 commit comments

Comments
 (0)