Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
35707c5
[API-Compat] paddle.compat.split is added and tested
Enigmatisms Aug 5, 2025
23c422d
[API-Compat] paddle.compat.split is rigorously tested
Enigmatisms Aug 6, 2025
309b44a
[API-Compat] Make the forbid_keywords decorator transparent
Enigmatisms Aug 6, 2025
2a33744
[API-Compat] Fixed decorator str input
Enigmatisms Aug 6, 2025
11d9640
[API-Compat] More unittest & static graph check & updated decorator
Enigmatisms Aug 9, 2025
6a58470
[API-Compat] Add paddle.compat.min/max and new PHI kernel (min/max_wi…
Enigmatisms Aug 8, 2025
6255ed9
[API-Compat] Add compat.min/max EN doc
Enigmatisms Aug 9, 2025
6fa8807
[WIP][API-Compat] Add dyna-graph unittests for min/max
Enigmatisms Aug 10, 2025
adb4c25
[WIP][API-Compat] Fixed CPU failure
Enigmatisms Aug 10, 2025
fd6adf0
[API-Compat] Correct min/max_with index gradient behavior
Enigmatisms Aug 10, 2025
3081556
[API-Compat] XPU fix (attempt)
Enigmatisms Aug 11, 2025
cd8d6ae
[API-Compat] Updated ForbidKeywordsDecorator
Enigmatisms Aug 11, 2025
085801e
some create api support more usage (#74494)
zhwesky2010 Aug 11, 2025
2864eb0
[API-Compat] Static Graph and CPU end debug
Enigmatisms Aug 11, 2025
693ff52
[API-Compat] Resolved conflicts in decorator_utils.py
Enigmatisms Aug 11, 2025
f3d7353
[API-Compat] Added static graph min/max_with_index op check, simplifi…
Enigmatisms Aug 13, 2025
bfd5134
[API-Compat] min/max static graph op test and out tensor support
Enigmatisms Aug 14, 2025
fb8bba0
[API-Compat] Resolved merge conflicts.
Enigmatisms Aug 14, 2025
47a08dc
[API-Compat] Fixed CPU static graph bugs
Enigmatisms Aug 14, 2025
9300d17
[API-Compat] Resolved merged conflicts, add symbolic shape test.
Enigmatisms Aug 19, 2025
17d848c
[API-Compat] Updated unittests
Enigmatisms Aug 19, 2025
822e8d7
[API-Compat] Update version year
Enigmatisms Aug 20, 2025
17f080e
[API-Compat] Fixed min/max out mechanism
Enigmatisms Aug 20, 2025
0fbbb99
[API-Compat] Try adding even more unittests.
Enigmatisms Aug 22, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions paddle/fluid/pir/dialect/op_generator/op_build_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@
'KthvalueInferMeta',
'MaxPoolWithIndexInferMeta',
'MaxPoolV2InferMeta',
'MinMaxWithIndexInferMeta',
'MultinomialInferMeta',
'OverlapAddInferMeta',
'PadInferMeta',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -315,26 +315,44 @@ bool AnyOpInferSymbolicShape(pir::Operation *op,
axis.size() == 0 /*reduce_all*/);
}

bool ArgmaxOpInferSymbolicShape(pir::Operation *op,
pir::InferSymbolicShapeContext *infer_context) {
bool MinMaxOpInferSymbolicShape(pir::Operation *op,
pir::InferSymbolicShapeContext *infer_context,
bool output_val_and_ind = false) {
bool flatten = GetBoolAttr(op, "flatten");
bool keepdims = GetBoolAttr(op, "keepdims");
bool keepdims = false;
int axis = 0;

if (output_val_and_ind) {
keepdims = GetBoolAttr(op, "keepdim");

PADDLE_ENFORCE_NE(
op->attributes().find("dim"),
op->attributes().end(),
common::errors::InvalidArgument(
"'dim' Attribute is expected for Min/MaxWithIndexOp. "));
axis = op->attributes()
.at("dim")
.dyn_cast<paddle::dialect::ScalarAttribute>()
.data()
.to<int64_t>();
} else {
keepdims = GetBoolAttr(op, "keepdims");
const auto &axis_shape_or_data =
infer_context->GetShapeOrDataForValue(op->operand_source(1));
axis = static_cast<int>(
axis_shape_or_data.data().value().at(0).Get<int64_t>());
}
const auto &input_sym_shape =
infer_context->GetShapeOrDataForValue(op->operand_source(0)).shape();
int rank = input_sym_shape.size();

const auto &axis_shape_or_data =
infer_context->GetShapeOrDataForValue(op->operand_source(1));
int axis =
static_cast<int>(axis_shape_or_data.data().value().at(0).Get<int64_t>());
int rank = input_sym_shape.size();
if (axis < 0) axis += rank;

const auto &out_sym_shape = [&] {
std::vector<symbol::DimExpr> out_sym_shape;
if (flatten) {
if (keepdims) {
out_sym_shape.emplace_back(std::int64_t(rank));
out_sym_shape.resize(rank, std::int64_t(1));
} else {
out_sym_shape = {};
}
Expand All @@ -357,14 +375,31 @@ bool ArgmaxOpInferSymbolicShape(pir::Operation *op,
symbol::TensorShapeOrDataDimExprs(out_sym_shape)};

infer_context->SetShapeOrDataForValue(op->result(0), shape_data);
if (output_val_and_ind)
infer_context->SetShapeOrDataForValue(op->result(1), shape_data);
return true;
}

#define DEFINE_MINMAX_OP_INFER_FUNC(OpName, output_val_and_ind) \
bool OpName##OpInferSymbolicShape( \
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { \
return MinMaxOpInferSymbolicShape(op, infer_context, output_val_and_ind); \
}

DEFINE_MINMAX_OP_INFER_FUNC(Argmax, false)
DEFINE_MINMAX_OP_INFER_FUNC(MaxWithIndex, true)
#undef DEFINE_MINMAX_OP_INFER_FUNC

bool ArgminOpInferSymbolicShape(pir::Operation *op,
pir::InferSymbolicShapeContext *infer_context) {
return ArgmaxOpInferSymbolicShape(op, infer_context);
}

bool MinWithIndexOpInferSymbolicShape(
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
return MaxWithIndexOpInferSymbolicShape(op, infer_context);
}

bool AsComplexOpInferSymbolicShape(
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
pir::Value operand_source = op->operand_source(0);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,10 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(Lu)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Lu_)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Mode)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Max)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(MaxWithIndex)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Maxout)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Min)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(MinWithIndex)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Mean)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(MeanAll)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(MatrixPower)
Expand Down
64 changes: 64 additions & 0 deletions paddle/phi/infermeta/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2950,6 +2950,70 @@ void ModeInferMeta(const MetaTensor& x,
indices->set_dtype(DataType::INT64);
}

