Skip to content

Commit e8c78b7

Browse files
committed
[API-Compat] Add paddle.compat.min/max and new PHI kernel (min/max_with_index)
1 parent c4a483f commit e8c78b7

File tree

12 files changed

+756
-7
lines changed

12 files changed

+756
-7
lines changed

paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -315,8 +315,9 @@ bool AnyOpInferSymbolicShape(pir::Operation *op,
315315
axis.size() == 0 /*reduce_all*/);
316316
}
317317

318-
bool ArgmaxOpInferSymbolicShape(pir::Operation *op,
319-
pir::InferSymbolicShapeContext *infer_context) {
318+
bool MinMaxOpInferSymbolicShape(pir::Operation *op,
319+
pir::InferSymbolicShapeContext *infer_context,
320+
bool output_val_and_ind = false) {
320321
bool flatten = GetBoolAttr(op, "flatten");
321322
bool keepdims = GetBoolAttr(op, "keepdims");
322323

@@ -357,13 +358,23 @@ bool ArgmaxOpInferSymbolicShape(pir::Operation *op,
357358
symbol::TensorShapeOrDataDimExprs(out_sym_shape)};
358359

359360
infer_context->SetShapeOrDataForValue(op->result(0), shape_data);
361+
if (output_val_and_ind)
362+
infer_context->SetShapeOrDataForValue(op->result(1), shape_data);
360363
return true;
361364
}
362365

363-
bool ArgminOpInferSymbolicShape(pir::Operation *op,
364-
pir::InferSymbolicShapeContext *infer_context) {
365-
return ArgmaxOpInferSymbolicShape(op, infer_context);
366-
}
366+
#define DEFINE_MINMAX_OP_INFER_FUNC(OpName, output_val_and_ind) \
367+
bool OpName##OpInferSymbolicShape( \
368+
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { \
369+
return MinMaxOpInferSymbolicShape(op, infer_context, output_val_and_ind); \
370+
}
371+
372+
DEFINE_MINMAX_OP_INFER_FUNC(Argmin, false)
373+
DEFINE_MINMAX_OP_INFER_FUNC(Argmax, false)
374+
DEFINE_MINMAX_OP_INFER_FUNC(MinWithIndex, true)
375+
DEFINE_MINMAX_OP_INFER_FUNC(MaxWithIndex, true)
376+
377+
#undef DEFINE_MINMAX_OP_INFER_FUNC
367378

368379
bool AsComplexOpInferSymbolicShape(
369380
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {

paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,8 +93,10 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(Lu)
9393
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Lu_)
9494
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Mode)
9595
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Max)
96+
OP_DECLARE_INFER_SYMBOLIC_SHAPE(MaxWithIndex)
9697
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Maxout)
9798
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Min)
99+
OP_DECLARE_INFER_SYMBOLIC_SHAPE(MinWithIndex)
98100
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Mean)
99101
OP_DECLARE_INFER_SYMBOLIC_SHAPE(MeanAll)
100102
OP_DECLARE_INFER_SYMBOLIC_SHAPE(MatrixPower)

paddle/phi/infermeta/unary.cc

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -366,6 +366,90 @@ void ArgMinMaxInferMeta(const MetaTensor& x,
366366
}
367367
}
368368

369+
void MinMaxWithIndexInferMeta(const MetaTensor& x,
370+
const Scalar& axis,
371+
bool keepdims,
372+
bool flatten,
373+
MetaTensor* val_out,
374+
MetaTensor* ind_out,
375+
MetaConfig config) {
376+
DataType val_dtype = x.dtype();
377+
378+
if (!config.is_runtime && axis.FromTensor()) {
379+
std::vector<int64_t> vec;
380+
if (flatten) {
381+
if (keepdims) { // NOLINT
382+
vec = std::vector<int64_t>(x.dims().size(), -1);
383+
} else {
384+
vec = {};
385+
}
386+
} else {
387+
if (keepdims) {
388+
vec = std::vector<int64_t>(x.dims().size(), -1);
389+
} else {
390+
vec = std::vector<int64_t>(x.dims().size() - 1, -1);
391+
}
392+
}
393+
val_out->set_dims(common::make_ddim(vec));
394+
val_out->set_dtype(val_dtype);
395+
ind_out->set_dims(common::make_ddim(vec));
396+
ind_out->set_dtype(DataType::INT64);
397+
return;
398+
}
399+
auto int_axis = axis.to<int64_t>();
400+
const auto& x_dims = x.dims();
401+
402+
auto x_rank = x.dims().size();
403+
if (x_rank > 0) {
404+
PADDLE_ENFORCE_GE(int_axis,
405+
-x_rank,
406+
common::errors::InvalidArgument(
407+
"'axis'(%d) must be greater than or equal to"
408+
" -Rank(X)(%d).",
409+
int_axis,
410+
-x_rank));
411+
PADDLE_ENFORCE_LT(
412+
int_axis,
413+
x_rank,
414+
common::errors::InvalidArgument(
415+
"'axis'(%d) must be less than Rank(X)(%d) of Input(X).",
416+
int_axis,
417+
x_rank));
418+
} else {
419+
// 0-dim tensor
420+
PADDLE_ENFORCE_EQ(int_axis == 0 || int_axis == -1,
421+
true,
422+
common::errors::InvalidArgument(
423+
"'axis'(%d) must be 0 or -1 if input tensor is "
424+
"0-dim.",
425+
int_axis));
426+
}
427+
428+
if (int_axis < 0) int_axis += x_rank;
429+
430+
std::vector<int64_t> vec;
431+
if (flatten) {
432+
if (keepdims) { // NOLINT
433+
vec = std::vector<int64_t>(x.dims().size(), 1);
434+
} else {
435+
vec = {};
436+
}
437+
} else {
438+
for (int64_t i = 0; i < int_axis; i++)
439+
vec.emplace_back(x_dims[static_cast<int>(i)]);
440+
if (keepdims) {
441+
vec.emplace_back(static_cast<int64_t>(1));
442+
}
443+
for (int64_t i = int_axis + 1; i < x_rank; i++)
444+
vec.emplace_back(x_dims[static_cast<int>(i)]);
445+
}
446+
447+
val_out->set_dims(common::make_ddim(vec));
448+
val_out->set_dtype(val_dtype);
449+
ind_out->set_dims(common::make_ddim(vec));
450+
ind_out->set_dtype(DataType::INT64);
451+
}
452+
369453
void ArgsortInferMeta(const MetaTensor& input,
370454
int axis,
371455
bool descending,

paddle/phi/infermeta/unary.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,14 @@ void ArgMinMaxInferMeta(const MetaTensor& x,
6666
MetaTensor* out,
6767
MetaConfig config = MetaConfig());
6868

69+
void MinMaxWithIndexInferMeta(const MetaTensor& x,
70+
const Scalar& axis,
71+
bool keepdims,
72+
bool flatten,
73+
MetaTensor* val_out,
74+
MetaTensor* ind_out,
75+
MetaConfig config = MetaConfig());
76+
6977
void ArgsortInferMeta(const MetaTensor& input,
7078
int axis,
7179
bool descending,

0 commit comments

Comments
 (0)