Skip to content

Commit b95cb68

Browse files
[API-Compat] Add paddle.compat.min/max and new PHI kernel (min/max_with_index) (#74547)
* [API-Compat] paddle.compat.split is added and tested * [API-Compat] paddle.compat.split is rigorously tested * [API-Compat] Make the forbid_keywords decorator transparent * [API-Compat] Fixed decorator str input * [API-Compat] More unittest & static graph check & updated decorator * [API-Compat] Add paddle.compat.min/max and new PHI kernel (min/max_with_index) * [API-Compat] Add compat.min/max EN doc Attempting to fix integral type gradient computation (rejection) * [WIP][API-Compat] Add dyna-graph unittests for min/max * [WIP][API-Compat] Fixed CPU failure * [API-Compat] Correct min/max_with index gradient behavior * [API-Compat] XPU fix (attempt) * [API-Compat] Updated ForbidKeywordsDecorator * some create api support more usage (#74494) * [API-Compat] Static Graph and CPU end debug * [API-Compat] Resolved conflicts in decorator_utils.py * [API-Compat] Added static graph min/max_with_index op check, simplified implementation * [API-Compat] min/max static graph op test and out tensor support * [API-Compat] Resolved merge conflicts. * [API-Compat] Fixed CPU static graph bugs removed split API for independence. * [API-Compat] Resolved merged conflicts, add symbolic shape test. * [API-Compat] Updated unittests * [API-Compat] Update version year * [API-Compat] Fixed min/max out mechanism * [API-Compat] Try adding even more unittests. --------- Co-authored-by: zhwesky2010 <[email protected]>
1 parent 76c2c4d commit b95cb68

File tree

17 files changed

+2084
-11
lines changed

17 files changed

+2084
-11
lines changed

paddle/fluid/pir/dialect/op_generator/op_build_gen.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@
135135
'KthvalueInferMeta',
136136
'MaxPoolWithIndexInferMeta',
137137
'MaxPoolV2InferMeta',
138+
'MinMaxWithIndexInferMeta',
138139
'MultinomialInferMeta',
139140
'OverlapAddInferMeta',
140141
'PadInferMeta',

paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc

Lines changed: 44 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -315,26 +315,44 @@ bool AnyOpInferSymbolicShape(pir::Operation *op,
315315
axis.size() == 0 /*reduce_all*/);
316316
}
317317

318-
bool ArgmaxOpInferSymbolicShape(pir::Operation *op,
319-
pir::InferSymbolicShapeContext *infer_context) {
318+
bool MinMaxOpInferSymbolicShape(pir::Operation *op,
319+
pir::InferSymbolicShapeContext *infer_context,
320+
bool output_val_and_ind = false) {
320321
bool flatten = GetBoolAttr(op, "flatten");
321-
bool keepdims = GetBoolAttr(op, "keepdims");
322+
bool keepdims = false;
323+
int axis = 0;
324+
325+
if (output_val_and_ind) {
326+
keepdims = GetBoolAttr(op, "keepdim");
322327

328+
PADDLE_ENFORCE_NE(
329+
op->attributes().find("dim"),
330+
op->attributes().end(),
331+
common::errors::InvalidArgument(
332+
"'dim' Attribute is expected for Min/MaxWithIndexOp. "));
333+
axis = op->attributes()
334+
.at("dim")
335+
.dyn_cast<paddle::dialect::ScalarAttribute>()
336+
.data()
337+
.to<int64_t>();
338+
} else {
339+
keepdims = GetBoolAttr(op, "keepdims");
340+
const auto &axis_shape_or_data =
341+
infer_context->GetShapeOrDataForValue(op->operand_source(1));
342+
axis = static_cast<int>(
343+
axis_shape_or_data.data().value().at(0).Get<int64_t>());
344+
}
323345
const auto &input_sym_shape =
324346
infer_context->GetShapeOrDataForValue(op->operand_source(0)).shape();
325-
int rank = input_sym_shape.size();
326347

327-
const auto &axis_shape_or_data =
328-
infer_context->GetShapeOrDataForValue(op->operand_source(1));
329-
int axis =
330-
static_cast<int>(axis_shape_or_data.data().value().at(0).Get<int64_t>());
348+
int rank = input_sym_shape.size();
331349
if (axis < 0) axis += rank;
332350

333351
const auto &out_sym_shape = [&] {
334352
std::vector<symbol::DimExpr> out_sym_shape;
335353
if (flatten) {
336354
if (keepdims) {
337-
out_sym_shape.emplace_back(std::int64_t(rank));
355+
out_sym_shape.resize(rank, std::int64_t(1));
338356
} else {
339357
out_sym_shape = {};
340358
}
@@ -357,14 +375,31 @@ bool ArgmaxOpInferSymbolicShape(pir::Operation *op,
357375
symbol::TensorShapeOrDataDimExprs(out_sym_shape)};
358376

359377
infer_context->SetShapeOrDataForValue(op->result(0), shape_data);
378+
if (output_val_and_ind)
379+
infer_context->SetShapeOrDataForValue(op->result(1), shape_data);
360380
return true;
361381
}
362382

383+
#define DEFINE_MINMAX_OP_INFER_FUNC(OpName, output_val_and_ind) \
384+
bool OpName##OpInferSymbolicShape( \
385+
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { \
386+
return MinMaxOpInferSymbolicShape(op, infer_context, output_val_and_ind); \
387+
}
388+
389+
DEFINE_MINMAX_OP_INFER_FUNC(Argmax, false)
390+
DEFINE_MINMAX_OP_INFER_FUNC(MaxWithIndex, true)
391+
#undef DEFINE_MINMAX_OP_INFER_FUNC
392+
363393
bool ArgminOpInferSymbolicShape(pir::Operation *op,
364394
pir::InferSymbolicShapeContext *infer_context) {
365395
return ArgmaxOpInferSymbolicShape(op, infer_context);
366396
}
367397

398+
bool MinWithIndexOpInferSymbolicShape(
399+
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
400+
return MaxWithIndexOpInferSymbolicShape(op, infer_context);
401+
}
402+
368403
bool AsComplexOpInferSymbolicShape(
369404
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
370405
pir::Value operand_source = op->operand_source(0);

paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,8 +93,10 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(Lu)
9393
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Lu_)
9494
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Mode)
9595
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Max)
96+
OP_DECLARE_INFER_SYMBOLIC_SHAPE(MaxWithIndex)
9697
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Maxout)
9798
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Min)
99+
OP_DECLARE_INFER_SYMBOLIC_SHAPE(MinWithIndex)
98100
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Mean)
99101
OP_DECLARE_INFER_SYMBOLIC_SHAPE(MeanAll)
100102
OP_DECLARE_INFER_SYMBOLIC_SHAPE(MatrixPower)