void MinMaxWithIndexInferMeta(const MetaTensor& x,
const Scalar& axis,
bool keepdims,
bool flatten,
MetaTensor* val_out,
MetaTensor* ind_out,
MetaConfig config) {
DataType val_dtype = x.dtype();

// axis.FromTensor will never be true for this op
auto int_axis = axis.to<int64_t>();
const auto& x_dims = x.dims();

auto x_rank = x.dims().size();
if (x_rank > 0) {
PADDLE_ENFORCE_GE(int_axis,
-x_rank,
common::errors::InvalidArgument(
"'axis'(%d) must be greater than or equal to"
" -Rank(X)(%d).",
int_axis,
-x_rank));
PADDLE_ENFORCE_LT(
int_axis,
x_rank,
common::errors::InvalidArgument(
"'axis'(%d) must be less than Rank(X)(%d) of Input(X).",
int_axis,
x_rank));
} else {
// 0-dim tensor
PADDLE_ENFORCE_EQ(int_axis == 0 || int_axis == -1,
true,
common::errors::InvalidArgument(
"'axis'(%d) must be 0 or -1 if input tensor is "
"0-dim.",
int_axis));
}

if (int_axis < 0) int_axis += x_rank;

std::vector<int64_t> vec;
if (flatten) {
if (keepdims) { // NOLINT
vec = std::vector<int64_t>(x.dims().size(), 1);
} else {
vec = {};
}
} else {
for (int64_t i = 0; i < int_axis; i++)
vec.emplace_back(x_dims[static_cast<int>(i)]);
if (keepdims) {
vec.emplace_back(static_cast<int64_t>(1));
}
for (int64_t i = int_axis + 1; i < x_rank; i++)
vec.emplace_back(x_dims[static_cast<int>(i)]);
}

val_out->set_dims(common::make_ddim(vec));
val_out->set_dtype(val_dtype);
ind_out->set_dims(common::make_ddim(vec));
ind_out->set_dtype(DataType::INT64);
}

