Skip to content

Commit 0135ca0

Browse files
committed
[API-Compat] Correct min/max_with index gradient behavior
1 parent 89f2d92 commit 0135ca0

File tree

9 files changed

+195
-166
lines changed

9 files changed

+195
-166
lines changed
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/phi/backends/gpu/gpu_context.h"
16+
#include "paddle/phi/common/place.h"
17+
#include "paddle/phi/core/kernel_registry.h"
18+
#include "paddle/phi/core/utils/data_type.h"
19+
#include "paddle/phi/kernels/funcs/gather_scatter_functor.h"
20+
#include "paddle/phi/kernels/funcs/math_function.h"
21+
22+
namespace phi {
23+
24+
template <typename T>
25+
using EnableIfInteger =
26+
typename std::enable_if<std::is_integral<T>::value, int>::type;
27+
28+
template <typename T>
29+
using EnableIfNonInteger =
30+
typename std::enable_if<!std::is_integral<T>::value, int>::type;
31+
32+
// Here if keepdim=True, this will fallback to a simplified version of
33+
// take_along_axis. However, if keepdim=False (by default), indices will
34+
// not have equal rank will the input values (and values_grad), therefore
35+
// needs an unsqueeze operation by shallow copying indices and Resize
36+
#define DEFINE_WITH_INDEX_GRAD_KERNEL(OpType) \
37+
template <typename T, typename Context, EnableIfNonInteger<T> = 0> \
38+
void OpType##WithIndexGradKernel(const Context& dev_ctx, \
39+
const DenseTensor& x, \
40+
const DenseTensor& values, \
41+
const DenseTensor& indices, \
42+
const DenseTensor& values_grad, \
43+
const Scalar& dim, \
44+
bool keepdim, \
45+
DenseTensor* x_grad) { \
46+
x_grad->Resize(x.dims()); \
47+
dev_ctx.template Alloc<T>(x_grad); \
48+
if (x_grad->numel() == 0) { \
49+
return; \
50+
} \
51+
int64_t dim_val = dim.to<int64_t>(); \
52+
if (dim_val < 0) { \
53+
dim_val += x.dims().size(); \
54+
} \
55+
DenseTensor shallow_copied_inds(indices); \
56+
if (!keepdim) { \
57+
auto indices_dim = x.dims(); \
58+
indices_dim[dim_val] = 1; \
59+
shallow_copied_inds.Resize(indices_dim); \
60+
} \
61+
phi::funcs::SetConstant<Context, T> functor; \
62+
functor(dev_ctx, x_grad, static_cast<T>(0)); \
63+
phi::funcs::gpu_scatter_add_kernel<T, int64_t>( \
64+
*x_grad, dim_val, shallow_copied_inds, values_grad, true, dev_ctx); \
65+
} \
66+
template <typename T, typename Context, EnableIfInteger<T> = 0> \
67+
void OpType##WithIndexGradKernel(const Context& dev_ctx, \
68+
const DenseTensor& x, \
69+
const DenseTensor& values, \
70+
const DenseTensor& indices, \
71+
const DenseTensor& values_grad, \
72+
const Scalar& dim, \
73+
bool keepdim, \
74+
DenseTensor* x_grad) { \
75+
std::string dtype_name = phi::DataTypeToString(values.dtype()); \
76+
PADDLE_ENFORCE_EQ( \
77+
0, \
78+
1, \
79+
phi::errors::InvalidArgument( \
80+
"Integer type '%s' is not allowed to have stop_gradient=False.", \
81+
dtype_name.c_str())); \
82+
}
83+
84+
DEFINE_WITH_INDEX_GRAD_KERNEL(Max)
85+
DEFINE_WITH_INDEX_GRAD_KERNEL(Min)
86+
87+
#undef DEFINE_WITH_INDEX_GRAD_KERNEL
88+
89+
} // namespace phi
90+
91+
PD_REGISTER_KERNEL(max_with_index_grad,
92+
GPU,
93+
ALL_LAYOUT,
94+
phi::MaxWithIndexGradKernel,
95+
float,
96+
double,
97+
uint8_t,
98+
int,
99+
int16_t,
100+
int64_t,
101+
phi::dtype::float16,
102+
phi::dtype::bfloat16) {}
103+
104+
PD_REGISTER_KERNEL(min_with_index_grad,
105+
GPU,
106+
ALL_LAYOUT,
107+
phi::MinWithIndexGradKernel,
108+
float,
109+
double,
110+
uint8_t,
111+
int,
112+
int16_t,
113+
int64_t,
114+
phi::dtype::float16,
115+
phi::dtype::bfloat16) {}