paddle/phi/infermeta/unary.cc

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2950,6 +2950,70 @@ void ModeInferMeta(const MetaTensor& x,
29502950
indices->set_dtype(DataType::INT64);
29512951
}
29522952

2953+
void MinMaxWithIndexInferMeta(const MetaTensor& x,
2954+
const Scalar& axis,
2955+
bool keepdims,
2956+
bool flatten,
2957+
MetaTensor* val_out,
2958+
MetaTensor* ind_out,
2959+
MetaConfig config) {
2960+
DataType val_dtype = x.dtype();
2961+
2962+
// axis.FromTensor will never be true for this op
2963+
auto int_axis = axis.to<int64_t>();
2964+
const auto& x_dims = x.dims();
2965+
2966+
auto x_rank = x.dims().size();
2967+
if (x_rank > 0) {
2968+
PADDLE_ENFORCE_GE(int_axis,
2969+
-x_rank,
2970+
common::errors::InvalidArgument(
2971+
"'axis'(%d) must be greater than or equal to"
2972+
" -Rank(X)(%d).",
2973+
int_axis,
2974+
-x_rank));
2975+
PADDLE_ENFORCE_LT(
2976+
int_axis,
2977+
x_rank,
2978+
common::errors::InvalidArgument(
2979+
"'axis'(%d) must be less than Rank(X)(%d) of Input(X).",
2980+
int_axis,
2981+
x_rank));
2982+
} else {
2983+
// 0-dim tensor
2984+
PADDLE_ENFORCE_EQ(int_axis == 0 || int_axis == -1,
2985+
true,
2986+
common::errors::InvalidArgument(
2987+
"'axis'(%d) must be 0 or -1 if input tensor is "
2988+
"0-dim.",
2989+
int_axis));
2990+
}
2991+
2992+
if (int_axis < 0) int_axis += x_rank;
2993+
2994+
std::vector<int64_t> vec;
2995+
if (flatten) {
2996+
if (keepdims) { // NOLINT
2997+
vec = std::vector<int64_t>(x.dims().size(), 1);
2998+
} else {
2999+
vec = {};
3000+
}
3001+
} else {
3002+
for (int64_t i = 0; i < int_axis; i++)
3003+
vec.emplace_back(x_dims[static_cast<int>(i)]);
3004+
if (keepdims) {
3005+
vec.emplace_back(static_cast<int64_t>(1));
3006+
}
3007+
for (int64_t i = int_axis + 1; i < x_rank; i++)
3008+
vec.emplace_back(x_dims[static_cast<int>(i)]);
3009+
}
3010+
3011+
val_out->set_dims(common::make_ddim(vec));
3012+
val_out->set_dtype(val_dtype);
3013+
ind_out->set_dims(common::make_ddim(vec));
3014+
ind_out->set_dtype(DataType::INT64);
3015+
}
3016+
29533017
void MultinomialInferMeta(const MetaTensor& x,
29543018
const Scalar& num_samples,
29553019
bool replacement,

paddle/phi/infermeta/unary.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,14 @@ void ArgMinMaxInferMeta(const MetaTensor& x,
6666
MetaTensor* out,
6767
MetaConfig config = MetaConfig());
6868

69+
void MinMaxWithIndexInferMeta(const MetaTensor& x,
70+
const Scalar& axis,
71+
bool keepdims,
72+
bool flatten,
73+
MetaTensor* val_out,
74+
MetaTensor* ind_out,
75+
MetaConfig config = MetaConfig());
76+
6977
void ArgsortInferMeta(const MetaTensor& input,
7078
int axis,
7179
bool descending,
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/phi/backends/gpu/gpu_context.h"
16+
#include "paddle/phi/common/place.h"
17+
#include "paddle/phi/core/kernel_registry.h"
18+
#include "paddle/phi/core/utils/data_type.h"
19+
#include "paddle/phi/kernels/funcs/gather_scatter_functor.h"
20+
#include "paddle/phi/kernels/funcs/math_function.h"
21+
22+
namespace phi {
23+
24+
template <typename T>
25+
using EnableIfInteger =
26+
typename std::enable_if<std::is_integral<T>::value, int>::type;
27+
28+
template <typename T>
29+
using EnableIfNonInteger =
30+
typename std::enable_if<!std::is_integral<T>::value, int>::type;
31+
32+
// Here if keepdim=True, this will fallback to a simplified version of
33+
// take_along_axis. However, if keepdim=False (by default), indices will
34+
// not have equal rank will the input values (and values_grad), therefore
35+
// needs an unsqueeze operation by shallow copying indices and Resize
36+
#define DEFINE_WITH_INDEX_GRAD_KERNEL(OpType) \
37+
template <typename T, typename Context, EnableIfNonInteger<T> = 0> \
38+
void OpType##WithIndexGradKernel(const Context& dev_ctx, \
39+
const DenseTensor& x, \
40+
const DenseTensor& values, \
41+
const DenseTensor& indices, \
42+
const DenseTensor& values_grad, \
43+
const Scalar& dim, \
44+
bool keepdim, \
45+
DenseTensor* x_grad) { \
46+
x_grad->Resize(x.dims()); \
47+
dev_ctx.template Alloc<T>(x_grad); \
48+
if (x_grad->numel() == 0) { \
49+
return; \
50+
} \
51+
int64_t dim_val = dim.to<int64_t>(); \
52+
if (dim_val < 0) { \
53+
dim_val += x.dims().size(); \
54+
} \
55+
DenseTensor shallow_copied_inds(indices); \
56+
if (!keepdim) { \
57+
auto indices_dim = x.dims(); \
58+
indices_dim[dim_val] = 1; \
59+
shallow_copied_inds.Resize(indices_dim); \
60+
} \
61+
phi::funcs::SetConstant<Context, T> functor; \
62+
functor(dev_ctx, x_grad, static_cast<T>(0)); \
63+
phi::funcs::gpu_scatter_add_kernel<T, int64_t>( \
64+
*x_grad, dim_val, shallow_copied_inds, values_grad, true, dev_ctx); \
65+
} \
66+
template <typename T, typename Context, EnableIfInteger<T> = 0> \
67+
void OpType##WithIndexGradKernel(const Context& dev_ctx, \
68+
const DenseTensor& x, \
69+
const DenseTensor& values, \
70+
const DenseTensor& indices, \
71+
const DenseTensor& values_grad, \
72+
const Scalar& dim, \
73+
bool keepdim, \
74+
DenseTensor* x_grad) { \
75+
std::string dtype_name = phi::DataTypeToString(values.dtype()); \
76+
PADDLE_ENFORCE_EQ( \
77+
0, \
78+
1, \
79+
phi::errors::InvalidArgument( \
80+
"Integer type '%s' is not allowed to have stop_gradient=False.", \
81+
dtype_name.c_str())); \
82+
}
83+
84+
DEFINE_WITH_INDEX_GRAD_KERNEL(Max)
85+
DEFINE_WITH_INDEX_GRAD_KERNEL(Min)
86+
87+
#undef DEFINE_WITH_INDEX_GRAD_KERNEL
88+
89+
} // namespace phi
90+
91+
PD_REGISTER_KERNEL(max_with_index_grad,
92+
GPU,
93+
ALL_LAYOUT,
94+
phi::MaxWithIndexGradKernel,
95+
float,
96+
double,
97+
uint8_t,
98+
int,
99+
int16_t,
100+
int64_t,
101+
phi::dtype::float16,
102+
phi::dtype::bfloat16) {}
103+
104+
PD_REGISTER_KERNEL(min_with_index_grad,
105+
GPU,
106+
ALL_LAYOUT,
107+
phi::MinWithIndexGradKernel,
108+
float,
109+
double,
110+
uint8_t,
111+
int,
112+
int16_t,
113+
int64_t,
114+
phi::dtype::float16,
115+
phi::dtype::bfloat16) {}

0 commit comments

Comments
 (0)