void MultinomialInferMeta(const MetaTensor& x,
const Scalar& num_samples,
bool replacement,
Expand Down
8 changes: 8 additions & 0 deletions paddle/phi/infermeta/unary.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,14 @@ void ArgMinMaxInferMeta(const MetaTensor& x,
MetaTensor* out,
MetaConfig config = MetaConfig());

void MinMaxWithIndexInferMeta(const MetaTensor& x,
const Scalar& axis,
bool keepdims,
bool flatten,
MetaTensor* val_out,
MetaTensor* ind_out,
MetaConfig config = MetaConfig());

void ArgsortInferMeta(const MetaTensor& input,
int axis,
bool descending,
Expand Down
115 changes: 115 additions & 0 deletions paddle/phi/kernels/gpu/min_max_with_index_grad_kernel.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/utils/data_type.h"
#include "paddle/phi/kernels/funcs/gather_scatter_functor.h"
#include "paddle/phi/kernels/funcs/math_function.h"

namespace phi {

template <typename T>
using EnableIfInteger =
typename std::enable_if<std::is_integral<T>::value, int>::type;

template <typename T>
using EnableIfNonInteger =
typename std::enable_if<!std::is_integral<T>::value, int>::type;

// Here if keepdim=True, this will fallback to a simplified version of
// take_along_axis. However, if keepdim=False (by default), indices will
// not have equal rank will the input values (and values_grad), therefore
// needs an unsqueeze operation by shallow copying indices and Resize
#define DEFINE_WITH_INDEX_GRAD_KERNEL(OpType) \
template <typename T, typename Context, EnableIfNonInteger<T> = 0> \
void OpType##WithIndexGradKernel(const Context& dev_ctx, \
const DenseTensor& x, \
const DenseTensor& values, \
const DenseTensor& indices, \
const DenseTensor& values_grad, \
const Scalar& dim, \
bool keepdim, \
DenseTensor* x_grad) { \
x_grad->Resize(x.dims()); \
dev_ctx.template Alloc<T>(x_grad); \
if (x_grad->numel() == 0) { \
return; \
} \
int64_t dim_val = dim.to<int64_t>(); \
if (dim_val < 0) { \
dim_val += x.dims().size(); \
} \
DenseTensor shallow_copied_inds(indices); \
if (!keepdim) { \
auto indices_dim = x.dims(); \
indices_dim[dim_val] = 1; \
shallow_copied_inds.Resize(indices_dim); \
} \
phi::funcs::SetConstant<Context, T> functor; \
functor(dev_ctx, x_grad, static_cast<T>(0)); \
phi::funcs::gpu_scatter_add_kernel<T, int64_t>( \
*x_grad, dim_val, shallow_copied_inds, values_grad, true, dev_ctx); \
} \
template <typename T, typename Context, EnableIfInteger<T> = 0> \
void OpType##WithIndexGradKernel(const Context& dev_ctx, \
const DenseTensor& x, \
const DenseTensor& values, \
const DenseTensor& indices, \
const DenseTensor& values_grad, \
const Scalar& dim, \
bool keepdim, \
DenseTensor* x_grad) { \
std::string dtype_name = phi::DataTypeToString(values.dtype()); \
PADDLE_ENFORCE_EQ( \
0, \
1, \
phi::errors::InvalidArgument( \
"Integer type '%s' is not allowed to have stop_gradient=False.", \
dtype_name.c_str())); \
}

DEFINE_WITH_INDEX_GRAD_KERNEL(Max)
DEFINE_WITH_INDEX_GRAD_KERNEL(Min)

#undef DEFINE_WITH_INDEX_GRAD_KERNEL

} // namespace phi

PD_REGISTER_KERNEL(max_with_index_grad,
GPU,
ALL_LAYOUT,
phi::MaxWithIndexGradKernel,
float,
double,
uint8_t,
int,
int16_t,
int64_t,
phi::dtype::float16,
phi::dtype::bfloat16) {}

PD_REGISTER_KERNEL(min_with_index_grad,
GPU,
ALL_LAYOUT,
phi::MinWithIndexGradKernel,
float,
double,
uint8_t,
int,
int16_t,
int64_t,
phi::dtype::float16,
phi::dtype::bfloat16) {}
Loading