Skip to content

Commit 486c416

Browse files
jiayisunxEikanWang
authored andcommitted
modify dil__softmax_backward_data
1 parent 6618ad5 commit 486c416

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

torch_ipex/csrc/cpu/dil/dil/operators/softmax.hpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,15 @@ struct softmax_backward : public dnnl::softmax_backward {
3333
int softmax_axis,
3434
const engine& aengine = engine::cpu_engine()) {
3535

36+
auto dst_desc = dst.get_desc();
37+
// align data type of diff_dst with dst
38+
auto diff_dst_desc = diff_dst.get_desc().to_type(dst_desc.get_data_type());
39+
3640
auto forward_hints = softmax_forward::primitive_desc(
37-
{prop_kind::forward_inference, dst.get_desc(), softmax_axis}, aengine);
41+
{prop_kind::forward_inference, dst_desc, softmax_axis}, aengine);
3842

3943
auto pd =
40-
primitive_desc({diff_dst.get_desc(), dst.get_desc(), softmax_axis},
44+
primitive_desc({diff_dst_desc, dst_desc, softmax_axis},
4145
aengine, forward_hints);
4246
auto expected_dst = dst.reorder_if_differ_in(pd.dst_desc());
4347
auto expected_diff_dst = diff_dst.reorder_if_differ_in(pd.diff_dst_desc());

0 commit comments

Comments
 (0)