@@ -584,12 +584,14 @@ void batch_norm_elementwise(
584584 batch_norm_elemt_channels_first_template<
585585 scalar_t ,
586586 accscalar_t ,
587- int32_t >(out, self, *weight, *bias, mean_, invstd_);
587+ int32_t >(
588+ out, self.contiguous (), *weight, *bias, mean_, invstd_);
588589 } else {
589590 batch_norm_elemt_channels_first_template<
590591 scalar_t ,
591592 scalar_t ,
592- int32_t >(out, self, *weight, *bias, mean_, invstd_);
593+ int32_t >(
594+ out, self.contiguous (), *weight, *bias, mean_, invstd_);
593595 }
594596 });
595597 return ;
@@ -607,7 +609,16 @@ void batch_norm_elementwise(
607609 (!mean_.defined () || mean_.is_contiguous ()) &&
608610 (!invstd_.defined () || invstd_.is_contiguous ())) {
609611 batch_norm_elemt_channels_last_template (
610- out, self, *weight, *bias, mean_, invstd_);
612+ out,
613+ // It is a WA to fix Mobile-SSD convergence issue.
614+ // TODO: Fully support: Check and convert activations with any
615+ // shapes to align with kernel required memory layout.
616+ self.dim () == 4 ? self.contiguous (at::MemoryFormat::ChannelsLast)
617+ : self,
618+ *weight,
619+ *bias,
620+ mean_,
621+ invstd_);
611622 return ;
612623 }
613624 }
@@ -2858,21 +2869,43 @@ Tensor batch_norm_elementwise_backward_train(
28582869 scalar_t ,
28592870 accscalar_t ,
28602871 int32_t >(
2861- grad_out, input, mean, invstd, weight, sum_dy, sum_dy_xmu);
2872+ grad_out.contiguous (),
2873+ input.contiguous (),
2874+ mean,
2875+ invstd,
2876+ weight,
2877+ sum_dy,
2878+ sum_dy_xmu);
28622879 } else {
28632880 return batch_norm_backward_elemt_channels_first_template<
28642881 scalar_t ,
28652882 scalar_t ,
28662883 int32_t >(
2867- grad_out, input, mean, invstd, weight, sum_dy, sum_dy_xmu);
2884+ grad_out.contiguous (),
2885+ input.contiguous (),
2886+ mean,
2887+ invstd,
2888+ weight,
2889+ sum_dy,
2890+ sum_dy_xmu);
28682891 }
28692892 });
28702893 }
28712894 case Impl::ChannelsLast: {
28722895 if ((!weight.defined () || weight.is_contiguous ()) &&
28732896 mean.is_contiguous () && invstd.is_contiguous ()) {
28742897 return batch_norm_backward_elemt_channels_last_template (
2875- grad_out, input, mean, invstd, weight, sum_dy, sum_dy_xmu);
2898+ // It is a WA to fix Mobile-SSD convergence issue.
2899+ grad_out.dim () == 4
2900+ ? grad_out.contiguous (at::MemoryFormat::ChannelsLast)
2901+ : grad_out,
2902+ input.dim () == 4 ? input.contiguous (at::MemoryFormat::ChannelsLast)
2903+ : input,
2904+ mean,
2905+ invstd,
2906+ weight,
2907+ sum_dy,
2908+ sum_dy_xmu);
28762909 }
28772910 }
28782911 case Impl::General: {
@@ -3091,7 +3124,18 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> batch_norm_backward_reduce_dispatch(
30913124 (!weight.defined () || weight.is_contiguous ()) && mean.is_contiguous () &&
30923125 invstd.is_contiguous ()) {
30933126 return batch_norm_backward_reduce_channels_last_template (
3094- grad_output, input, mean, invstd, weight, input_g, weight_g, bias_g);
3127+ // It is a WA to fix Mobile-SSD convergence issue.
3128+ grad_output.dim () == 4
3129+ ? grad_output.contiguous (at::MemoryFormat::ChannelsLast)
3130+ : grad_output,
3131+ input.dim () == 4 ? input.contiguous (at::MemoryFormat::ChannelsLast)
3132+ : input,
3133+ mean,
3134+ invstd,
3135+ weight,
3136+ input_g,
3137+ weight_g,
3138+ bias_g);
30953139 }
30963140 return IPEX_DISPATCH_FLOATING_TYPES_AND2 (
30973141 kHalf ,
@@ -3282,8 +3326,8 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> native_batch_norm_backward(
32823326 scalar_t ,
32833327 accscalar_t ,
32843328 int32_t >(
3285- grad_output,
3286- input,
3329+ grad_output. contiguous () ,
3330+ input. contiguous () ,
32873331 *weight,
32883332 *running_mean,
32893333 *running_var,
@@ -3297,8 +3341,8 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> native_batch_norm_backward(
32973341 scalar_t ,
32983342 scalar_t ,
32993343 int32_t >(
3300- grad_output,
3301- input,
3344+ grad_output. contiguous () ,
3345+ input. contiguous () ,
33023346 *weight,
33033347 *running_mean,
33043348 *running_var,
@@ -3913,7 +3957,17 @@ Tensor batch_norm_backward_elemt_dispatch(
39133957 batch_norm_use_channels_last_kernels (self) &&
39143958 batch_norm_use_channels_last_kernels (input)) {
39153959 return batch_norm_backward_elemt_channels_last_template (
3916- self, input, mean, invstd, weight, sum_dy, sum_dy_xmu, count);
3960+ // It is a WA to fix Mobile-SSD convergence issue.
3961+ self.dim () == 4 ? self.contiguous (at::MemoryFormat::ChannelsLast)
3962+ : self,
3963+ input.dim () == 4 ? input.contiguous (at::MemoryFormat::ChannelsLast)
3964+ : input,
3965+ mean,
3966+ invstd,
3967+ weight,
3968+ sum_dy,
3969+ sum_dy_xmu,
3970+ count);
39173971 }
39183972
39193973 return IPEX_DISPATCH_FLOATING_TYPES_AND2 (
@@ -3938,27 +3992,55 @@ Tensor batch_norm_backward_elemt_dispatch(
39383992 scalar_t ,
39393993 accscalar_t ,
39403994 int32_t >(
3941- self, input, mean, invstd, weight, sum_dy, sum_dy_xmu, count);
3995+ self.contiguous (),
3996+ input.contiguous (),
3997+ mean,
3998+ invstd,
3999+ weight,
4000+ sum_dy,
4001+ sum_dy_xmu,
4002+ count);
39424003 } else {
39434004 return batch_norm_backward_elemt_channels_first_template<
39444005 scalar_t ,
39454006 scalar_t ,
39464007 int32_t >(
3947- self, input, mean, invstd, weight, sum_dy, sum_dy_xmu, count);
4008+ self.contiguous (),
4009+ input.contiguous (),
4010+ mean,
4011+ invstd,
4012+ weight,
4013+ sum_dy,
4014+ sum_dy_xmu,
4015+ count);
39484016 }
39494017 } else {
39504018 if (is_half_float || is_bfloat16_float) {
39514019 return batch_norm_backward_elemt_channels_first_template<
39524020 scalar_t ,
39534021 accscalar_t ,
39544022 int64_t >(
3955- self, input, mean, invstd, weight, sum_dy, sum_dy_xmu, count);
4023+ self.contiguous (),
4024+ input.contiguous (),
4025+ mean,
4026+ invstd,
4027+ weight,
4028+ sum_dy,
4029+ sum_dy_xmu,
4030+ count);
39564031 } else {
39574032 return batch_norm_backward_elemt_channels_first_template<
39584033 scalar_t ,
39594034 scalar_t ,
39604035 int64_t >(
3961- self, input, mean, invstd, weight, sum_dy, sum_dy_xmu, count);
4036+ self.contiguous (),
4037+ input.contiguous (),
4038+ mean,
4039+ invstd,
4040+ weight,
4041+ sum_dy,
4042+ sum_dy_xmu,
4043+ count);
39624044 }
39634045 }
39644046 });
0 commit comments