Skip to content

Commit 8db5d5a

Browse files
densamoilovliubo-intel
authored andcommitted
[FORK][x64] enable f16 dst for s8/u8 inputs in fc
1 parent 3d7a6f1 commit 8db5d5a

File tree

8 files changed

+26
-15
lines changed

8 files changed

+26
-15
lines changed

src/cpu/cpu_inner_product_list.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -398,6 +398,16 @@ const std::map<pk_dt_impl_key_t, std::vector<impl_list_item_t>> &impl_list_map()
398398
CPU_INSTANCE(ref_inner_product_int8_fwd_t)
399399
nullptr,
400400
}},
401+
{{forward, s8, s8, f16}, {
402+
CPU_INSTANCE_X64(matmul_inner_product_fwd_t)
403+
CPU_INSTANCE(ref_inner_product_int8_fwd_t)
404+
nullptr,
405+
}},
406+
{{forward, u8, s8, f16}, {
407+
CPU_INSTANCE_X64(matmul_inner_product_fwd_t)
408+
CPU_INSTANCE(ref_inner_product_int8_fwd_t)
409+
nullptr,
410+
}},
401411
{{forward, s8, s8, bf16}, {
402412
//CPU_INSTANCE_X64(matmul_inner_product_fwd_t)
403413
CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t,avx512_core_amx)

src/cpu/ref_inner_product_int8.hpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,12 +51,12 @@ struct ref_inner_product_int8_fwd_t : public primitive_t {
5151
VDISPATCH_INNER_PRODUCT(
5252
utils::one_of(src_type, s8, u8), VERBOSE_UNSUPPORTED_DT);
5353
VDISPATCH_INNER_PRODUCT(wei_type == s8, VERBOSE_UNSUPPORTED_DT);
54-
VDISPATCH_INNER_PRODUCT(
55-
IMPLICATION(with_bias(),
56-
utils::one_of(bia_type, f32, bf16, s32, s8, u8)),
54+
VDISPATCH_INNER_PRODUCT(IMPLICATION(with_bias(),
55+
utils::one_of(bia_type, f32, f16,
56+
bf16, s32, s8, u8)),
5757
VERBOSE_UNSUPPORTED_DT);
5858
VDISPATCH_INNER_PRODUCT(
59-
utils::one_of(dst_type, f32, bf16, s32, s8, u8),
59+
utils::one_of(dst_type, f32, f16, bf16, s32, s8, u8),
6060
VERBOSE_UNSUPPORTED_DT);
6161
VDISPATCH_INNER_PRODUCT(
6262
IMPLICATION(with_bias(),

src/cpu/x64/brgemm/brgemm.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -415,9 +415,10 @@ status_t brgemm_desc_set_postops(brgemm_desc_t *brg,
415415
// check that combination of data types is allowed
416416
if ((brg->dt_a == data_type::u8 && brg->dt_b == data_type::s8)
417417
&& (!one_of(dt_d, data_type::u8, data_type::s8, data_type::s32,
418-
data_type::f32, data_type::bf16))
418+
data_type::f32, data_type::f16, data_type::bf16))
419419
&& (!one_of(dt_bias, data_type::undef, data_type::u8, data_type::s8,
420-
data_type::s32, data_type::f32, data_type::bf16)))
420+
data_type::s32, data_type::f32, data_type::f16,
421+
data_type::bf16)))
421422
return status::unimplemented;
422423
if ((brg->dt_a == data_type::bf16 && brg->dt_b == data_type::bf16)
423424
&& (!one_of(dt_d, data_type::bf16, data_type::f32))

src/cpu/x64/matmul/brgemm_matmul.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ status_t brgemm_matmul_t<isa>::pd_t::init(engine_t *engine) {
8383
// ICE in GCC 7.4.
8484
const bool is_bia_dt_correct
8585
= IMPLICATION(is_int8 == true,
86-
one_of(bia_dt, f32, s32, s8, u8, bf16))
86+
one_of(bia_dt, f32, s32, s8, u8, f16, bf16))
8787
&& IMPLICATION(
8888
is_f8 == true, one_of(bia_dt, f32, f16, bf16, src_dt))
8989
&& IMPLICATION(

src/cpu/x64/matmul/brgemm_matmul_utils.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,7 @@ brgemm_matmul_conf_utils_t::brgemm_matmul_conf_utils_t(
240240
&& one_of(bgmmc.wei_dt, f8_e5m2, f8_e4m3)
241241
&& one_of(bgmmc.dst_dt, f16, f32, bf16, f8_e5m2, f8_e4m3))
242242
, int8_dt(utils::one_of(bgmmc.src_dt, u8, s8) && bgmmc.wei_dt == s8
243-
&& one_of(bgmmc.dst_dt, u8, s8, s32, f32, bf16))
243+
&& one_of(bgmmc.dst_dt, u8, s8, s32, f32, f16, bf16))
244244
, bf32_dt(f32_dt
245245
&& one_of(attr.fpmath_.mode_, fpmath_mode::bf16, fpmath_mode::any)
246246
&& isa == avx512_core_amx)

src/cpu/x64/matmul_inner_product.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ struct matmul_inner_product_fwd_t : public primitive_t {
5959
const auto wei_dt = invariant_wei_md()->data_type;
6060
const auto dst_dt = invariant_dst_md()->data_type;
6161
const bool is_int8 = utils::one_of(src_dt, u8, s8) && wei_dt == s8
62-
&& utils::one_of(dst_dt, u8, s8, s32, f32, bf16);
62+
&& utils::one_of(dst_dt, u8, s8, s32, f32, f16, bf16);
6363

6464
auto skip_mask = skip_mask_t::post_ops | skip_mask_t::sum_dt
6565
| skip_mask_t::fpmath_mode;

tests/benchdnn/inputs/matmul/test_matmul_ci

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
# Plain cases
66
--reset
7-
--dt=f64,f32,bf16,f16,f8_e5m2,f8_e4m3,f8_e5m2:f8_e4m3:f32,u8:s8:s8,s8:s8:f32,s8:s8:f16,s8:u8:f16
7+
--dt=f64,f32,bf16,f16,f8_e5m2,f8_e4m3,f8_e5m2:f8_e4m3:f32,u8:s8:s8,s8:s8:f32,s8:s8:f16,u8:s8:f16
88
--bia-dt=f32
99
--bia_mask=2
1010
--batch=shapes_2d_ci
@@ -14,7 +14,7 @@
1414

1515
# Post-ops check for different data types
1616
--reset
17-
--dt=f32,bf16,f16,f8_e5m2,f8_e4m3,u8:s8:s8,s8:s8:f32,s8:s8:f16,s8:u8:f16
17+
--dt=f32,bf16,f16,f8_e5m2,f8_e4m3,u8:s8:s8,s8:s8:f32,s8:s8:f16,u8:s8:f16
1818
--attr-post-ops=sum+relu:0.5+add:f32
1919
--batch=shapes_2d_ci
2020

@@ -35,7 +35,7 @@
3535

3636
# Different tags
3737
--reset
38-
--dt=f64,f32,bf16,f16,f8_e5m2,f8_e4m3,u8:s8:s8,s8:s8:f32,s8:s8:f16,s8:u8:f16
38+
--dt=f64,f32,bf16,f16,f8_e5m2,f8_e4m3,u8:s8:s8,s8:s8:f32,s8:s8:f16,u8:s8:f16
3939
--stag=ab,ba
4040
--wtag=ab,ba
4141
--dtag=ab,ba
@@ -58,7 +58,7 @@
5858

5959
# Arg scales check
6060
--reset
61-
--dt=f64,f32,bf16,f16,f8_e5m2,f8_e4m3,u8:s8:u8,s8:s8:f32,s8:s8:f16,s8:u8:f16
61+
--dt=f64,f32,bf16,f16,f8_e5m2,f8_e4m3,u8:s8:u8,s8:s8:f32,s8:s8:f16,u8:s8:f16
6262
--attr-scales=src:common:0.25+wei:common:0.5+dst:common:2,wei:per_oc
6363
--batch=shapes_2d_ci
6464

@@ -72,7 +72,7 @@
7272

7373
# Zero-points check
7474
--reset
75-
--dt=s8:s8:s8,u8:s8:f32,u8:s8:bf16,s8:s8:f16,s8:u8:f16
75+
--dt=s8:s8:s8,u8:s8:f32,u8:s8:bf16,s8:s8:f16,u8:s8:f16
7676
--attr-zero-points=src:common:1+wei:common:-1+dst:common:2
7777
--batch=shapes_2d_ci
7878

tests/benchdnn/inputs/matmul/test_matmul_int8

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# int8
22
--reset
33

4-
--dt=u8:s8:s8,s8:s8:f32,s8:s8:f16,s8:u8:f16
4+
--dt=u8:s8:s8,s8:s8:f32,s8:s8:f16,u8:s8:f16
55
--stag=ab --wtag=ab,ba --dtag=ab
66
--runtime_dims_masks=0,2:1,1:0,3:1
77
--bia-dt=undef,f32 --bia_mask=2

0 commit comments

Comments
 (0)