Skip to content

Commit 3b6410a

Browse files
authored
[CPU] FullyConnected acceleration with u2 weights decompression (#31467)
### Details: - *FullyConnected acceleration with u2 weights decompression.* - *OneDNN PR: openvinotoolkit/oneDNN#289 ### Tickets: - *[CVS-169357](https://jira.devtools.intel.com/browse/CVS-169357)*
1 parent 69b0a7c commit 3b6410a

File tree

14 files changed

+175
-25
lines changed

14 files changed

+175
-25
lines changed

src/plugins/intel_cpu/src/dnnl_extension_utils.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,8 @@ std::optional<dnnl::memory::data_type> DnnlExtensionUtils::ElementTypeToDataType
9090
return memory::data_type::s4;
9191
case ov::element::u4:
9292
return memory::data_type::u4;
93+
case ov::element::u2:
94+
return memory::data_type::u2;
9395
case ov::element::f8e8m0:
9496
return memory::data_type::e8m0;
9597
case ov::element::f8e4m3:
@@ -137,6 +139,8 @@ ov::element::Type DnnlExtensionUtils::DataTypeToElementType(const dnnl::memory::
137139
return ov::element::i4;
138140
case memory::data_type::u4:
139141
return ov::element::u4;
142+
case memory::data_type::u2:
143+
return ov::element::u2;
140144
case memory::data_type::e8m0:
141145
return ov::element::f8e8m0;
142146
case memory::data_type::f8_e4m3:

src/plugins/intel_cpu/src/nodes/common/cpu_convert.cpp

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -845,6 +845,36 @@ struct ConvertFromBinPrecision<std::tuple<src_t, dst_t>> {
845845
}
846846
};
847847

848+
#define INTEL_CPU_CVT_FROM_2BIT_LIST \
849+
INTEL_CPU_CVT(u2, f32), INTEL_CPU_CVT(u2, f16), INTEL_CPU_CVT(u2, bf16), INTEL_CPU_CVT(u2, i32), \
850+
INTEL_CPU_CVT(u2, u8), INTEL_CPU_CVT(u2, i8)
851+
852+
struct ConvertFrom2BitContext {
853+
const void* srcPtr;
854+
void* dstPtr;
855+
size_t size;
856+
bool converted;
857+
};
858+
859+
template <typename T>
860+
struct ConvertFrom2BitPrecision;
861+
862+
[[maybe_unused]] static uint8_t get_u2(uint8_t val, uint8_t shift) {
863+
return static_cast<uint8_t>((val & (0x3 << shift)) >> shift);
864+
}
865+
866+
template <typename src_t, typename dst_t>
867+
struct ConvertFrom2BitPrecision<std::tuple<src_t, dst_t>> {
868+
void operator()(ConvertFrom2BitContext& ctx) {
869+
const auto* src = static_cast<const uint8_t*>(ctx.srcPtr);
870+
auto dst = static_cast<dst_t*>(ctx.dstPtr);
871+
parallel_for(ctx.size, [&](size_t i) {
872+
dst[i] = static_cast<dst_t>(get_u2(src[i / 4], (i % 4) * 2));
873+
});
874+
ctx.converted = true;
875+
}
876+
};
877+
848878
#define INTEL_CPU_CVT_FROM_4BIT_LIST \
849879
INTEL_CPU_CVT(u4, f32), INTEL_CPU_CVT(u4, i32), INTEL_CPU_CVT(u4, bf16), INTEL_CPU_CVT(u4, f16), \
850880
INTEL_CPU_CVT(u4, i8), INTEL_CPU_CVT(u4, u8), INTEL_CPU_CVT(i4, f32), INTEL_CPU_CVT(i4, i32), \
@@ -1069,6 +1099,10 @@ void cpu_convert(const void* srcPtr,
10691099
srcPrc.bitwidth(),
10701100
"> precision to: ",
10711101
dstPrc);
1102+
} else if (srcPrc == ov::element::u2) {
1103+
ConvertFrom2BitContext ctx{srcPtr, dstPtr, size, false};
1104+
OV_SWITCH(intel_cpu, ConvertFrom2BitPrecision, ctx, std::tie(srcPrc, dstPrc), INTEL_CPU_CVT_FROM_2BIT_LIST);
1105+
OPENVINO_ASSERT(ctx.converted, "cpu_convert can't convert from: ", srcPrc, " precision to: ", dstPrc);
10721106
} else if (srcPrc.bitwidth() == 4U) {
10731107
ConvertFrom4BitContext ctx{srcPrc, srcPtr, dstPtr, size, false};
10741108
OV_SWITCH(intel_cpu, ConvertFrom4BitPrecision, ctx, std::tie(srcPrc, dstPrc), INTEL_CPU_CVT_FROM_4BIT_LIST);
@@ -1115,6 +1149,7 @@ bool is_supported_convert([[maybe_unused]] ov::element::Type srcPrc, [[maybe_unu
11151149
isSupportedContext ctx;
11161150
OV_SWITCH(intel_cpu, isSupported, ctx, std::tie(srcPrc, dstPrc), INTEL_CPU_CVT_LIST);
11171151
OV_SWITCH(intel_cpu, isSupported, ctx, std::tie(srcPrc, dstPrc), INTEL_CPU_CVT_FROM_BIN_LIST);
1152+
OV_SWITCH(intel_cpu, isSupported, ctx, std::tie(srcPrc, dstPrc), INTEL_CPU_CVT_FROM_2BIT_LIST);
11181153
OV_SWITCH(intel_cpu, isSupported, ctx, std::tie(srcPrc, dstPrc), INTEL_CPU_CVT_FROM_4BIT_LIST);
11191154
OV_SWITCH(intel_cpu, isSupported, ctx, std::tie(srcPrc, dstPrc), INTEL_CPU_CVT_FROM_BYTE_FP_LIST);
11201155
OV_SWITCH(intel_cpu, isSupported, ctx, std::tie(srcPrc, dstPrc), INTEL_CPU_CVT_TO_4BIT_LIST);

src/plugins/intel_cpu/src/nodes/executors/dnnl/dnnl_fullyconnected_primitive.cpp

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ bool DnnlFCPrimitive::useWeightsDecompressionImpl(const ov::element::Type inputT
138138
const ov::element::Type weightsType,
139139
const ov::intel_cpu::Config::ModelType modelType) {
140140
if (dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx2)) {
141-
if (any_of(inputType, f32, bf16) && any_of(weightsType, u8, i8, nf4, u4, i4, f4e2m1)) {
141+
if (any_of(inputType, f32, bf16) && any_of(weightsType, u8, i8, nf4, u4, i4, f4e2m1, u2)) {
142142
return true;
143143
}
144144

@@ -176,11 +176,15 @@ static bool useDynamicQuantizationImpl(size_t dqGroupSize,
176176
// For dynamic quantization, VNNI accumulation requires weight to be unsigned.
177177
// To support dynamic quantization with weights symmetrically quantized as i8/i4
178178
// w/o zero-point, we will transform weight to u8/u4 weight with zp 128/8.
179-
if (none_of(weightsDesc->getPrecision(), ov::element::u8, ov::element::u4) &&
179+
if (none_of(weightsDesc->getPrecision(), ov::element::u8, ov::element::u4, ov::element::u2) &&
180180
!((any_of(weightsDesc->getPrecision(), ov::element::i8, ov::element::i4) && !zpPtr))) {
181181
return false;
182182
}
183-
if (zpPtr && none_of(zpPtr->getDesc().getPrecision(), ov::element::u8, ov::element::u4, ov::element::dynamic)) {
183+
if (zpPtr && none_of(zpPtr->getDesc().getPrecision(),
184+
ov::element::u8,
185+
ov::element::u4,
186+
ov::element::u2,
187+
ov::element::dynamic)) {
184188
return false;
185189
}
186190

@@ -255,6 +259,9 @@ static DnnlPrimitiveAttrs createPrimitiveAttrs(const FCAttrs& attrs,
255259

256260
if (auto it = memory.find(ARG_WEI | ARG_ATTR_ZERO_POINTS); it != memory.end()) {
257261
auto dstPrc = useDynamicQuantization ? ov::element::u8 : ov::element::f32;
262+
if (weiDesc->getPrecision() == ov::element::u2) {
263+
dstPrc = ov::element::u2;
264+
}
258265
dnnlpoc.appendDecompressionZeroPointsLegacy(it->second, !attrs.weightsNonTransposed, dstPrc);
259266
}
260267

src/plugins/intel_cpu/src/nodes/executors/fullyconnected_implementations.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,10 +95,10 @@ static const TypeMapping dnnlFCTypeMapping {
9595
{{_u8 | _i8, _i8, _f16, _u8 | _i8 | _i32 | _bf16 | _f32}, {bypass(), bypass(), just<f32>(), bypass()}},
9696
{{_u8 | _i8, _i8, _any, _any}, {bypass(), bypass(), just<f32>(), just<f32>()}},
9797
// compresses int weights (@todo more strict requrements for output precision?)
98-
{{_bf16, _u8 | _i8 | _nf4 | _u4 | _i4 | _f4e2m1, _any, _any}, {bypass(), bypass(), use<0>(), use<0>()},
98+
{{_bf16, _u8 | _i8 | _nf4 | _u4 | _i4 | _f4e2m1 | _u2, _any, _any}, {bypass(), bypass(), use<0>(), use<0>()},
9999
Require<dnnl::impl::cpu::x64::avx512_core_bf16>()}, // Ticket 122347
100100
{{_bf16, _u8 | _i8 | _nf4 | _u4 | _i4 | _f4e2m1, _any, _any}, {just<f32>(), bypass(), just<f32>(), just<f32>()}},
101-
{{_f32, _u8 | _i8 | _nf4 | _u4 | _i4 | _f4e2m1, _any, _any}, {bypass(), bypass(), use<0>(), use<0>()}},
101+
{{_f32, _u8 | _i8 | _nf4 | _u4 | _i4 | _f4e2m1 | _u2, _any, _any}, {bypass(), bypass(), use<0>(), use<0>()}},
102102
// @todo should we fallback to FPXX instead of _f32?
103103
{{_any, _any, _any, _any}, {just<f32>(), just<f32>(), just<f32>(), just<f32>()}},
104104
// @todo explicitly cover configuration limitations for oneDNN on ARM

src/plugins/intel_cpu/src/nodes/executors/type_mask.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ struct TypeMask {
3535
_string = 1 << 20,
3636
_f4e2m1 = 1 << 21,
3737
_f8e8m0 = 1 << 22,
38+
_u2 = 1 << 23,
3839
};
3940

4041
explicit TypeMask(const ov::element::Type precision) : value(generateMask(precision)), precision(precision) {}
@@ -82,6 +83,7 @@ struct TypeMask {
8283
CASE(string)
8384
CASE(f4e2m1)
8485
CASE(f8e8m0)
86+
CASE(u2)
8587
default:
8688
return _dynamic;
8789
}
@@ -116,6 +118,7 @@ DEFINE_TYPE_ALIAS(_f8e5m2);
116118
DEFINE_TYPE_ALIAS(_string);
117119
DEFINE_TYPE_ALIAS(_f4e2m1);
118120
DEFINE_TYPE_ALIAS(_f8e8m0);
121+
DEFINE_TYPE_ALIAS(_u2);
119122
constexpr auto _any_float = _f64 | _f32 | _f16 | _bf16;
120123
constexpr auto _hw_float = _f32 | _f16 | _bf16;
121124
constexpr auto _half_float = _f16 | _bf16;

src/plugins/intel_cpu/src/nodes/fullyconnected.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ ov::element::TypeVector FullyConnected::getSupportedCompressedWeightsTypes([[may
7272
}
7373
#if defined(OPENVINO_ARCH_X86_64)
7474
ov::element::TypeVector supportedDataTypes =
75-
{Type_t::u8, Type_t::i8, Type_t::u4, Type_t::i4, Type_t::nf4, Type_t::f4e2m1};
75+
{Type_t::u8, Type_t::i8, Type_t::u4, Type_t::i4, Type_t::nf4, Type_t::f4e2m1, Type_t::u2};
7676
if (apply_fp8) {
7777
supportedDataTypes.insert(supportedDataTypes.end(), {Type_t::f8e4m3, Type_t::f8e5m2});
7878
}

src/plugins/intel_cpu/src/plugin.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -318,14 +318,14 @@ std::shared_ptr<ov::ICompiledModel> Plugin::compile_model(const std::shared_ptr<
318318
for (const auto& ii : model->inputs()) {
319319
auto input_precision = ii.get_element_type();
320320
static const std::set<ov::element::Type_t> supported_precisions = {
321-
ov::element::Type_t::u4, ov::element::Type_t::i4, ov::element::Type_t::u8,
322-
ov::element::Type_t::i8, ov::element::Type_t::f8e4m3, ov::element::Type_t::f8e5m2,
323-
ov::element::Type_t::u16, ov::element::Type_t::i16, ov::element::Type_t::u32,
324-
ov::element::Type_t::i32, ov::element::Type_t::u64, ov::element::Type_t::i64,
325-
ov::element::Type_t::bf16, ov::element::Type_t::f16, ov::element::Type_t::f32,
326-
ov::element::Type_t::f64, ov::element::Type_t::boolean, ov::element::Type_t::string,
327-
ov::element::Type_t::nf4, ov::element::Type_t::f4e2m1, ov::element::Type_t::f8e8m0,
328-
ov::element::Type_t::dynamic};
321+
ov::element::Type_t::u4, ov::element::Type_t::i4, ov::element::Type_t::u8,
322+
ov::element::Type_t::i8, ov::element::Type_t::f8e4m3, ov::element::Type_t::f8e5m2,
323+
ov::element::Type_t::u16, ov::element::Type_t::i16, ov::element::Type_t::u32,
324+
ov::element::Type_t::i32, ov::element::Type_t::u64, ov::element::Type_t::i64,
325+
ov::element::Type_t::bf16, ov::element::Type_t::f16, ov::element::Type_t::f32,
326+
ov::element::Type_t::f64, ov::element::Type_t::boolean, ov::element::Type_t::string,
327+
ov::element::Type_t::nf4, ov::element::Type_t::f4e2m1, ov::element::Type_t::f8e8m0,
328+
ov::element::Type_t::u2, ov::element::Type_t::dynamic};
329329

330330
if (supported_precisions.find(input_precision) == supported_precisions.end()) {
331331
OPENVINO_THROW_NOT_IMPLEMENTED("CPU plugin: Input image format ",

src/plugins/intel_cpu/src/utils/plain_tensor.hpp

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -405,6 +405,9 @@ struct PlainTensor {
405405
if (any_of(m_dt, ov::element::i4, ov::element::u4)) {
406406
return 2;
407407
}
408+
if (m_dt == ov::element::u2) {
409+
return 4;
410+
}
408411
return 1;
409412
}
410413

@@ -423,7 +426,15 @@ struct PlainTensor {
423426

424427
template <typename DT, ov::element::Type_t SRC_PREC = ov::element::u8, typename... Is>
425428
[[nodiscard]] DT* ptr(Is... indices) const {
426-
constexpr size_t stride_div = SRC_PREC == ov::element::u4 ? 2 : 1;
429+
constexpr size_t stride_div = [] {
430+
if (SRC_PREC == ov::element::u2) {
431+
return 4;
432+
}
433+
if (SRC_PREC == ov::element::u4) {
434+
return 2;
435+
}
436+
return 1;
437+
}();
427438
const size_t off = offset<0>(indices...) / stride_div;
428439
return reinterpret_cast<DT*>(m_ptr.get()) + off;
429440
}

src/plugins/intel_cpu/tests/functional/custom/single_layer_tests/classes/conversion.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ void ConvertCPULayerTest::SetUp() {
9696
if (primitive.empty())
9797
primitive = getPrimitiveType();
9898
#if defined(OPENVINO_ARCH_ARM64)
99-
if (inPrc == ov::element::u4 || inPrc == ov::element::i4 ||
99+
if (inPrc == ov::element::u2 || inPrc == ov::element::u4 || inPrc == ov::element::i4 ||
100100
inPrc == ov::element::f4e2m1 || inPrc == ov::element::f8e8m0 ||
101101
inPrc == ov::element::f8e4m3 || inPrc == ov::element::f8e5m2 ||
102102
outPrc == ov::element::f8e4m3 || outPrc == ov::element::f8e5m2 ||

src/plugins/intel_cpu/tests/functional/custom/single_layer_tests/instances/common/conversion.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,24 @@ INSTANTIATE_TEST_SUITE_P(smoke_ConvertCPULayerTest_from_f8e8m0, ConvertCPULayerT
9191
::testing::Values(CPUSpecificParams({nchw}, {nchw}, {}, {"ref"}))),
9292
ConvertCPULayerTest::getTestCaseName);
9393

94+
const std::vector<ov::element::Type> common_precisions = {
95+
ov::element::f32,
96+
ov::element::i32,
97+
ov::element::f16,
98+
ov::element::bf16,
99+
ov::element::u8,
100+
ov::element::i8,
101+
};
102+
103+
INSTANTIATE_TEST_SUITE_P(smoke_ConvertCPULayerTest_from_u2, ConvertCPULayerTest,
104+
::testing::Combine(
105+
::testing::ValuesIn(inShapes_4D_dynamic()),
106+
::testing::Values(ov::element::u2),
107+
::testing::ValuesIn(common_precisions),
108+
::testing::Values(ov::test::SpecialValue::none),
109+
::testing::Values(CPUSpecificParams({}, {}, {}, {"ref"}))),
110+
ConvertCPULayerTest::getTestCaseName);
111+
94112
const std::vector<ov::element::Type> f8_precisions = {
95113
ov::element::f8e4m3,
96114
ov::element::f8e5m2,

0 commit comments

Comments
 (0)