File tree Expand file tree Collapse file tree 1 file changed +6
-2
lines changed
torch_ipex/csrc/cpu/dil/dil/operators Expand file tree Collapse file tree 1 file changed +6
-2
lines changed Original file line number Diff line number Diff line change @@ -33,11 +33,15 @@ struct softmax_backward : public dnnl::softmax_backward {
3333 int softmax_axis,
3434 const engine& aengine = engine::cpu_engine()) {
3535
36+ auto dst_desc = dst.get_desc ();
37+ // align data type of diff_dst with dst
38+ auto diff_dst_desc = diff_dst.get_desc ().to_type (dst_desc.get_data_type ());
39+
3640 auto forward_hints = softmax_forward::primitive_desc (
37- {prop_kind::forward_inference, dst. get_desc () , softmax_axis}, aengine);
41+ {prop_kind::forward_inference, dst_desc , softmax_axis}, aengine);
3842
3943 auto pd =
40- primitive_desc ({diff_dst. get_desc (), dst. get_desc () , softmax_axis},
44+ primitive_desc ({diff_dst_desc, dst_desc , softmax_axis},
4145 aengine, forward_hints);
4246 auto expected_dst = dst.reorder_if_differ_in (pd.dst_desc ());
4347 auto expected_diff_dst = diff_dst.reorder_if_differ_in (pd.diff_dst_desc ());
You can’t perform that action at this time.
0 commit comments