@@ -5208,6 +5208,90 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> batch_norm_backward_reduce(
52085208 grad_output, input, mean, invstd, weight_opt, input_g, weight_g, bias_g);
52095209}
52105210
5211+ #ifdef USE_OVERRIDE_OP
5212+ // Rename below functions because they have overload with the same name
5213+ // and can't be registered.
5214+ std::tuple<Tensor, Tensor, Tensor> _native_batch_norm_legit_ (
5215+ const Tensor& self,
5216+ const c10::optional<Tensor>& weight_opt,
5217+ const c10::optional<Tensor>& bias_opt,
5218+ Tensor& running_mean,
5219+ Tensor& running_var,
5220+ bool train,
5221+ double momentum,
5222+ double epsilon) {
5223+ return at::AtenIpexTypeXPU::_native_batch_norm_legit (
5224+ self,
5225+ weight_opt,
5226+ bias_opt,
5227+ running_mean,
5228+ running_var,
5229+ train,
5230+ momentum,
5231+ epsilon);
5232+ }
5233+
5234+ std::tuple<Tensor, Tensor, Tensor> _native_batch_norm_legit_no_state (
5235+ const Tensor& self,
5236+ const c10::optional<Tensor>& weight_opt,
5237+ const c10::optional<Tensor>& bias_opt,
5238+ bool train,
5239+ double momentum,
5240+ double epsilon) {
5241+ return at::AtenIpexTypeXPU::_native_batch_norm_legit (
5242+ self, weight_opt, bias_opt, train, momentum, epsilon);
5243+ }
5244+
5245+ std::tuple<Tensor&, Tensor&, Tensor&> _native_batch_norm_legit_out_ (
5246+ const Tensor& self,
5247+ const c10::optional<Tensor>& weight_opt,
5248+ const c10::optional<Tensor>& bias_opt,
5249+ Tensor& running_mean,
5250+ Tensor& running_var,
5251+ bool train,
5252+ double momentum,
5253+ double epsilon,
5254+ Tensor& output,
5255+ Tensor& save_mean,
5256+ Tensor& save_invstd) {
5257+ return at::AtenIpexTypeXPU::_native_batch_norm_legit_out (
5258+ self,
5259+ weight_opt,
5260+ bias_opt,
5261+ running_mean,
5262+ running_var,
5263+ train,
5264+ momentum,
5265+ epsilon,
5266+ output,
5267+ save_mean,
5268+ save_invstd);
5269+ }
5270+
5271+ std::tuple<Tensor&, Tensor&, Tensor&> _native_batch_norm_legit_no_state_out (
5272+ const Tensor& self,
5273+ const c10::optional<Tensor>& weight_opt,
5274+ const c10::optional<Tensor>& bias_opt,
5275+ bool train,
5276+ double momentum,
5277+ double epsilon,
5278+ Tensor& output,
5279+ Tensor& save_mean,
5280+ Tensor& save_invstd) {
5281+ return at::AtenIpexTypeXPU::_native_batch_norm_legit_out (
5282+ self,
5283+ weight_opt,
5284+ bias_opt,
5285+ train,
5286+ momentum,
5287+ epsilon,
5288+ output,
5289+ save_mean,
5290+ save_invstd);
5291+ }
5292+
5293+ #endif
5294+
52115295} // namespace AtenIpexTypeXPU
52125296} // namespace at
52135297
@@ -5223,6 +5307,18 @@ IPEX_TORCH_LIBRARY_IMPL(aten, XPU, m) {
52235307 m.impl (
52245308 " native_batch_norm_backward" ,
52255309 TORCH_FN ((&at::AtenIpexTypeXPU::native_batch_norm_backward)));
5310+ m.impl (
5311+ " _native_batch_norm_legit" ,
5312+ TORCH_FN ((&at::AtenIpexTypeXPU::_native_batch_norm_legit_)));
5313+ m.impl (
5314+ " _native_batch_norm_legit.out" ,
5315+ TORCH_FN ((&at::AtenIpexTypeXPU::_native_batch_norm_legit_out_)));
5316+ m.impl (
5317+ " _native_batch_norm_legit.no_stats" ,
5318+ TORCH_FN ((&at::AtenIpexTypeXPU::_native_batch_norm_legit_no_state)));
5319+ m.impl (
5320+ " _native_batch_norm_legit.no_stats_out" ,
5321+ TORCH_FN ((&at::AtenIpexTypeXPU::_native_batch_norm_legit_no_state_out)));
52265322}
52275323
52285324} // namespace
0 commit comments