@@ -1521,6 +1521,15 @@ void program_node::create_onednn_primitive_attributes(
1521
1521
memory_offset++;
1522
1522
};
1523
1523
1524
+ auto resize_layout_for_fc = [](const program_node *node, layout& in_layout) {
1525
+ if (node->is_type <fully_connected>()) {
1526
+ auto input_size = node->as <fully_connected>().get_primitive ()->input_size ;
1527
+ auto new_pshape = in_layout.get_partial_shape ();
1528
+ new_pshape.resize (input_size);
1529
+ in_layout.set_partial_shape (new_pshape);
1530
+ }
1531
+ };
1532
+
1524
1533
int32_t num_sum_post_ops = 0 ;
1525
1534
for (size_t idx = 0 ; idx < cldnn_post_ops.size (); idx++) {
1526
1535
auto & desc = cldnn_post_ops[idx];
@@ -1582,8 +1591,7 @@ void program_node::create_onednn_primitive_attributes(
1582
1591
new_layout.set_partial_shape (new_input_pshape);
1583
1592
in = new_layout;
1584
1593
}
1585
- size_t in_batched_size = in.count () / (in.spatial (0 ) * in.spatial (1 ));
1586
- dnnl::memory::dims dims = onednn::convert_gemm_tensor (in.get_tensor (), rank, in_batched_size == 1 );
1594
+ dnnl::memory::dims dims = onednn::convert_gemm_tensor (in.get_tensor (), rank, false );
1587
1595
dnnl::memory::data_type dt = onednn::convert_data_type (in.data_type );
1588
1596
dnnl::memory::format_tag fmt = onednn::convert_gemm_data_format (dims, in.format );
1589
1597
post_ops.append_binary (alg, dnnl::memory::desc (dims, dt, fmt));
@@ -1648,6 +1656,7 @@ void program_node::create_onednn_primitive_attributes(
1648
1656
update_onednn_post_op_list (onednn_post_op_type::eltwise_linear, empty_mem);
1649
1657
} else {
1650
1658
auto in_scale = get_input_layout (dep_idx++);
1659
+ resize_layout_for_fc (this , in_scale);
1651
1660
dnnl::memory::desc in_scale_desc = onednn::layout_to_memory_desc (in_scale, onednn::get_default_data_format (in_scale));
1652
1661
post_ops.append_binary (dnnl::algorithm::binary_mul, in_scale_desc);
1653
1662
update_onednn_post_op_list (onednn_post_op_type::binary_mul, dep_idx - 1 , onednn::get_default_data_format (in_scale), false ,
@@ -1660,6 +1669,7 @@ void program_node::create_onednn_primitive_attributes(
1660
1669
update_onednn_post_op_list (onednn_post_op_type::eltwise_linear, empty_mem);
1661
1670
} else {
1662
1671
auto in_shift = get_input_layout (dep_idx++);
1672
+ resize_layout_for_fc (this , in_shift);
1663
1673
dnnl::memory::desc in_shift_desc = onednn::layout_to_memory_desc (in_shift, onednn::get_default_data_format (in_shift));
1664
1674
post_ops.append_binary (dnnl::algorithm::binary_add, in_shift_desc);
1665
1675
update_onednn_post_op_list (onednn_post_op_type::binary_add, dep_idx - 1 , onednn::get_default_data_format (in_shift), false ,
@@ -1692,6 +1702,7 @@ void program_node::create_onednn_primitive_attributes(
1692
1702
update_onednn_post_op_list (onednn_post_op_type::eltwise_linear, empty_mem);
1693
1703
} else {
1694
1704
auto out_scale = get_input_layout (dep_idx++);
1705
+ resize_layout_for_fc (this , out_scale);
1695
1706
dnnl::memory::desc out_scale_desc = onednn::layout_to_memory_desc (out_scale, onednn::get_default_data_format (out_scale));
1696
1707
post_ops.append_binary (dnnl::algorithm::binary_mul, out_scale_desc);
1697
1708
update_onednn_post_op_list (onednn_post_op_type::binary_mul, dep_idx - 1 , onednn::get_default_data_format (out_scale), false ,
@@ -1705,6 +1716,7 @@ void program_node::create_onednn_primitive_attributes(
1705
1716
update_onednn_post_op_list (onednn_post_op_type::eltwise_linear, empty_mem);
1706
1717
} else {
1707
1718
auto out_shift = get_input_layout (dep_idx++);
1719
+ resize_layout_for_fc (this , out_shift);
1708
1720
dnnl::memory::desc out_shift_desc = onednn::layout_to_memory_desc (out_shift, onednn::get_default_data_format (out_shift));
1709
1721
post_ops.append_binary (dnnl::algorithm::binary_add, out_shift_desc);
1710
1722
update_onednn_post_op_list (onednn_post_op_type::binary_add, dep_idx - 1 , onednn::get_default_data_format (out_shift), false ,
0 commit comments