Skip to content

Commit 239dff3

Browse files
authored
Fix paddle.dist & paddle.nn.functional.normalize (#74448)
* fix: fix normalization logic in p_norm kernels * Code Formatting * zancun * fix_dist_normalize * back * huifu zhushi * pre-commit * Fix code style * Standardize code comments
1 parent 97fd314 commit 239dff3

File tree

4 files changed

+79
-28
lines changed

4 files changed

+79
-28
lines changed

paddle/phi/kernels/funcs/reduce_grad_functions.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,10 @@ void ReduceGradFunctor(const Context& dev_ctx,
3838
auto x_dims = input0.dims();
3939
auto reduced_dims_v = common::vectorize(x_dims);
4040
std::vector<int> dims_ref = dims;
41-
Eigen::array<int, D> broadcast_dim;
41+
Eigen::array<int64_t, D> broadcast_dim;
4242
for (size_t i = 0; i < D; ++i) broadcast_dim[i] = 1;
4343

44-
int broad_cast_times = 1;
44+
int64_t broad_cast_times = 1;
4545
for (size_t i = 0; i < dims_ref.size(); ++i) {
4646
if (dims_ref[i] < 0) {
4747
dims_ref[i] = x_rank + dims_ref[i];
@@ -142,7 +142,7 @@ void LaunchReduceGradKernel(const Context& dev_ctx,
142142
auto& place = *dev_ctx.eigen_device();
143143
// *dev_ctx.eigen_device();
144144
auto broadcast_dim =
145-
Eigen::array<int, 1>({{static_cast<int>(input0->numel())}});
145+
Eigen::array<int64_t, 1>({{static_cast<int64_t>(input0->numel())}});
146146
functor(place,
147147
&x,
148148
&x_reduce,

paddle/phi/kernels/gpu/dist_kernel.cu

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,18 @@ struct PowFunctor {
6363
Ty p_order_;
6464
};
6565

66+
template <typename Tx,
67+
typename Ty,
68+
typename Tout> // Tx is high precision, Tout is low/out precision
69+
struct PowFunctorHighPrecision {
70+
HOSTDEVICE explicit inline PowFunctorHighPrecision(const Ty& p_order)
71+
: p_order_(p_order) {}
72+
HOSTDEVICE inline Tx operator()(const Tx x) const {
73+
return static_cast<Tout>(pow(static_cast<Ty>(x), p_order_));
74+
}
75+
Ty p_order_;
76+
};
77+
6678
template <typename T, typename Functor>
6779
__global__ void ReduceSumWithSubtract(
6880
const T* x, const T* y, T* out, int64_t N, Functor func) {
@@ -126,16 +138,17 @@ void DistKernel(const Context& dev_ctx,
126138
DenseTensor intermediate;
127139
const T* x_ptr = x.data<T>();
128140
const T* y_ptr = y.data<T>();
141+
129142
T* o_ptr = dev_ctx.template Alloc<T>(out);
130143
auto stream = dev_ctx.stream();
131144

132145
auto xdim = x.dims();
133146
if (xdim == y.dims()) { // same shape
134-
auto n = x.numel();
147+
int64_t n = x.numel();
148+
135149
auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, n);
136150
intermediate.Resize(common::make_ddim({config.block_per_grid.x}));
137151
T* i_ptr = dev_ctx.template Alloc<T>(&intermediate);
138-
139152
std::vector<int64_t> axis_dims = {static_cast<int64_t>(-1)};
140153
std::vector<int> reduce_axis =
141154
funcs::details::GetReduceDim(axis_dims, xdim.size(), true);
@@ -166,15 +179,23 @@ void DistKernel(const Context& dev_ctx,
166179
ReduceSumWithSubtract<T>
167180
<<<config.block_per_grid.x, config.thread_per_block.x, 0, stream>>>(
168181
x_ptr, y_ptr, i_ptr, n, OtherOrderFunctor<T, MT>(p_order));
169-
phi::funcs::ReduceKernel<T, T, kps::AddFunctor, kps::IdentityFunctor<MT>>(
170-
dev_ctx, intermediate, out, kps::IdentityFunctor<MT>(), reduce_axis);
171-
172-
const DenseTensor* tmp_norm = out;
173-
std::vector<const DenseTensor*> ins = {tmp_norm};
182+
DenseTensor out_other;
183+
out_other.Resize(out->dims());
184+
dev_ctx.template Alloc<MT>(&out_other);
185+
186+
phi::funcs::
187+
ReduceKernel<T, MT, kps::AddFunctor, kps::IdentityFunctor<MT>>(
188+
dev_ctx,
189+
intermediate,
190+
&out_other,
191+
kps::IdentityFunctor<MT>(),
192+
reduce_axis);
193+
std::vector<const DenseTensor*> ins = {&out_other};
174194
std::vector<DenseTensor*> outs = {out};
175-
MT p_order_ = static_cast<MT>(static_cast<MT>(1.) / p_order);
195+
196+
MT p_order_ = static_cast<MT>(1.f / p_order);
176197
phi::funcs::ElementwiseKernel<T>(
177-
dev_ctx, ins, &outs, PowFunctor<T, MT>(p_order_));
198+
dev_ctx, ins, &outs, PowFunctorHighPrecision<MT, MT, T>(p_order_));
178199
}
179200

180201
} else {

paddle/phi/kernels/gpu/p_norm_grad_kernel.cu

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,12 @@ struct AbsMaxAndMinGradFunctor {
4242

4343
template <typename T>
4444
struct PNormGradFunctor {
45+
using MT = typename phi::dtype::MPTypeTrait<T>::Type;
4546
HOSTDEVICE explicit inline PNormGradFunctor(float porder, float eps) {
46-
this->porder = static_cast<T>(porder - 1.);
47-
this->eps = static_cast<T>(eps);
47+
this->porder = static_cast<MT>(porder - 1.);
48+
this->eps = static_cast<MT>(eps);
4849
}
50+
4951
template <typename Context,
5052
typename X,
5153
typename Y,
@@ -59,12 +61,33 @@ struct PNormGradFunctor {
5961
DY* dy,
6062
const Dim& dim,
6163
int size) {
64+
auto x_mt = x->template cast<MT>();
65+
auto y_mt = y->template cast<MT>();
66+
auto dy_mt = dy->template cast<MT>();
67+
68+
auto norm_pow = y_mt.pow(-this->porder);
69+
auto mask_norm_nonzero = (y_mt != static_cast<MT>(0)).template cast<MT>();
70+
71+
// Set to 0 where porder < 0 and x == 0
72+
MT zero = static_cast<MT>(0);
73+
auto mask_x_zero = (x_mt == zero).template cast<MT>();
74+
75+
MT is_porder_negative =
76+
this->porder < zero ? static_cast<MT>(1) : static_cast<MT>(0);
77+
auto invalid_mask = (mask_x_zero * is_porder_negative);
78+
auto safe_pow =
79+
x_mt.abs().pow(this->porder) * (static_cast<MT>(1) - invalid_mask);
80+
6281
dx->device(place) =
63-
(*x).abs().pow(this->porder) * (*x).sign() * dy->broadcast(dim) *
64-
(*y + y->constant(eps)).pow(-this->porder).broadcast(dim);
82+
(safe_pow * x_mt.sign() * dy_mt.broadcast(dim) *
83+
norm_pow.broadcast(dim) *
84+
mask_norm_nonzero.broadcast(dim) // Mask out positions where norm == 0
85+
)
86+
.template cast<T>();
6587
}
66-
T porder;
67-
T eps;
88+
89+
MT porder;
90+
MT eps;
6891
};
6992

7093
template <typename T, typename Context>

paddle/phi/kernels/gpu/p_norm_kernel.cu

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -124,31 +124,38 @@ void PNormKernel(const Context& dev_ctx,
124124
phi::funcs::ElementwiseKernel<T>(
125125
dev_ctx, ins, &outs, UnsignedPowFunctor<T>(1. / porder));
126126
#else
127+
DenseTensor out_temp;
128+
out_temp.Resize(out_norm->dims());
129+
dev_ctx.template Alloc<MT>(&out_temp);
130+
127131
if (porder == 1.0) {
128132
// fast 1-norm
129133
phi::funcs::ReduceKernel<T, T, kps::AddFunctor, FabsFunctor<T>>(
130134
dev_ctx, *in_x, out_norm, FabsFunctor<T>(), reduce_axis);
131135
} else if (porder == 2.0) {
132136
// fast 2-norm
133-
phi::funcs::ReduceKernel<T, T, kps::AddFunctor, SquareFunctor<T>>(
134-
dev_ctx, *in_x, out_norm, SquareFunctor<T>(), reduce_axis);
137+
phi::funcs::ReduceKernel<T, MT, kps::AddFunctor, SquareFunctor<MT>>(
138+
dev_ctx, *in_x, &out_temp, SquareFunctor<MT>(), reduce_axis);
135139
} else if (porder == 3.0) {
136140
// fast 3-norm
137-
phi::funcs::ReduceKernel<T, T, kps::AddFunctor, FabsCubicFunctor<T>>(
138-
dev_ctx, *in_x, out_norm, FabsCubicFunctor<T>(), reduce_axis);
141+
phi::funcs::ReduceKernel<T, MT, kps::AddFunctor, FabsCubicFunctor<MT>>(
142+
dev_ctx, *in_x, &out_temp, FabsCubicFunctor<MT>(), reduce_axis);
139143
} else {
140144
// vanilla norm
141-
phi::funcs::ReduceKernel<T, T, kps::AddFunctor, UnsignedPowFunctor<T>>(
142-
dev_ctx, *in_x, out_norm, UnsignedPowFunctor<T>(porder), reduce_axis);
145+
phi::funcs::ReduceKernel<T, MT, kps::AddFunctor, UnsignedPowFunctor<MT>>(
146+
dev_ctx,
147+
*in_x,
148+
&out_temp,
149+
UnsignedPowFunctor<MT>(porder),
150+
reduce_axis);
143151
}
144152

145153
if (porder != 1.0) {
146-
// save computation when porder is 1.0
147-
const DenseTensor* tmp_norm = out_norm;
148-
std::vector<const DenseTensor*> ins = {tmp_norm};
154+
std::vector<const DenseTensor*> ins = {&out_temp};
149155
std::vector<DenseTensor*> outs = {out_norm};
156+
MT p_order_ = static_cast<MT>(1.f / porder);
150157
phi::funcs::ElementwiseKernel<T>(
151-
dev_ctx, ins, &outs, UnsignedPowFunctor<T>(1. / porder));
158+
dev_ctx, ins, &outs, UnsignedPowFunctor<MT>(p_order_));
152159
}
153160
#endif
154161
}

0 commit comments

Comments
 (0)