diff --git a/src/cpu/cpu_inner_product_list.cpp b/src/cpu/cpu_inner_product_list.cpp index 51f754c8450..798f22d9d5a 100644 --- a/src/cpu/cpu_inner_product_list.cpp +++ b/src/cpu/cpu_inner_product_list.cpp @@ -398,6 +398,16 @@ const std::map> &impl_list_map() CPU_INSTANCE(ref_inner_product_int8_fwd_t) nullptr, }}, + {{forward, s8, s8, f16}, { + CPU_INSTANCE_X64(matmul_inner_product_fwd_t) + CPU_INSTANCE(ref_inner_product_int8_fwd_t) + nullptr, + }}, + {{forward, u8, s8, f16}, { + CPU_INSTANCE_X64(matmul_inner_product_fwd_t) + CPU_INSTANCE(ref_inner_product_int8_fwd_t) + nullptr, + }}, {{forward, s8, s8, bf16}, { //CPU_INSTANCE_X64(matmul_inner_product_fwd_t) CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t,avx512_core_amx) diff --git a/src/cpu/ref_inner_product_int8.hpp b/src/cpu/ref_inner_product_int8.hpp index 4f16d2e368d..d09903e3ef5 100644 --- a/src/cpu/ref_inner_product_int8.hpp +++ b/src/cpu/ref_inner_product_int8.hpp @@ -51,12 +51,12 @@ struct ref_inner_product_int8_fwd_t : public primitive_t { VDISPATCH_INNER_PRODUCT( utils::one_of(src_type, s8, u8), VERBOSE_UNSUPPORTED_DT); VDISPATCH_INNER_PRODUCT(wei_type == s8, VERBOSE_UNSUPPORTED_DT); - VDISPATCH_INNER_PRODUCT( - IMPLICATION(with_bias(), - utils::one_of(bia_type, f32, bf16, s32, s8, u8)), + VDISPATCH_INNER_PRODUCT(IMPLICATION(with_bias(), + utils::one_of(bia_type, f32, f16, + bf16, s32, s8, u8)), VERBOSE_UNSUPPORTED_DT); VDISPATCH_INNER_PRODUCT( - utils::one_of(dst_type, f32, bf16, s32, s8, u8), + utils::one_of(dst_type, f32, f16, bf16, s32, s8, u8), VERBOSE_UNSUPPORTED_DT); VDISPATCH_INNER_PRODUCT( IMPLICATION(with_bias(), diff --git a/src/cpu/x64/brgemm/brgemm.cpp b/src/cpu/x64/brgemm/brgemm.cpp index 036902d0104..6b0326b28c8 100644 --- a/src/cpu/x64/brgemm/brgemm.cpp +++ b/src/cpu/x64/brgemm/brgemm.cpp @@ -415,9 +415,10 @@ status_t brgemm_desc_set_postops(brgemm_desc_t *brg, // check that combination of data types is allowed if ((brg->dt_a == data_type::u8 && brg->dt_b == data_type::s8) && (!one_of(dt_d, data_type::u8, data_type::s8, data_type::s32, - data_type::f32, data_type::bf16)) + data_type::f32, data_type::f16, data_type::bf16)) && (!one_of(dt_bias, data_type::undef, data_type::u8, data_type::s8, - data_type::s32, data_type::f32, data_type::bf16))) + data_type::s32, data_type::f32, data_type::f16, + data_type::bf16))) return status::unimplemented; if ((brg->dt_a == data_type::bf16 && brg->dt_b == data_type::bf16) && (!one_of(dt_d, data_type::bf16, data_type::f32)) diff --git a/src/cpu/x64/matmul/brgemm_matmul.cpp b/src/cpu/x64/matmul/brgemm_matmul.cpp index 9b739fc7b61..9a94c5be51a 100644 --- a/src/cpu/x64/matmul/brgemm_matmul.cpp +++ b/src/cpu/x64/matmul/brgemm_matmul.cpp @@ -83,7 +83,7 @@ status_t brgemm_matmul_t::pd_t::init(engine_t *engine) { // ICE in GCC 7.4. const bool is_bia_dt_correct = IMPLICATION(is_int8 == true, - one_of(bia_dt, f32, s32, s8, u8, bf16)) + one_of(bia_dt, f32, s32, s8, u8, f16, bf16)) && IMPLICATION( is_f8 == true, one_of(bia_dt, f32, f16, bf16, src_dt)) && IMPLICATION( diff --git a/src/cpu/x64/matmul/brgemm_matmul_utils.cpp b/src/cpu/x64/matmul/brgemm_matmul_utils.cpp index 1d4fe2d44e8..884bebf4edf 100644 --- a/src/cpu/x64/matmul/brgemm_matmul_utils.cpp +++ b/src/cpu/x64/matmul/brgemm_matmul_utils.cpp @@ -240,7 +240,7 @@ brgemm_matmul_conf_utils_t::brgemm_matmul_conf_utils_t( && one_of(bgmmc.wei_dt, f8_e5m2, f8_e4m3) && one_of(bgmmc.dst_dt, f16, f32, bf16, f8_e5m2, f8_e4m3)) , int8_dt(utils::one_of(bgmmc.src_dt, u8, s8) && bgmmc.wei_dt == s8 - && one_of(bgmmc.dst_dt, u8, s8, s32, f32, bf16)) + && one_of(bgmmc.dst_dt, u8, s8, s32, f32, f16, bf16)) , bf32_dt(f32_dt && one_of(attr.fpmath_.mode_, fpmath_mode::bf16, fpmath_mode::any) && isa == avx512_core_amx) diff --git a/src/cpu/x64/matmul_inner_product.hpp b/src/cpu/x64/matmul_inner_product.hpp index a055a90c64c..82d09dd5cd9 100644 --- a/src/cpu/x64/matmul_inner_product.hpp +++ b/src/cpu/x64/matmul_inner_product.hpp @@ -59,7 +59,7 @@ struct matmul_inner_product_fwd_t : public primitive_t { const auto wei_dt = invariant_wei_md()->data_type; const auto dst_dt = invariant_dst_md()->data_type; const bool is_int8 = utils::one_of(src_dt, u8, s8) && wei_dt == s8 - && utils::one_of(dst_dt, u8, s8, s32, f32, bf16); + && utils::one_of(dst_dt, u8, s8, s32, f32, f16, bf16); auto skip_mask = skip_mask_t::post_ops | skip_mask_t::sum_dt | skip_mask_t::fpmath_mode; diff --git a/tests/benchdnn/inputs/matmul/test_matmul_ci b/tests/benchdnn/inputs/matmul/test_matmul_ci index c359a9f0987..5aa28081048 100644 --- a/tests/benchdnn/inputs/matmul/test_matmul_ci +++ b/tests/benchdnn/inputs/matmul/test_matmul_ci @@ -4,7 +4,7 @@ # Plain cases --reset ---dt=f64,f32,bf16,f16,f8_e5m2,f8_e4m3,f8_e5m2:f8_e4m3:f32,u8:s8:s8,s8:s8:f32,s8:s8:f16,s8:u8:f16 +--dt=f64,f32,bf16,f16,f8_e5m2,f8_e4m3,f8_e5m2:f8_e4m3:f32,u8:s8:s8,s8:s8:f32,s8:s8:f16,u8:s8:f16 --bia-dt=f32 --bia_mask=2 --batch=shapes_2d_ci @@ -14,7 +14,7 @@ # Post-ops check for different data types --reset ---dt=f32,bf16,f16,f8_e5m2,f8_e4m3,u8:s8:s8,s8:s8:f32,s8:s8:f16,s8:u8:f16 +--dt=f32,bf16,f16,f8_e5m2,f8_e4m3,u8:s8:s8,s8:s8:f32,s8:s8:f16,u8:s8:f16 --attr-post-ops=sum+relu:0.5+add:f32 --batch=shapes_2d_ci @@ -35,7 +35,7 @@ # Different tags --reset ---dt=f64,f32,bf16,f16,f8_e5m2,f8_e4m3,u8:s8:s8,s8:s8:f32,s8:s8:f16,s8:u8:f16 +--dt=f64,f32,bf16,f16,f8_e5m2,f8_e4m3,u8:s8:s8,s8:s8:f32,s8:s8:f16,u8:s8:f16 --stag=ab,ba --wtag=ab,ba --dtag=ab,ba @@ -58,7 +58,7 @@ # Arg scales check --reset ---dt=f64,f32,bf16,f16,f8_e5m2,f8_e4m3,u8:s8:u8,s8:s8:f32,s8:s8:f16,s8:u8:f16 +--dt=f64,f32,bf16,f16,f8_e5m2,f8_e4m3,u8:s8:u8,s8:s8:f32,s8:s8:f16,u8:s8:f16 --attr-scales=src:common:0.25+wei:common:0.5+dst:common:2,wei:per_oc --batch=shapes_2d_ci @@ -72,7 +72,7 @@ # Zero-points check --reset ---dt=s8:s8:s8,u8:s8:f32,u8:s8:bf16,s8:s8:f16,s8:u8:f16 +--dt=s8:s8:s8,u8:s8:f32,u8:s8:bf16,s8:s8:f16,u8:s8:f16 --attr-zero-points=src:common:1+wei:common:-1+dst:common:2 --batch=shapes_2d_ci diff --git a/tests/benchdnn/inputs/matmul/test_matmul_int8 b/tests/benchdnn/inputs/matmul/test_matmul_int8 index 44d165e4caa..8a4141f3fe0 100644 --- a/tests/benchdnn/inputs/matmul/test_matmul_int8 +++ b/tests/benchdnn/inputs/matmul/test_matmul_int8 @@ -1,7 +1,7 @@ # int8 --reset ---dt=u8:s8:s8,s8:s8:f32,s8:s8:f16,s8:u8:f16 +--dt=u8:s8:s8,s8:s8:f32,s8:s8:f16,u8:s8:f16 --stag=ab --wtag=ab,ba --dtag=ab --runtime_dims_masks=0,2:1,1:0,3:1 --bia-dt=undef,f32 --bia_mask=2