paddle/phi/kernels/gpu/min_max_with_index_kernel.cu

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -254,25 +254,25 @@ void MinMaxWithIndexOpCUDAKernel(const Context& dev_ctx,
254254
template <typename T, typename Context>
255255
void MinWithIndexKernel(const Context& dev_ctx,
256256
const DenseTensor& x,
257-
const Scalar& axis,
258-
bool keepdims,
257+
const Scalar& dim,
258+
bool keepdim,
259259
bool flatten,
260260
DenseTensor* val_out,
261261
DenseTensor* ind_out) {
262262
MinMaxWithIndexOpCUDAKernel<Context, T, cub::ArgMin>(
263-
dev_ctx, x, axis, keepdims, flatten, val_out, ind_out);
263+
dev_ctx, x, dim, keepdim, flatten, val_out, ind_out);
264264
}
265265

266266
template <typename T, typename Context>
267267
void MaxWithIndexKernel(const Context& dev_ctx,
268268
const DenseTensor& x,
269-
const Scalar& axis,
270-
bool keepdims,
269+
const Scalar& dim,
270+
bool keepdim,
271271
bool flatten,
272272
DenseTensor* val_out,
273273
DenseTensor* ind_out) {
274274
MinMaxWithIndexOpCUDAKernel<Context, T, cub::ArgMax>(
275-
dev_ctx, x, axis, keepdims, flatten, val_out, ind_out);
275+
dev_ctx, x, dim, keepdim, flatten, val_out, ind_out);
276276
}
277277

278278
#endif

paddle/phi/kernels/gpu/reduce_kernel.cu

Lines changed: 0 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -160,80 +160,6 @@ void ReduceAMaxGradKernel(const Context& dev_ctx,
160160
dev_ctx, x, out, out_grad, dims, keep_dim, reduce_all, x_grad);
161161
}
162162

