Skip to content

Commit dddc5d9

Browse files
author
zhangkaihuo
authored
[cherry-pick]BatchNorm use inplace (#49529)
att, cherry-pick#48254, and resolve conflict
1 parent 34fafb1 commit dddc5d9

File tree

5 files changed

+25
-23
lines changed

5 files changed

+25
-23
lines changed

paddle/phi/api/yaml/generator/generate_sparse_op.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,8 @@ def main(
8484
backward_api_dict = to_named_dict(backward_apis)
8585

8686
for api in apis:
87+
if api['name'][-1] == '_':
88+
api['name'] = api['name'][:-1]
8789
api['op_name'] = SPARSE_OP_PREFIX + api['name']
8890
api['name'] = api['op_name']
8991
if api["backward"] is not None:

paddle/phi/api/yaml/sparse_backward.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@
101101
atanh_csr_grad {sparse_csr, sparse_csr -> sparse_csr}
102102

103103
- backward_op : batch_norm_grad
104-
forward : batch_norm (Tensor x, Tensor scale, Tensor bias, Tensor mean, Tensor variance, float momentum, float epsilon, str data_layout, bool is_test, bool use_global_stats, bool trainable_statistics, bool fuse_with_relu) -> Tensor(out), Tensor(mean_out), Tensor(variance_out), Tensor(saved_mean), Tensor(saved_variance), Tensor(reserve_space)
104+
forward : batch_norm_ (Tensor x, Tensor scale, Tensor bias, Tensor mean, Tensor variance, float momentum, float epsilon, str data_layout, bool is_test, bool use_global_stats, bool trainable_statistics, bool fuse_with_relu) -> Tensor(out), Tensor(mean_out), Tensor(variance_out), Tensor(saved_mean), Tensor(saved_variance), Tensor(reserve_space)
105105
args : (Tensor x, Tensor scale, Tensor bias, Tensor mean_out, Tensor variance_out, Tensor saved_mean, Tensor saved_variance, Tensor reserve_space, Tensor out_grad, float momentum, float epsilon, str data_layout, bool is_test, bool use_global_stats, bool trainable_statistics, bool fuse_with_relu)
106106
output : Tensor(x_grad), Tensor(scale_grad), Tensor(bias_grad)
107107
infer_meta :

paddle/phi/api/yaml/sparse_ops.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,15 +87,15 @@
8787
layout : x
8888
backward : atanh_grad
8989

90-
- op : batch_norm
90+
- op : batch_norm_
9191
args : (Tensor x, Tensor scale, Tensor bias, Tensor mean, Tensor variance, float momentum, float epsilon, str data_layout, bool is_test, bool use_global_stats, bool trainable_statistics, bool fuse_with_relu)
9292
output : Tensor(out), Tensor(mean_out), Tensor(variance_out), Tensor(saved_mean), Tensor(saved_variance), Tensor(reserve_space)
9393
infer_meta :
9494
func : BatchNormInferMeta
9595
kernel :
9696
func : batch_norm_coo {sparse_coo, dense, dense, dense, dense -> sparse_coo, dense, dense, dense, dense, dense}
9797
data_type : x
98-
view : (mean -> mean_out), (variance -> variance_out)
98+
inplace : (mean -> mean_out), (variance -> variance_out)
9999
backward : batch_norm_grad
100100

101101
- op : cast

paddle/phi/kernels/sparse/batch_norm_kernel.h

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -23,25 +23,25 @@ namespace phi {
2323
namespace sparse {
2424

2525
template <typename T, typename Context>
26-
void BatchNormKernel(const Context& dev_ctx,
27-
const SparseCooTensor& x,
28-
const DenseTensor& scale,
29-
const DenseTensor& bias,
30-
const DenseTensor& mean,
31-
const DenseTensor& variance,
32-
float momentum,
33-
float epsilon,
34-
const std::string& data_layout,
35-
bool is_test,
36-
bool use_global_stats,
37-
bool trainable_statistics,
38-
bool fuse_with_relu,
39-
SparseCooTensor* y,
40-
DenseTensor* mean_out,
41-
DenseTensor* variance_out,
42-
DenseTensor* saved_mean,
43-
DenseTensor* saved_variance,
44-
DenseTensor* reserve_space);
26+
void BatchNormCooKernel(const Context& dev_ctx,
27+
const SparseCooTensor& x,
28+
const DenseTensor& scale,
29+
const DenseTensor& bias,
30+
const DenseTensor& mean,
31+
const DenseTensor& variance,
32+
float momentum,
33+
float epsilon,
34+
const std::string& data_layout,
35+
bool is_test,
36+
bool use_global_stats,
37+
bool trainable_statistics,
38+
bool fuse_with_relu,
39+
SparseCooTensor* y,
40+
DenseTensor* mean_out,
41+
DenseTensor* variance_out,
42+
DenseTensor* saved_mean,
43+
DenseTensor* saved_variance,
44+
DenseTensor* reserve_space);
4545

4646
} // namespace sparse
4747
} // namespace phi

python/paddle/sparse/nn/layer/norm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ def forward(self, input):
138138
data_format = 'NCHW' if self._data_format[1] == 'C' else 'NHWC'
139139

140140
if in_dynamic_mode():
141-
batch_norm_out, _, _, _, _, _ = _C_ops.sparse_batch_norm(
141+
batch_norm_out, _, _, _, _, _ = _C_ops.sparse_batch_norm_(
142142
input,
143143
self.weight,
144144
self.bias,

0 commit comments

Comments
 (0)