Skip to content

Commit a1e2271

Browse files
chunhuanMengtye1fengyuan14
authored
BatchNormalization: SYCL: convert memory format to align with SYCL kernel assumption (#3857) (#3882)
* check and transform format * Update BatchNorm.cpp * Update BatchNorm * add comments --------- Signed-off-by: Feng Yuan <[email protected]> Co-authored-by: Ye Ting <[email protected]> Co-authored-by: Feng Yuan <[email protected]>
1 parent 1eef60d commit a1e2271

File tree

1 file changed

+98
-16
lines changed

1 file changed

+98
-16
lines changed

csrc/gpu/aten/operators/BatchNorm.cpp

Lines changed: 98 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -584,12 +584,14 @@ void batch_norm_elementwise(
584584
batch_norm_elemt_channels_first_template<
585585
scalar_t,
586586
accscalar_t,
587-
int32_t>(out, self, *weight, *bias, mean_, invstd_);
587+
int32_t>(
588+
out, self.contiguous(), *weight, *bias, mean_, invstd_);
588589
} else {
589590
batch_norm_elemt_channels_first_template<
590591
scalar_t,
591592
scalar_t,
592-
int32_t>(out, self, *weight, *bias, mean_, invstd_);
593+
int32_t>(
594+
out, self.contiguous(), *weight, *bias, mean_, invstd_);
593595
}
594596
});
595597
return;
@@ -607,7 +609,16 @@ void batch_norm_elementwise(
607609
(!mean_.defined() || mean_.is_contiguous()) &&
608610
(!invstd_.defined() || invstd_.is_contiguous())) {
609611
batch_norm_elemt_channels_last_template(
610-
out, self, *weight, *bias, mean_, invstd_);
612+
out,
613+
// It is a WA to fix Mobile-SSD convergence issue.
614+
// TODO: Fully support: Check and convert activations with any
615+
// shapes to align with kernel required memory layout.
616+
self.dim() == 4 ? self.contiguous(at::MemoryFormat::ChannelsLast)
617+
: self,
618+
*weight,
619+
*bias,
620+
mean_,
621+
invstd_);
611622
return;
612623
}
613624
}
@@ -2858,21 +2869,43 @@ Tensor batch_norm_elementwise_backward_train(
28582869
scalar_t,
28592870
accscalar_t,
28602871
int32_t>(
2861-
grad_out, input, mean, invstd, weight, sum_dy, sum_dy_xmu);
2872+
grad_out.contiguous(),
2873+
input.contiguous(),
2874+
mean,
2875+
invstd,
2876+
weight,
2877+
sum_dy,
2878+
sum_dy_xmu);
28622879
} else {
28632880
return batch_norm_backward_elemt_channels_first_template<
28642881
scalar_t,
28652882
scalar_t,
28662883
int32_t>(
2867-
grad_out, input, mean, invstd, weight, sum_dy, sum_dy_xmu);
2884+
grad_out.contiguous(),
2885+
input.contiguous(),
2886+
mean,
2887+
invstd,
2888+
weight,
2889+
sum_dy,
2890+
sum_dy_xmu);
28682891
}
28692892
});
28702893
}
28712894
case Impl::ChannelsLast: {
28722895
if ((!weight.defined() || weight.is_contiguous()) &&
28732896
mean.is_contiguous() && invstd.is_contiguous()) {
28742897
return batch_norm_backward_elemt_channels_last_template(
2875-
grad_out, input, mean, invstd, weight, sum_dy, sum_dy_xmu);
2898+
// It is a WA to fix Mobile-SSD convergence issue.
2899+
grad_out.dim() == 4
2900+
? grad_out.contiguous(at::MemoryFormat::ChannelsLast)
2901+
: grad_out,
2902+
input.dim() == 4 ? input.contiguous(at::MemoryFormat::ChannelsLast)
2903+
: input,
2904+
mean,
2905+
invstd,
2906+
weight,
2907+
sum_dy,
2908+
sum_dy_xmu);
28762909
}
28772910
}
28782911
case Impl::General: {
@@ -3091,7 +3124,18 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> batch_norm_backward_reduce_dispatch(
30913124
(!weight.defined() || weight.is_contiguous()) && mean.is_contiguous() &&
30923125
invstd.is_contiguous()) {
30933126
return batch_norm_backward_reduce_channels_last_template(
3094-
grad_output, input, mean, invstd, weight, input_g, weight_g, bias_g);
3127+
// It is a WA to fix Mobile-SSD convergence issue.
3128+
grad_output.dim() == 4
3129+
? grad_output.contiguous(at::MemoryFormat::ChannelsLast)
3130+
: grad_output,
3131+
input.dim() == 4 ? input.contiguous(at::MemoryFormat::ChannelsLast)
3132+
: input,
3133+
mean,
3134+
invstd,
3135+
weight,
3136+
input_g,
3137+
weight_g,
3138+
bias_g);
30953139
}
30963140
return IPEX_DISPATCH_FLOATING_TYPES_AND2(
30973141
kHalf,
@@ -3282,8 +3326,8 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> native_batch_norm_backward(
32823326
scalar_t,
32833327
accscalar_t,
32843328
int32_t>(
3285-
grad_output,
3286-
input,
3329+
grad_output.contiguous(),
3330+
input.contiguous(),
32873331
*weight,
32883332
*running_mean,
32893333
*running_var,
@@ -3297,8 +3341,8 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> native_batch_norm_backward(
32973341
scalar_t,
32983342
scalar_t,
32993343
int32_t>(
3300-
grad_output,
3301-
input,
3344+
grad_output.contiguous(),
3345+
input.contiguous(),
33023346
*weight,
33033347
*running_mean,
33043348
*running_var,
@@ -3913,7 +3957,17 @@ Tensor batch_norm_backward_elemt_dispatch(
39133957
batch_norm_use_channels_last_kernels(self) &&
39143958
batch_norm_use_channels_last_kernels(input)) {
39153959
return batch_norm_backward_elemt_channels_last_template(
3916-
self, input, mean, invstd, weight, sum_dy, sum_dy_xmu, count);
3960+
// It is a WA to fix Mobile-SSD convergence issue.
3961+
self.dim() == 4 ? self.contiguous(at::MemoryFormat::ChannelsLast)
3962+
: self,
3963+
input.dim() == 4 ? input.contiguous(at::MemoryFormat::ChannelsLast)
3964+
: input,
3965+
mean,
3966+
invstd,
3967+
weight,
3968+
sum_dy,
3969+
sum_dy_xmu,
3970+
count);
39173971
}
39183972

39193973
return IPEX_DISPATCH_FLOATING_TYPES_AND2(
@@ -3938,27 +3992,55 @@ Tensor batch_norm_backward_elemt_dispatch(
39383992
scalar_t,
39393993
accscalar_t,
39403994
int32_t>(
3941-
self, input, mean, invstd, weight, sum_dy, sum_dy_xmu, count);
3995+
self.contiguous(),
3996+
input.contiguous(),
3997+
mean,
3998+
invstd,
3999+
weight,
4000+
sum_dy,
4001+
sum_dy_xmu,
4002+
count);
39424003
} else {
39434004
return batch_norm_backward_elemt_channels_first_template<
39444005
scalar_t,
39454006
scalar_t,
39464007
int32_t>(
3947-
self, input, mean, invstd, weight, sum_dy, sum_dy_xmu, count);
4008+
self.contiguous(),
4009+
input.contiguous(),
4010+
mean,
4011+
invstd,
4012+
weight,
4013+
sum_dy,
4014+
sum_dy_xmu,
4015+
count);
39484016
}
39494017
} else {
39504018
if (is_half_float || is_bfloat16_float) {
39514019
return batch_norm_backward_elemt_channels_first_template<
39524020
scalar_t,
39534021
accscalar_t,
39544022
int64_t>(
3955-
self, input, mean, invstd, weight, sum_dy, sum_dy_xmu, count);
4023+
self.contiguous(),
4024+
input.contiguous(),
4025+
mean,
4026+
invstd,
4027+
weight,
4028+
sum_dy,
4029+
sum_dy_xmu,
4030+
count);
39564031
} else {
39574032
return batch_norm_backward_elemt_channels_first_template<
39584033
scalar_t,
39594034
scalar_t,
39604035
int64_t>(
3961-
self, input, mean, invstd, weight, sum_dy, sum_dy_xmu, count);
4036+
self.contiguous(),
4037+
input.contiguous(),
4038+
mean,
4039+
invstd,
4040+
weight,
4041+
sum_dy,
4042+
sum_dy_xmu,
4043+
count);
39624044
}
39634045
}
39644046
});

0 commit comments

Comments
 (0)