@@ -58,34 +58,9 @@ at::Tensor AtenIpexCPUDev::dil_convolution(
5858 dil_bias = dbl::comm::try_gen_dil_tensor (bias);
5959 }
6060
61- // Prepack weight tensor if it's either a *cpu tensor* or a *plain dil tensor*
62- //
63- // Note: weight tensor will not be re-packed unless user has implicitly
64- // triggered `to_public` by accessing its data
65- // One caveat is when the input size has changed and prepacked weight
66- // might not be the best fit for new input size, the weight will not
67- // be re-packed in such cases, but it still ensures the correctness
68- //
69- // TODO: once semantics of "own shade context" is equivalent to
70- // "is dil tensor", we could remove the first check below
7161 dbl::comm::reorder_to_bf16_for_mix_prec (weight);
72- if (!check_tensor_own_shade_context (weight) ||
73- !cpu::ShadeDataContext::isDilOwnTheTensor (weight) ||
74- cpu::ShadeDataContext::getDilTensor (weight).is_public_format ()) {
75- auto packed_desc = dil::convolution_forward::expected_weights_desc (
76- weight.sizes ().vec (),
77- dil_input.get_data_type (),
78- stride.vec (),
79- padding.vec (),
80- padding.vec (),
81- dilation.vec (),
82- groups,
83- dil::algorithm::convolution_direct,
84- dil::prop_kind::forward,
85- dil_input.get_data_type (),
86- input.sizes ().vec ());
87- dbl::comm::reorder_to_desc (weight, packed_desc);
88- }
62+ dbl::conv::prepack_conv_weights (input, dil_input,
63+ weight, stride, padding, dilation, groups);
8964 dil_weight = dbl::comm::try_gen_dil_tensor (weight);
9065
9166 dil::tensor dil_output = dbl::conv::conv2d_impl (
@@ -133,7 +108,8 @@ std::tuple<at::Tensor, at::Tensor> dil_convolution_backward_weights(
133108 const dil::tensor dil_input = dbl::comm::try_gen_dil_tensor (input);
134109
135110 dil::tensor dil_grad_weight, dil_grad_bias;
136- auto diff_weight_type = get_dil_data_type (weight.scalar_type ());
111+ dil::tensor w = dbl::comm::try_gen_dil_tensor (weight);
112+ auto diff_weight_type = w.get_data_type ();
137113 auto weight_size = weight.sizes ();
138114
139115 if (bias_defined) {
@@ -176,7 +152,8 @@ std::tuple<at::Tensor,at::Tensor,at::Tensor> AtenIpexCPUDev::dil_convolution_bac
176152{
177153 DEBUG (" AtenIpexCPUDev::dil_convolution_backward\n " );
178154 at::Tensor grad_output = grad_output_t .is_contiguous () ? grad_output_t : grad_output_t .contiguous ();
179-
155+ CHECK_DNNL_OP_PRE_COND (input);
156+ CHECK_DNNL_OP_PRE_COND (weight);
180157 dbl::comm::reorder_to_bf16_for_mix_prec (input);
181158 dbl::comm::reorder_to_bf16_for_mix_prec (grad_output);
182159 dbl::comm::reorder_to_bf16_for_mix_prec (weight);
@@ -667,7 +644,7 @@ at::Tensor AtenIpexCPUDev::dil_linear(
667644 dbl::comm::reorder_to_bf16_for_mix_prec (weight);
668645
669646 // reshape first if input dim is greater than 2 and the reshape will cost a memory copy.
670- auto self_reshaped = self.dim () > 2 ? self. reshape ( {-1 , self.size (self.dim () - 1 )}) : self;
647+ auto self_reshaped = self.dim () > 2 ? dil_reshape (self, {-1 , self.size (self.dim () - 1 )}) : self;
671648 const dil::tensor x = dbl::comm::try_gen_dil_tensor (self_reshaped);
672649 const dil::tensor w = dbl::comm::try_gen_dil_tensor (weight);
673650
@@ -704,7 +681,7 @@ at::Tensor AtenIpexCPUDev::dil_linear_fuse_relu(
704681 dbl::comm::reorder_to_bf16_for_mix_prec (weight);
705682
706683 // reshape first if input dim is greater than 2 and the reshape will cost a memory copy.
707- auto self_reshaped = self.dim () > 2 ? self. reshape ( {-1 , self.size (self.dim () - 1 )}) : self;
684+ auto self_reshaped = self.dim () > 2 ? dil_reshape (self, {-1 , self.size (self.dim () - 1 )}) : self;
708685 const dil::tensor x = dbl::comm::try_gen_dil_tensor (self_reshaped);
709686 const dil::tensor w = dbl::comm::try_gen_dil_tensor (weight);
710687
@@ -740,11 +717,13 @@ at::Tensor AtenIpexCPUDev::dil_linear_backward_input(
740717 at::IntArrayRef input_size, const at::Tensor& grad_output, const at::Tensor& weight){
741718 DEBUG (" AtenIpexCPUDev::dil_linear_backward_input\n " );
742719
720+ CHECK_DNNL_OP_PRE_COND (grad_output);
721+ CHECK_DNNL_OP_PRE_COND (weight);
743722 dbl::comm::reorder_to_bf16_for_mix_prec (grad_output);
744723 dbl::comm::reorder_to_bf16_for_mix_prec (weight);
745724
746725 auto grad_output_reshaped = grad_output.dim () > 2 ?
747- grad_output. reshape ( {-1 , grad_output.size (grad_output.dim () - 1 )}) : grad_output;
726+ dil_reshape (grad_output, {-1 , grad_output.size (grad_output.dim () - 1 )}) : grad_output;
748727 dil::tensor grady = dbl::comm::try_gen_dil_tensor (grad_output_reshaped);
749728 const dil::tensor w = dbl::comm::try_gen_dil_tensor (weight);
750729
@@ -766,17 +745,22 @@ std::tuple<at::Tensor, at::Tensor> AtenIpexCPUDev::dil_linear_backward_weights(
766745 const at::Tensor& grad_output, const at::Tensor& input, const at::Tensor& weight, bool bias_defined) {
767746 DEBUG (" AtenIpexCPUDev::dil_linear_backward_weights\n " );
768747
748+ CHECK_DNNL_OP_PRE_COND (input);
749+ CHECK_DNNL_OP_PRE_COND (grad_output);
750+ CHECK_DNNL_OP_PRE_COND (weight);
769751 dbl::comm::reorder_to_bf16_for_mix_prec (grad_output);
770752 dbl::comm::reorder_to_bf16_for_mix_prec (input);
771753 dbl::comm::reorder_to_bf16_for_mix_prec (weight);
772754
773755 auto grad_output_reshaped = grad_output.dim () > 2 ?
774- grad_output. reshape ( {-1 , grad_output.size (grad_output.dim () - 1 )}) : grad_output;
775- auto input_reshaped = input.dim () > 2 ? input. reshape ( {-1 , input.size (input.dim () - 1 )}) : input;
756+ dil_reshape (grad_output, {-1 , grad_output.size (grad_output.dim () - 1 )}) : grad_output;
757+ auto input_reshaped = input.dim () > 2 ? dil_reshape (input, {-1 , input.size (input.dim () - 1 )}) : input;
776758
777759 dil::tensor grady = dbl::comm::try_gen_dil_tensor (grad_output_reshaped);
778760 dil::tensor x = dbl::comm::try_gen_dil_tensor (input_reshaped);
779- auto diff_weight_type = get_dil_data_type (weight.scalar_type ());
761+ dil::tensor w = dbl::comm::try_gen_dil_tensor (weight);
762+ auto diff_weight_type = w.get_data_type ();
763+
780764 dil::tensor gradw, gradb;
781765 if (bias_defined) {
782766 dil::inner_product_backward_weights::compute (x, grady, gradw, gradb, diff_weight_type);
@@ -795,13 +779,6 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> AtenIpexCPUDev::dil_linear_backwa
795779 const at::Tensor& input, const at::Tensor& grad_output,
796780 const at::Tensor& weight, std::array<bool ,3 > output_mask) {
797781 DEBUG (" AtenIpexCPUDev::dil_linear_backward\n " );
798- CHECK_DNNL_OP_PRE_COND (input);
799- CHECK_DNNL_OP_PRE_COND (grad_output);
800- CHECK_DNNL_OP_PRE_COND (weight);
801-
802- dbl::comm::reorder_to_bf16_for_mix_prec (grad_output);
803- dbl::comm::reorder_to_bf16_for_mix_prec (input);
804- dbl::comm::reorder_to_bf16_for_mix_prec (weight);
805782
806783 at::Tensor grad_input, grad_weight, grad_bias;
807784 if (output_mask[0 ]) {
@@ -1304,10 +1281,9 @@ at::Tensor AtenIpexCPUDev::dil__softmax_backward_data(
13041281
13051282at::Tensor AtenIpexCPUDev::dil_sigmoid (const at::Tensor& self) {
13061283 DEBUG (" AtenIpexCPUDev::dil_sigmoid\n " );
1307-
1284+ CHECK_DNNL_OP_PRE_COND (self);
13081285 dbl::comm::reorder_to_bf16_for_mix_prec (self);
13091286
1310- CHECK_DNNL_OP_PRE_COND (self);
13111287 dil::tensor x = dbl::comm::try_gen_dil_tensor (self);
13121288 dil::tensor y;
13131289 dil::eltwise_forward::compute (
0 commit comments