163-
template <typename T>
164-
using EnableIfInteger =
165-
typename std::enable_if<std::is_integral<T>::value, int>::type;
166-
167-
template <typename T>
168-
using EnableIfNonInteger =
169-
typename std::enable_if<!std::is_integral<T>::value, int>::type;
170-
171-
template <typename T, typename Context, EnableIfNonInteger<T> = 0>
172-
void MinWithIndexGradKernel(const Context& dev_ctx,
173-
const DenseTensor& x,
174-
const DenseTensor& values,
175-
const DenseTensor& values_grad,
176-
const Scalar& dim,
177-
bool keepdims,
178-
bool flatten,
179-
DenseTensor* x_grad) {
180-
int64_t dim_val = dim.to<int64_t>();
181-
flatten = recompute_reduce_all(x, {dim_val}, flatten);
182-
ReduceCudaAMaxAMinGrad<T, Context>(
183-
dev_ctx, x, values, values_grad, {dim_val}, keepdims, flatten, x_grad);
184-
}
185-
186-
template <typename T, typename Context, EnableIfInteger<T> = 0>
187-
void MinWithIndexGradKernel(const Context& dev_ctx,
188-
const DenseTensor& x,
189-
const DenseTensor& values,
190-
const DenseTensor& values_grad,
191-
const Scalar& dim,
192-
bool keepdims,
193-
bool flatten,
194-
DenseTensor* x_grad) {
195-
std::string dtype_name = phi::DataTypeToString(x.dtype());
196-
PADDLE_ENFORCE_EQ(
197-
0,
198-
1,
199-
phi::errors::InvalidArgument(
200-
"Integer type '%s' is not allowed to have stop_gradient=False.",
201-
dtype_name.c_str()));
202-
}
203-
204-
template <typename T, typename Context, EnableIfNonInteger<T> = 0>
205-
void MaxWithIndexGradKernel(const Context& dev_ctx,
206-
const DenseTensor& x,
207-
const DenseTensor& values,
208-
const DenseTensor& values_grad,
209-
const Scalar& dim,
210-
bool keepdims,
211-
bool flatten,
212-
DenseTensor* x_grad) {
213-
int64_t dim_val = dim.to<int64_t>();
214-
flatten = recompute_reduce_all(x, {dim_val}, flatten);
215-
ReduceCudaAMaxAMinGrad<T, Context>(
216-
dev_ctx, x, values, values_grad, {dim_val}, keepdims, flatten, x_grad);
217-
}
218-
219-
template <typename T, typename Context, EnableIfInteger<T> = 0>
220-
void MaxWithIndexGradKernel(const Context& dev_ctx,
221-
const DenseTensor& x,
222-
const DenseTensor& values,
223-
const DenseTensor& values_grad,
224-
const Scalar& dim,
225-
bool keepdims,
226-
bool flatten,
227-
DenseTensor* x_grad) {
228-
std::string dtype_name = phi::DataTypeToString(x.dtype());
229-
PADDLE_ENFORCE_EQ(
230-
0,
231-
1,
232-
phi::errors::InvalidArgument(
233-
"Integer type '%s' is not allowed to have stop_gradient=False.",
234-
dtype_name.c_str()));
235-
}
236-
237163
template <typename T, typename Context>
238164
void ReduceMaxGradKernel(const Context& dev_ctx,
239165
const DenseTensor& x,
@@ -359,19 +285,6 @@ PD_REGISTER_KERNEL(max_grad,
359285
phi::dtype::float16,
360286
phi::dtype::bfloat16) {}
361287

362-
PD_REGISTER_KERNEL(max_with_index_grad,
363-
GPU,
364-
ALL_LAYOUT,
365-
phi::MaxWithIndexGradKernel,
366-
float,
367-
double,
368-
uint8_t,
369-
int,
370-
int16_t,
371-
int64_t,
372-
phi::dtype::float16,
373-
phi::dtype::bfloat16) {}
374-
375288
PD_REGISTER_KERNEL(mean_grad,
376289
GPU,
377290
ALL_LAYOUT,
@@ -398,19 +311,6 @@ PD_REGISTER_KERNEL(min_grad,
398311
phi::dtype::float16,
399312
phi::dtype::bfloat16) {}
400313

401-
PD_REGISTER_KERNEL(min_with_index_grad,
402-
GPU,
403-
ALL_LAYOUT,
404-
phi::MinWithIndexGradKernel,
405-
float,
406-
double,
407-
uint8_t,
408-
int,
409-
int16_t,
410-
int64_t,
411-
phi::dtype::float16,
412-
phi::dtype::bfloat16) {}
413-
414314
PD_REGISTER_KERNEL(sum_grad,
415315
GPU,
416316
ALL_LAYOUT,

paddle/phi/kernels/min_max_with_index_grad_kernel.h.h

Lines changed: 0 additions & 42 deletions
This file was deleted.

paddle/phi/kernels/min_max_with_index_kernel.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,17 +22,17 @@ namespace phi {
2222
template <typename T, typename Context>
2323
void MinWithIndexKernel(const Context& dev_ctx,
2424
const DenseTensor& x,
25-
const Scalar& axis,
26-
bool keepdims,
25+
const Scalar& dim,
26+
bool keepdim,
2727
bool flatten,
2828
DenseTensor* val_out,
2929
DenseTensor* ind_out);
3030

3131
template <typename T, typename Context>
3232
void MaxWithIndexKernel(const Context& dev_ctx,
3333
const DenseTensor& x,
34-
const Scalar& axis,
35-
bool keepdims,
34+
const Scalar& dim,
35+
bool keepdim,
3636
bool flatten,
3737
DenseTensor* val_out,
3838
DenseTensor* ind_out);

paddle/phi/ops/yaml/backward.yaml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2278,8 +2278,8 @@
22782278
func : max_pool3d_with_index_grad
22792279

22802280
- backward_op : max_with_index_grad
2281-
forward : max_with_index (Tensor x, Scalar axis, bool keepdims, bool flatten) -> Tensor(values), Tensor(indices)
2282-
args : (Tensor x, Tensor values, Tensor values_grad, Scalar axis, bool keepdims, bool flatten)
2281+
forward : max_with_index (Tensor x, Scalar dim, bool keepdim, bool flatten) -> Tensor(values), Tensor(indices)
2282+
args : (Tensor x, Tensor values, Tensor indices, Tensor values_grad, Scalar dim, bool keepdim)
22832283
output : Tensor(x_grad)
22842284
infer_meta :
22852285
func : UnchangedInferMeta
@@ -2351,8 +2351,8 @@
23512351
data_type : out_grad
23522352

23532353
- backward_op : min_with_index_grad
2354-
forward : min_with_index (Tensor x, Scalar axis, bool keepdims, bool flatten) -> Tensor(values), Tensor(indices)
2355-
args : (Tensor x, Tensor values, Tensor values_grad, Scalar axis, bool keepdims, bool flatten)
2354+
forward : min_with_index (Tensor x, Scalar dim, bool keepdim, bool flatten) -> Tensor(values), Tensor(indices)
2355+
args : (Tensor x, Tensor values, Tensor indices, Tensor values_grad, Scalar dim, bool keepdim)
23562356
output : Tensor(x_grad)
23572357
infer_meta :
23582358
func : UnchangedInferMeta

paddle/phi/ops/yaml/ops.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3554,7 +3554,7 @@
35543554
interfaces : paddle::dialect::InferSymbolicShapeInterface
35553555

35563556
- op : max_with_index
3557-
args : (Tensor x, Scalar(int64_t) axis, bool keepdims = false, bool flatten = false)
3557+
args : (Tensor x, Scalar(int64_t) dim, bool keepdim = false, bool flatten = false)
35583558
output : Tensor(values), Tensor(indices)
35593559
infer_meta :
35603560
func : MinMaxWithIndexInferMeta
@@ -3674,7 +3674,7 @@
36743674
interfaces : paddle::dialect::InferSymbolicShapeInterface
36753675

36763676
- op : min_with_index
3677-
args : (Tensor x, Scalar(int64_t) axis, bool keepdims = false, bool flatten = false)
3677+
args : (Tensor x, Scalar(int64_t) dim, bool keepdim = false, bool flatten = false)
36783678
output : Tensor(values), Tensor(indices)
36793679
infer_meta :
36803680
func : MinMaxWithIndexInferMeta

0 commit comments

Comments
 (0)