diff --git a/paddle/fluid/pir/dialect/op_generator/op_build_gen.py b/paddle/fluid/pir/dialect/op_generator/op_build_gen.py index f8510480b2fca4..60840cc60ec5e9 100644 --- a/paddle/fluid/pir/dialect/op_generator/op_build_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/op_build_gen.py @@ -135,6 +135,7 @@ 'KthvalueInferMeta', 'MaxPoolWithIndexInferMeta', 'MaxPoolV2InferMeta', + 'MinMaxWithIndexInferMeta', 'MultinomialInferMeta', 'OverlapAddInferMeta', 'PadInferMeta', diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc index 6750759633d0b8..ab9e020aea41ea 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc @@ -315,26 +315,44 @@ bool AnyOpInferSymbolicShape(pir::Operation *op, axis.size() == 0 /*reduce_all*/); } -bool ArgmaxOpInferSymbolicShape(pir::Operation *op, - pir::InferSymbolicShapeContext *infer_context) { +bool MinMaxOpInferSymbolicShape(pir::Operation *op, + pir::InferSymbolicShapeContext *infer_context, + bool output_val_and_ind = false) { bool flatten = GetBoolAttr(op, "flatten"); - bool keepdims = GetBoolAttr(op, "keepdims"); + bool keepdims = false; + int axis = 0; + + if (output_val_and_ind) { + keepdims = GetBoolAttr(op, "keepdim"); + PADDLE_ENFORCE_NE( + op->attributes().find("dim"), + op->attributes().end(), + common::errors::InvalidArgument( + "'dim' Attribute is expected for Min/MaxWithIndexOp. ")); + axis = op->attributes() + .at("dim") + .dyn_cast() + .data() + .to(); + } else { + keepdims = GetBoolAttr(op, "keepdims"); + const auto &axis_shape_or_data = + infer_context->GetShapeOrDataForValue(op->operand_source(1)); + axis = static_cast( + axis_shape_or_data.data().value().at(0).Get()); + } const auto &input_sym_shape = infer_context->GetShapeOrDataForValue(op->operand_source(0)).shape(); - int rank = input_sym_shape.size(); - const auto &axis_shape_or_data = - infer_context->GetShapeOrDataForValue(op->operand_source(1)); - int axis = - static_cast(axis_shape_or_data.data().value().at(0).Get()); + int rank = input_sym_shape.size(); if (axis < 0) axis += rank; const auto &out_sym_shape = [&] { std::vector out_sym_shape; if (flatten) { if (keepdims) { - out_sym_shape.emplace_back(std::int64_t(rank)); + out_sym_shape.resize(rank, std::int64_t(1)); } else { out_sym_shape = {}; } @@ -357,14 +375,31 @@ bool ArgmaxOpInferSymbolicShape(pir::Operation *op, symbol::TensorShapeOrDataDimExprs(out_sym_shape)}; infer_context->SetShapeOrDataForValue(op->result(0), shape_data); + if (output_val_and_ind) + infer_context->SetShapeOrDataForValue(op->result(1), shape_data); return true; } +#define DEFINE_MINMAX_OP_INFER_FUNC(OpName, output_val_and_ind) \ + bool OpName##OpInferSymbolicShape( \ + pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { \ + return MinMaxOpInferSymbolicShape(op, infer_context, output_val_and_ind); \ + } + +DEFINE_MINMAX_OP_INFER_FUNC(Argmax, false) +DEFINE_MINMAX_OP_INFER_FUNC(MaxWithIndex, true) +#undef DEFINE_MINMAX_OP_INFER_FUNC + bool ArgminOpInferSymbolicShape(pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { return ArgmaxOpInferSymbolicShape(op, infer_context); } +bool MinWithIndexOpInferSymbolicShape( + pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { + return MaxWithIndexOpInferSymbolicShape(op, infer_context); +} + bool AsComplexOpInferSymbolicShape( pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { pir::Value operand_source = op->operand_source(0); diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.h b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.h index 9868d08d8a290d..8d21b51eb2719f 100755 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.h +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.h @@ -93,8 +93,10 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(Lu) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Lu_) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Mode) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Max) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(MaxWithIndex) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Maxout) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Min) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(MinWithIndex) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Mean) OP_DECLARE_INFER_SYMBOLIC_SHAPE(MeanAll) OP_DECLARE_INFER_SYMBOLIC_SHAPE(MatrixPower) diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 405528589b824e..ab8dff4a9e8d2d 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -2950,6 +2950,70 @@ void ModeInferMeta(const MetaTensor& x, indices->set_dtype(DataType::INT64); } +void MinMaxWithIndexInferMeta(const MetaTensor& x, + const Scalar& axis, + bool keepdims, + bool flatten, + MetaTensor* val_out, + MetaTensor* ind_out, + MetaConfig config) { + DataType val_dtype = x.dtype(); + + // axis.FromTensor will never be true for this op + auto int_axis = axis.to(); + const auto& x_dims = x.dims(); + + auto x_rank = x.dims().size(); + if (x_rank > 0) { + PADDLE_ENFORCE_GE(int_axis, + -x_rank, + common::errors::InvalidArgument( + "'axis'(%d) must be greater than or equal to" + " -Rank(X)(%d).", + int_axis, + -x_rank)); + PADDLE_ENFORCE_LT( + int_axis, + x_rank, + common::errors::InvalidArgument( + "'axis'(%d) must be less than Rank(X)(%d) of Input(X).", + int_axis, + x_rank)); + } else { + // 0-dim tensor + PADDLE_ENFORCE_EQ(int_axis == 0 || int_axis == -1, + true, + common::errors::InvalidArgument( + "'axis'(%d) must be 0 or -1 if input tensor is " + "0-dim.", + int_axis)); + } + + if (int_axis < 0) int_axis += x_rank; + + std::vector vec; + if (flatten) { + if (keepdims) { // NOLINT + vec = std::vector(x.dims().size(), 1); + } else { + vec = {}; + } + } else { + for (int64_t i = 0; i < int_axis; i++) + vec.emplace_back(x_dims[static_cast(i)]); + if (keepdims) { + vec.emplace_back(static_cast(1)); + } + for (int64_t i = int_axis + 1; i < x_rank; i++) + vec.emplace_back(x_dims[static_cast(i)]); + } + + val_out->set_dims(common::make_ddim(vec)); + val_out->set_dtype(val_dtype); + ind_out->set_dims(common::make_ddim(vec)); + ind_out->set_dtype(DataType::INT64); +} + void MultinomialInferMeta(const MetaTensor& x, const Scalar& num_samples, bool replacement, diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index 7334ee476c0ad9..ea6c95748c16c5 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -66,6 +66,14 @@ void ArgMinMaxInferMeta(const MetaTensor& x, MetaTensor* out, MetaConfig config = MetaConfig()); +void MinMaxWithIndexInferMeta(const MetaTensor& x, + const Scalar& axis, + bool keepdims, + bool flatten, + MetaTensor* val_out, + MetaTensor* ind_out, + MetaConfig config = MetaConfig()); + void ArgsortInferMeta(const MetaTensor& input, int axis, bool descending, diff --git a/paddle/phi/kernels/gpu/min_max_with_index_grad_kernel.cu b/paddle/phi/kernels/gpu/min_max_with_index_grad_kernel.cu new file mode 100644 index 00000000000000..2cbffdb67cb3ae --- /dev/null +++ b/paddle/phi/kernels/gpu/min_max_with_index_grad_kernel.cu @@ -0,0 +1,115 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/common/place.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/utils/data_type.h" +#include "paddle/phi/kernels/funcs/gather_scatter_functor.h" +#include "paddle/phi/kernels/funcs/math_function.h" + +namespace phi { + +template +using EnableIfInteger = + typename std::enable_if::value, int>::type; + +template +using EnableIfNonInteger = + typename std::enable_if::value, int>::type; + +// Here if keepdim=True, this will fallback to a simplified version of +// take_along_axis. However, if keepdim=False (by default), indices will +// not have equal rank will the input values (and values_grad), therefore +// needs an unsqueeze operation by shallow copying indices and Resize +#define DEFINE_WITH_INDEX_GRAD_KERNEL(OpType) \ + template = 0> \ + void OpType##WithIndexGradKernel(const Context& dev_ctx, \ + const DenseTensor& x, \ + const DenseTensor& values, \ + const DenseTensor& indices, \ + const DenseTensor& values_grad, \ + const Scalar& dim, \ + bool keepdim, \ + DenseTensor* x_grad) { \ + x_grad->Resize(x.dims()); \ + dev_ctx.template Alloc(x_grad); \ + if (x_grad->numel() == 0) { \ + return; \ + } \ + int64_t dim_val = dim.to(); \ + if (dim_val < 0) { \ + dim_val += x.dims().size(); \ + } \ + DenseTensor shallow_copied_inds(indices); \ + if (!keepdim) { \ + auto indices_dim = x.dims(); \ + indices_dim[dim_val] = 1; \ + shallow_copied_inds.Resize(indices_dim); \ + } \ + phi::funcs::SetConstant functor; \ + functor(dev_ctx, x_grad, static_cast(0)); \ + phi::funcs::gpu_scatter_add_kernel( \ + *x_grad, dim_val, shallow_copied_inds, values_grad, true, dev_ctx); \ + } \ + template = 0> \ + void OpType##WithIndexGradKernel(const Context& dev_ctx, \ + const DenseTensor& x, \ + const DenseTensor& values, \ + const DenseTensor& indices, \ + const DenseTensor& values_grad, \ + const Scalar& dim, \ + bool keepdim, \ + DenseTensor* x_grad) { \ + std::string dtype_name = phi::DataTypeToString(values.dtype()); \ + PADDLE_ENFORCE_EQ( \ + 0, \ + 1, \ + phi::errors::InvalidArgument( \ + "Integer type '%s' is not allowed to have stop_gradient=False.", \ + dtype_name.c_str())); \ + } + +DEFINE_WITH_INDEX_GRAD_KERNEL(Max) +DEFINE_WITH_INDEX_GRAD_KERNEL(Min) + +#undef DEFINE_WITH_INDEX_GRAD_KERNEL + +} // namespace phi + +PD_REGISTER_KERNEL(max_with_index_grad, + GPU, + ALL_LAYOUT, + phi::MaxWithIndexGradKernel, + float, + double, + uint8_t, + int, + int16_t, + int64_t, + phi::dtype::float16, + phi::dtype::bfloat16) {} + +PD_REGISTER_KERNEL(min_with_index_grad, + GPU, + ALL_LAYOUT, + phi::MinWithIndexGradKernel, + float, + double, + uint8_t, + int, + int16_t, + int64_t, + phi::dtype::float16, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/gpu/min_max_with_index_kernel.cu b/paddle/phi/kernels/gpu/min_max_with_index_kernel.cu new file mode 100644 index 00000000000000..521444ef9e9481 --- /dev/null +++ b/paddle/phi/kernels/gpu/min_max_with_index_kernel.cu @@ -0,0 +1,312 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/min_max_with_index_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +#if defined(__NVCC__) || defined(__HIPCC__) + +#ifdef __NVCC__ +#include "cub/cub.cuh" +#endif +#ifdef __HIPCC__ +#include +namespace cub = hipcub; +#endif +#include + +#include "paddle/common/ddim.h" +#include "paddle/phi/core/utils/data_type.h" +#include "paddle/phi/kernels/funcs/math_function.h" +namespace phi { + +namespace { // NOLINT +template +using KeyValuePair = cub::KeyValuePair; + +} // namespace + +#define FIXED_BLOCK_DIM_CASE_BASE(log2_block_dim, ...) \ + case (1 << (log2_block_dim)): { \ + constexpr auto kBlockDim = (1 << (log2_block_dim)); \ + __VA_ARGS__; \ + } break + +#define FIXED_BLOCK_DIM_CASE(...) \ + FIXED_BLOCK_DIM_CASE_BASE(10, ##__VA_ARGS__); \ + FIXED_BLOCK_DIM_CASE_BASE(9, ##__VA_ARGS__); \ + FIXED_BLOCK_DIM_CASE_BASE(8, ##__VA_ARGS__); \ + FIXED_BLOCK_DIM_CASE_BASE(7, ##__VA_ARGS__); \ + FIXED_BLOCK_DIM_CASE_BASE(6, ##__VA_ARGS__); \ + FIXED_BLOCK_DIM_CASE_BASE(5, ##__VA_ARGS__); \ + FIXED_BLOCK_DIM_CASE_BASE(4, ##__VA_ARGS__); \ + FIXED_BLOCK_DIM_CASE_BASE(3, ##__VA_ARGS__); + +template +__global__ void MinMaxWithIndexKernel(const int64_t height, // n * h + const int64_t width, // c + const int64_t post_size, // h + const Reducer reducer, + const T init, + const T* in, + T* val_out, + IndType* key_out) { + typedef cub::BlockReduce, BlockDim> BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + + for (IndexType idx = blockIdx.x; idx < height; idx += gridDim.x) { + KeyValuePair kv_pair = {-1, init}; + IndexType h = idx / post_size; + IndexType w = idx % post_size; + for (IndexType k = threadIdx.x; k < width; k += blockDim.x) { + kv_pair = + reducer({k, in[h * width * post_size + k * post_size + w]}, kv_pair); + } + kv_pair = BlockReduce(temp_storage).Reduce(kv_pair, reducer); + if (threadIdx.x == 0) { + val_out[idx] = static_cast(kv_pair.value); + key_out[idx] = static_cast(kv_pair.key); + } + __syncthreads(); + } +} + +template +void ComputeMinMaxWithIndex(const phi::GPUContext& dev_ctx, + const DenseTensor& input, + DenseTensor* values, + DenseTensor* indices, + const int64_t pre, + const int64_t post, + const int64_t n) { + auto cu_stream = dev_ctx.stream(); + auto ComputeBlockSize = [](int64_t col) { + auto block_size = 8; + if (col > 512) + block_size = 1024; + else if (col > 256) + block_size = 512; + else if (col > 128) + block_size = 256; + else if (col > 64) + block_size = 128; + else if (col > 32) + block_size = 64; + else if (col > 16) + block_size = 32; + else if (col > 8) + block_size = 16; + return block_size; + }; + + int64_t max_grid_dimx = dev_ctx.GetCUDAMaxGridDimSize()[0]; + int64_t height = pre * post; + int64_t width = n; + int64_t grid_size = height < max_grid_dimx ? height : max_grid_dimx; + + const T* in_data = input.data(); + + T* val_data = dev_ctx.template Alloc(values); + IndType* ind_data = dev_ctx.template Alloc(indices); + + if (typeid(Reducer) == typeid(cub::ArgMax)) { + switch (ComputeBlockSize(width)) { + FIXED_BLOCK_DIM_CASE( + MinMaxWithIndexKernel + <<>>( + height, + width, + post, + Reducer(), + std::numeric_limits::lowest(), + in_data, + val_data, + ind_data)); + } + } else { + switch (ComputeBlockSize(width)) { + FIXED_BLOCK_DIM_CASE( + MinMaxWithIndexKernel + <<>>( + height, + width, + post, + Reducer(), + std::numeric_limits::max(), + in_data, + val_data, + ind_data)); + } + } +} + +template +struct VisitDataCudaMinMaxWithIndexFunctor { + const Context& dev_ctx; + const DenseTensor& x; + int64_t axis; + bool keepdims; + bool flatten; + DenseTensor* val_out; + DenseTensor* ind_out; + + explicit VisitDataCudaMinMaxWithIndexFunctor(const Context& dev_ctx, + const DenseTensor& x, + int64_t axis, + bool keepdims, + bool flatten, + DenseTensor* val_out, + DenseTensor* ind_out) + : dev_ctx(dev_ctx), + x(x), + axis(axis), + keepdims(keepdims), + flatten(flatten), + val_out(val_out), + ind_out(ind_out) {} + + template + void apply() const { + phi::DDim x_dims; + int new_axis = axis; + if (flatten) { + x_dims = common::make_ddim({x.numel()}); + // if flatten, the axis just as 0 + new_axis = 0; + } else { + x_dims = x.dims(); + if (axis < 0) new_axis = axis + x.dims().size(); + } + if (x.numel() == 0) { + dev_ctx.template Alloc(val_out); + dev_ctx.template Alloc(ind_out); + return; + } + // For 0D Tensor + if (x.dims().size() == 0) { + dev_ctx.template Alloc(val_out); + dev_ctx.template Alloc(ind_out); + phi::funcs::set_constant(dev_ctx, ind_out, static_cast(0)); + phi::Copy(dev_ctx, x, dev_ctx.GetPlace(), false, val_out); + return; + } + + int64_t numel = x.numel(); + int64_t groups = numel / x_dims[new_axis]; + int64_t pre = 1; + int64_t post = 1; + int64_t n = x_dims[new_axis]; + + for (int i = 0; i < new_axis; i++) { + pre *= x_dims[i]; + } + + for (int i = new_axis + 1; i < x_dims.size(); i++) { + post *= x_dims[i]; + } + + if (numel > std::numeric_limits::max()) { + ComputeMinMaxWithIndex( + dev_ctx, x, val_out, ind_out, pre, post, n); + } else { + ComputeMinMaxWithIndex( + dev_ctx, x, val_out, ind_out, pre, post, n); + } + } +}; + +template +void MinMaxWithIndexOpCUDAKernel(const Context& dev_ctx, + const DenseTensor& x, + const Scalar& axis, + bool keepdims, + bool flatten, + DenseTensor* val_out, + DenseTensor* ind_out) { + PADDLE_ENFORCE_GE( + x.numel(), + 0, + common::errors::InvalidArgument( + "(min/max)_with_index input numel must > 0, bug got %d", x.numel())); + phi::VisitDataTypeTiny( + phi::DataType::INT64, + VisitDataCudaMinMaxWithIndexFunctor( + dev_ctx, x, axis.to(), keepdims, flatten, val_out, ind_out)); +} + +template +void MinWithIndexKernel(const Context& dev_ctx, + const DenseTensor& x, + const Scalar& dim, + bool keepdim, + bool flatten, + DenseTensor* val_out, + DenseTensor* ind_out) { + MinMaxWithIndexOpCUDAKernel( + dev_ctx, x, dim, keepdim, flatten, val_out, ind_out); +} + +template +void MaxWithIndexKernel(const Context& dev_ctx, + const DenseTensor& x, + const Scalar& dim, + bool keepdim, + bool flatten, + DenseTensor* val_out, + DenseTensor* ind_out) { + MinMaxWithIndexOpCUDAKernel( + dev_ctx, x, dim, keepdim, flatten, val_out, ind_out); +} + +#endif + +} // namespace phi + +PD_REGISTER_KERNEL(min_with_index, + GPU, + ALL_LAYOUT, + phi::MinWithIndexKernel, + phi::dtype::float16, + phi::dtype::bfloat16, + float, + double, + int32_t, + int64_t, + int16_t, + uint8_t) { + kernel->OutputAt(0).SetDataType(kernel->InputAt(0).dtype); + kernel->OutputAt(1).SetDataType(phi::DataType::INT64); +} + +PD_REGISTER_KERNEL(max_with_index, + GPU, + ALL_LAYOUT, + phi::MaxWithIndexKernel, + phi::dtype::float16, + phi::dtype::bfloat16, + float, + double, + int32_t, + int64_t, + int16_t, + uint8_t) { + kernel->OutputAt(0).SetDataType(kernel->InputAt(0).dtype); + kernel->OutputAt(1).SetDataType(phi::DataType::INT64); +} diff --git a/paddle/phi/kernels/min_max_with_index_kernel.h b/paddle/phi/kernels/min_max_with_index_kernel.h new file mode 100644 index 00000000000000..56e733fcdbeef8 --- /dev/null +++ b/paddle/phi/kernels/min_max_with_index_kernel.h @@ -0,0 +1,40 @@ +/* Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/phi/common/scalar.h" +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void MinWithIndexKernel(const Context& dev_ctx, + const DenseTensor& x, + const Scalar& dim, + bool keepdim, + bool flatten, + DenseTensor* val_out, + DenseTensor* ind_out); + +template +void MaxWithIndexKernel(const Context& dev_ctx, + const DenseTensor& x, + const Scalar& dim, + bool keepdim, + bool flatten, + DenseTensor* val_out, + DenseTensor* ind_out); + +} // namespace phi diff --git a/paddle/phi/ops/yaml/backward.yaml b/paddle/phi/ops/yaml/backward.yaml index 5364fa6ff73b9c..154b99e557fabf 100644 --- a/paddle/phi/ops/yaml/backward.yaml +++ b/paddle/phi/ops/yaml/backward.yaml @@ -2277,6 +2277,16 @@ kernel : func : max_pool3d_with_index_grad +- backward_op : max_with_index_grad + forward : max_with_index (Tensor x, Scalar dim, bool keepdim, bool flatten) -> Tensor(values), Tensor(indices) + args : (Tensor x, Tensor values, Tensor indices, Tensor values_grad, Scalar dim, bool keepdim) + output : Tensor(x_grad) + infer_meta : + func : UnchangedInferMeta + param : [x] + kernel : + func : max_with_index_grad + - backward_op : maxout_grad forward : maxout(Tensor x, int groups, int axis) -> Tensor(out) args : (Tensor x, Tensor out, Tensor out_grad, int groups, int axis) @@ -2340,6 +2350,16 @@ func : meshgrid_grad data_type : out_grad +- backward_op : min_with_index_grad + forward : min_with_index (Tensor x, Scalar dim, bool keepdim, bool flatten) -> Tensor(values), Tensor(indices) + args : (Tensor x, Tensor values, Tensor indices, Tensor values_grad, Scalar dim, bool keepdim) + output : Tensor(x_grad) + infer_meta : + func : UnchangedInferMeta + param : [x] + kernel : + func : min_with_index_grad + - backward_op : mish_grad forward : mish (Tensor x, float lambda) -> Tensor(out) args : (Tensor x, Tensor out_grad, float lambda) diff --git a/paddle/phi/ops/yaml/ops.yaml b/paddle/phi/ops/yaml/ops.yaml index b5f4d6371a82b1..694b19cbe62188 100644 --- a/paddle/phi/ops/yaml/ops.yaml +++ b/paddle/phi/ops/yaml/ops.yaml @@ -3577,6 +3577,17 @@ backward : max_pool3d_with_index_grad interfaces : paddle::dialect::InferSymbolicShapeInterface +- op : max_with_index + args : (Tensor x, Scalar(int64_t) dim, bool keepdim = false, bool flatten = false) + output : Tensor(values), Tensor(indices) + infer_meta : + func : MinMaxWithIndexInferMeta + kernel : + func : max_with_index + data_type : x + backward : max_with_index_grad + interfaces : paddle::dialect::InferSymbolicShapeInterface + - op : maxout args : (Tensor x, int groups, int axis = 1) output : Tensor(out) @@ -3686,6 +3697,17 @@ backward : meshgrid_grad interfaces : paddle::dialect::InferSymbolicShapeInterface +- op : min_with_index + args : (Tensor x, Scalar(int64_t) dim, bool keepdim = false, bool flatten = false) + output : Tensor(values), Tensor(indices) + infer_meta : + func : MinMaxWithIndexInferMeta + kernel : + func : min_with_index + data_type : x + backward : min_with_index_grad + interfaces : paddle::dialect::InferSymbolicShapeInterface + - op : mish args : (Tensor x, float lambda) output : Tensor diff --git a/python/paddle/compat.py b/python/paddle/compat.py index 2a37393e9053f8..023fe2efcbe325 100644 --- a/python/paddle/compat.py +++ b/python/paddle/compat.py @@ -14,8 +14,10 @@ from .tensor.compat import ( Unfold, + max, + min, sort, split, ) -__all__ = ['split', 'sort', 'Unfold'] +__all__ = ['split', 'sort', 'Unfold', 'min', 'max'] diff --git a/python/paddle/tensor/compat.py b/python/paddle/tensor/compat.py index ad7ec15d1cfae0..3995a274309144 100644 --- a/python/paddle/tensor/compat.py +++ b/python/paddle/tensor/compat.py @@ -14,7 +14,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, NamedTuple +from typing import TYPE_CHECKING, Any, NamedTuple import paddle from paddle import _C_ops @@ -224,6 +224,11 @@ class SortRetType(NamedTuple): indices: Tensor +class MinMaxRetType(NamedTuple): + values: Tensor + indices: Tensor + + def _check_out_status( out: Tensor | tuple[Tensor, Tensor] | list[Tensor], expect_multiple: bool = False, @@ -398,3 +403,428 @@ def to_list_if_necessary(x, size_check=False): dilations=to_list_if_necessary(self.dilations), name=self.name, ) + + +def _min_max_param_checker(func_name: str, *args: Any, **kwargs: Any): + def invalid_arguments_exception(error_prefix=""): + type_strs = [type(v).__name__ for v in args] + type_strs.extend([f"{k}={type(v).__name__}" for k, v in kwargs.items()]) + signature = ", ".join(type_strs) + + error_msg = ( + f"Invalid arguments for `paddle.compat.{func_name}`:\n{error_prefix}" + f"Got: (paddle.Tensor input, {signature}), but expect one of:\n" + f" - (input: paddle.Tensor) for reduce_{func_name} on all dims.\n" + f" - (input: paddle.Tensor, other: paddle.Tensor) -> see paddle.{func_name}imum\n" + f" - (input: paddle.Tensor, int dim (cannot be None), bool keepdim = False)\n" + ) + return TypeError(error_msg) + + def try_get_keys(key): + res = None + try: + res = kwargs[key] + except KeyError: + raise invalid_arguments_exception() from None + return res + + dim_or_other = None + keepdim = False + + num_args = len(args) + total_arg_num = num_args + len(kwargs) + if total_arg_num > 2: + raise invalid_arguments_exception() + elif total_arg_num == 2: + if num_args == 2: + dim_or_other, keepdim = args + elif num_args == 1: + dim_or_other = args[0] + keepdim = try_get_keys("keepdim") + else: + dim_or_other = try_get_keys("dim") + keepdim = try_get_keys("keepdim") + if dim_or_other is None or isinstance( + dim_or_other, (Variable, paddle.pir.Value) + ): + raise invalid_arguments_exception() + elif total_arg_num == 1: + if num_args: + dim_or_other = args[0] + else: + if "dim" in kwargs: + dim_or_other = kwargs["dim"] + elif "other" in kwargs: + dim_or_other = kwargs["other"] + if not isinstance(dim_or_other, (Variable, paddle.pir.Value)): + raise invalid_arguments_exception() + if dim_or_other is None: + raise invalid_arguments_exception() + + if ( + dim_or_other is not None + and not isinstance(dim_or_other, (Variable, paddle.pir.Value)) + and type(dim_or_other) is not int + ): + raise invalid_arguments_exception( + f"The second input must be int or Tensor or implicit None in compat.{func_name}, but received {type(dim_or_other)}.\n" + ) + + return dim_or_other, keepdim + + +def _min_max_tensor_allow_grad(input: Tensor): + """Prevent integral input tensor type to have `stop_gradient=False`""" + in_dtype = input.dtype + if ( + in_dtype == paddle.int32 + or in_dtype == paddle.int64 + or in_dtype == paddle.uint8 + or in_dtype == paddle.int16 + ): + if not input.stop_gradient: + raise TypeError( + f"Tensors with integral type: '{in_dtype}' should stop gradient." + ) + + +def _min_max_allow_cpu_composite(input: Tensor): + """paddle.min/argmin(max/argmax), paddle.take_along_axis reject the following types""" + in_dtype = input.dtype + if ( + in_dtype == paddle.float16 + or in_dtype == paddle.bfloat16 + or in_dtype == paddle.int16 + ): + raise TypeError( + f"Non-CUDA GPU placed Tensor does not have '{in_dtype}' op registered.\n" + "Paddle support following DataTypes: int32, int64, float64, float32, uint8" + ) + + +def _check_out_status( + out: Tensor | tuple[Tensor, Tensor] | list[Tensor], + expect_multiple: bool = False, +): + if out is None: + return + if not in_dynamic_mode(): + raise RuntimeError( + "Using `out` static graph CINN backend is currently not supported. Directly return the tensor tuple instead.\n" + ) + if expect_multiple: + if not isinstance(out, (tuple, list)) or len(out) != 2: + raise TypeError( + f"Expected a list or tuple of two tensors, got {type(out)} instead." + ) + if not ( + isinstance(out[0], paddle.Tensor) + and isinstance(out[1], paddle.Tensor) + ): + raise TypeError( + f"Expected Tensor type in the tuple/list, got ({type(out[0])}, {type(out[1])}) instead." + ) + else: + if not isinstance(out, paddle.Tensor): + raise TypeError(f"Expected a Tensor, got {type(out)} instead.") + + +@ForbidKeywordsDecorator( + illegal_keys={"x", "axis"}, + func_name="paddle.compat.min", + correct_name="paddle.min", +) +def min( + input: Tensor, + *args: Any, + out: Tensor | tuple[Tensor, Tensor] | list[Tensor] = None, + **kwargs: Any, +) -> Tensor | MinMaxRetType: + """ + + Computes the minimum of tensor elements. There are mainly 3 cases (functionalities): + 1. paddle.compat.min(input: Tensor): reduce min over all dims, return a single value Tensor + 2. paddle.compat.min(input: Tensor, dim: int (cannot be None), keepdim=False): reduce min over the given dim, + returns a named tuple MinMaxRetType(values: Tensor, indices: Tensor) + 3. paddle.compat.min(input: Tensor, other: Tensor): see `paddle.minimum` + + Special warning: the gradient behavior is NOT well-documented by PyTorch, the actual behavior should be: + 1. Case 1: the same as `min` + 2. Case 2: NOT evenly distributing the gradient for equal minimum elements! PyTorch actually only propagates to the elements with indices, + for example: Tensor([1, 1, 1]) -> min(..., dim=0) -> values=Tensor(0, ...), indices=Tensor(0), the gradient for input tensor won't be + Tensor([1/3, 1/3, 1/3]) as stated in their documentation, but will be Tensor([1, 0, 0]). This API implements a similar backward kernel. + 3. Case 3: the same as `minimum` + + Args: + input (Tensor): A tensor, the data type is bfloat16, float16, float32, float64, int32, int64 on GPU. + uint8, int32, int64, float32, float64 are allowed on CPU. + dim (int, optional): The dim along which the minimum is computed. + If this is not specified: see case 1, note that: `None` cannot be passed to this (TypeError will be thrown) + compute the minimum over all elements of `input` and return a Tensor with a single element, + otherwise must be in the range :math:`[-input.ndim, input.ndim)`. + If :math:`dim < 0`, the axis to reduce is :math:`input.ndim + dim`. + Warning: if `dim` is specified, execute static graph will throw exceptions + when not on a GPU device, since max_with_index is not implemented for non-GPU devices + keepdim (bool, optional): Whether to reserve the reduced dimension in the + output Tensor. The result tensor will have one fewer dimension + than the `input` unless :attr:`keepdim` is true, default + value is False. Note that if `dim` does not appear in neither (*args) or (**kwargs), this parameter cannot be passed alone + other (Tensor, optional): the other tensor to perform `paddle.minimum` with. This Tensor should + have the same or broadcast-able shape as the `input`. Note that (`dim` & `keepdim`) and `other` are mutually exclusive + meaning that trying to composite both will result in TypeError + out (Tensor|tuple[Tensor, Tensor], optional): the output Tensor or tuple of (Tensor, int64 Tensor) that can be optionally + given to be used as output buffers. For case 1 and 3 out is just a Tensor, while for case 2 we expect a tuple + + + Returns: + - For case 1: a single value Tensor (0-dim) + - For case 2: a named tuple MinMaxRetType(values: Tensor, indices: Tensor), `values` has the same data type as the `input`, + while indices is always an int64 Tensor, with exactly the same shape as `values`. + MinMaxRetType can be used (indexed, packed, unpacked) in the same way as a regular tuple + - For case 3: see `paddle.minimum` + + + Examples: + .. code-block:: python + + >>> import paddle + + >>> # data_x is a Tensor with shape [2, 4] + >>> # the axis is a int element + >>> x = paddle.to_tensor([[0.2, 0.3, 0.5, 0.9], + ... [0.1, 0.2, 0.6, 0.7]], + ... dtype='float64', stop_gradient=False) + >>> # Case 1: reduce over all dims + >>> result1 = paddle.compat.min(x) + >>> result1 + Tensor(shape=[], dtype=float64, place=Place(gpu:0), stop_gradient=False, + 0.10000000) + + >>> # Case 2: reduce over specified dim + >>> x.clear_grad() + >>> result2 = paddle.compat.min(x, dim=1) + >>> result2 + MinMaxRetType(values=Tensor(shape=[2], dtype=float64, place=Place(gpu:0), stop_gradient=False, + [0.20000000, 0.10000000]), indices=Tensor(shape=[2], dtype=int64, place=Place(gpu:0), stop_gradient=True, + [0, 0])) + >>> result2[0].backward() + >>> x.grad + Tensor(shape=[2, 4], dtype=float64, place=Place(gpu:0), stop_gradient=False, + [[1., 0., 0., 0.], + [1., 0., 0., 0.]]) + + >>> # Case 3: equivalent to `paddle.minimum` + >>> x.clear_grad() + >>> y = paddle.to_tensor([[0.5, 0.4, 0.1, 0.2], + ... [0.3, 0.1, 0.6, 0.7]], + ... dtype='float64', stop_gradient=False) + >>> result3 = paddle.compat.min(x, y) + >>> result3 + Tensor(shape=[2, 4], dtype=float64, place=Place(gpu:0), stop_gradient=False, + [[0.20000000, 0.30000000, 0.10000000, 0.20000000], + [0.10000000, 0.10000000, 0.60000000, 0.70000000]]) + """ + if not isinstance(input, (paddle.pir.Value, paddle.Tensor)): + raise TypeError( + f"input should be a tensor, but got an instance with type '{type(input).__name__}'" + ) + _min_max_tensor_allow_grad(input) + + dim_or_other, keepdim = _min_max_param_checker("min", *args, **kwargs) + + ret = None + if dim_or_other is None: + # paddle.min and paddle.amin actually shares the same grad op (ReduceAminKernel) + _check_out_status(out, False) + ret = paddle.min(input) + elif isinstance(dim_or_other, int): + _check_out_status(out, True) + if input.ndim: + if in_dynamic_mode() and not input.place.is_gpu_place(): + _min_max_allow_cpu_composite(input) + # CPUPlace and other placements are implemented by composition + + indices = paddle.argmin(input, axis=dim_or_other, keepdim=True) + values = paddle.take_along_axis( + input, indices, axis=dim_or_other + ) + if keepdim: + ret = MinMaxRetType(values=values, indices=indices) + else: + ret = MinMaxRetType( + values=values.squeeze_(axis=dim_or_other), + indices=indices.squeeze_(axis=dim_or_other), + ) + else: + vals, inds = _C_ops.min_with_index( + input, dim_or_other, keepdim, False + ) + inds.stop_gradient = True + ret = MinMaxRetType(values=vals, indices=inds) + else: + ret = MinMaxRetType( + values=input, + indices=paddle.zeros( + [], dtype=paddle.int64, device=input.place + ), + ) + else: + _check_out_status(out, False) + ret = _C_ops.minimum(input, dim_or_other) + + if out is not None: + if isinstance(ret, MinMaxRetType): + paddle.assign(ret.values, out[0]) + paddle.assign(ret.indices, out[1]) + else: + paddle.assign(ret, out) + return ret + + +@ForbidKeywordsDecorator( + illegal_keys={"x", "axis"}, + func_name="paddle.compat.max", + correct_name="paddle.max", +) +def max( + input: Tensor, + *args: Any, + out: Tensor | tuple[Tensor, Tensor] | list[Tensor] = None, + **kwargs: Any, +) -> Tensor | MinMaxRetType: + """ + + Computes the maximum of tensor elements. There are mainly 3 cases (functionalities): + 1. paddle.compat.max(input: Tensor): reduce max over all dims, return a single value Tensor + 2. paddle.compat.max(input: Tensor, dim: int (cannot be None), keepdim=False): reduce max over the given dim, + returns a named tuple MinMaxRetType(values: Tensor, indices: Tensor) + 3. paddle.compat.max(input: Tensor, other: Tensor): see `paddle.maximum` + + Special warning: the gradient behavior is NOT well-documented by PyTorch, the actual behavior should be: + 1. Case 1: the same as `max` + 2. Case 2: NOT evenly distributing the gradient for equal maximum elements! PyTorch actually only propagates to the elements with indices, + for example: Tensor([1, 1, 1]) -> max(..., dim=0) -> values=Tensor(0, ...), indices=Tensor(0), the gradient for input tensor won't be + Tensor([1/3, 1/3, 1/3]) as stated in their documentation, but will be Tensor([1, 0, 0]). This API implements a similar backward kernel. + 3. Case 3: the same as `maximum` + + Args: + input (Tensor): A tensor, the data type is bfloat16, float16, float32, float64, int32, int64 on GPU. + uint8, int32, int64, float32, float64 are allowed on CPU. + dim (int, optional): The dim along which the maximum is computed. + If this is not specified: see case 1, note that: `None` cannot be passed to this (TypeError will be thrown) + compute the maximum over all elements of `input` and return a Tensor with a single element, + otherwise must be in the range :math:`[-input.ndim, input.ndim)`. + If :math:`dim < 0`, the axis to reduce is :math:`input.ndim + dim`. + Warning: if `dim` is specified, execute static graph will throw exceptions + when not on a GPU device, since max_with_index is not implemented for non-GPU devices + keepdim (bool, optional): Whether to reserve the reduced dimension in the + output Tensor. The result tensor will have one fewer dimension + than the `input` unless :attr:`keepdim` is true, default + value is False. Note that if `dim` does not appear in neither (*args) or (**kwargs), this parameter cannot be passed alone + other (Tensor, optional): the other tensor to perform `paddle.maximum` with. This Tensor should + have the same or broadcast-able shape as the `input`. Note that (`dim` & `keepdim`) and `other` are mutually exclusive + meaning that trying to composite both will result in TypeError + out (Tensor|tuple[Tensor, Tensor], optional): the output Tensor or tuple of (Tensor, int64 Tensor) that can be optionally + given to be used as output buffers. For case 1 and 3 out is just a Tensor, while for case 2 we expect a tuple + + + Returns: + - For case 1: a single value Tensor (0-dim) + - For case 2: a named tuple MinMaxRetType(values: Tensor, indices: Tensor), `values` has the same data type as the `input`, + while indices is always an int64 Tensor, with exactly the same shape as `values`. + MinMaxRetType can be used (indexed, packed, unpacked) in the same way as a regular tuple + - For case 3: see `paddle.maximum` + + + Examples: + .. code-block:: python + + >>> import paddle + + >>> # data_x is a Tensor with shape [2, 4] + >>> # the axis is a int element + >>> x = paddle.to_tensor([[0.2, 0.3, 0.5, 0.9], + ... [0.1, 0.2, 0.6, 0.7]], + ... dtype='float64', stop_gradient=False) + >>> # Case 1: reduce over all dims + >>> result1 = paddle.compat.max(x) + >>> result1 + Tensor(shape=[], dtype=float64, place=Place(gpu:0), stop_gradient=False, + 0.90000000) + + >>> # Case 2: reduce over specified dim + >>> x.clear_grad() + >>> result2 = paddle.compat.max(x, dim=1) + >>> result2 + MinMaxRetType(values=Tensor(shape=[2], dtype=float64, place=Place(gpu:0), stop_gradient=False, + [0.90000000, 0.70000000]), indices=Tensor(shape=[2], dtype=int64, place=Place(gpu:0), stop_gradient=True, + [3, 3])) + >>> result2[0].backward() + >>> x.grad + Tensor(shape=[2, 4], dtype=float64, place=Place(gpu:0), stop_gradient=False, + [[0., 0., 0., 1.], + [0., 0., 0., 1.]]) + + >>> # Case 3: equivalent to `paddle.maximum` + >>> x.clear_grad() + >>> y = paddle.to_tensor([[0.5, 0.4, 0.1, 0.2], + ... [0.3, 0.1, 0.6, 0.7]], + ... dtype='float64', stop_gradient=False) + >>> result3 = paddle.compat.max(x, y) + >>> result3 + Tensor(shape=[2, 4], dtype=float64, place=Place(gpu:0), stop_gradient=False, + [[0.50000000, 0.40000000, 0.50000000, 0.90000000], + [0.30000000, 0.20000000, 0.60000000, 0.70000000]]) + """ + if not isinstance(input, (paddle.pir.Value, paddle.Tensor)): + raise TypeError( + f"input should be a tensor, but got an instance with type '{type(input).__name__}'" + ) + _min_max_tensor_allow_grad(input) + + dim_or_other, keepdim = _min_max_param_checker("max", *args, **kwargs) + + ret = None + if dim_or_other is None: + _check_out_status(out, False) + ret = paddle.max(input) + elif isinstance(dim_or_other, int): + _check_out_status(out, True) + if input.ndim: + if in_dynamic_mode() and not input.place.is_gpu_place(): + _min_max_allow_cpu_composite(input) + indices = paddle.argmax(input, axis=dim_or_other, keepdim=True) + values = paddle.take_along_axis( + input, indices, axis=dim_or_other + ) + if keepdim: + ret = MinMaxRetType(values=values, indices=indices) + else: + ret = MinMaxRetType( + values=values.squeeze_(axis=dim_or_other), + indices=indices.squeeze_(axis=dim_or_other), + ) + else: + vals, inds = _C_ops.max_with_index( + input, dim_or_other, keepdim, False + ) + inds.stop_gradient = True + ret = MinMaxRetType(values=vals, indices=inds) + else: + ret = MinMaxRetType( + values=input, + indices=paddle.zeros( + [], dtype=paddle.int64, device=input.place + ), + ) + else: + _check_out_status(out, False) + ret = _C_ops.maximum(input, dim_or_other) + + if out is not None: + if isinstance(ret, MinMaxRetType): + paddle.assign(ret.values, out[0]) + paddle.assign(ret.indices, out[1]) + else: + paddle.assign(ret, out) + return ret diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index 1f84b1d6067e4f..62ff59ac412546 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -110,6 +110,8 @@ from paddle import Tensor from paddle._typing import DTypeLike +from paddle.utils.decorator_utils import ForbidKeywordsDecorator + __all__ = [] _supported_int_dtype_ = [ @@ -3272,6 +3274,11 @@ def _check_input(x): return out +@ForbidKeywordsDecorator( + illegal_keys={"input", "dim", "other"}, + func_name="paddle.max", + correct_name="paddle.compat.max", +) def max( x: Tensor, axis: int | Sequence[int] | None = None, @@ -3431,6 +3438,11 @@ def max( return out +@ForbidKeywordsDecorator( + illegal_keys={"input", "dim", "other"}, + func_name="paddle.min", + correct_name="paddle.compat.min", +) def min( x: Tensor, axis: int | Sequence[int] | None = None, diff --git a/test/ir/pir/cinn/symbolic/test_minmax_infer_sym.py b/test/ir/pir/cinn/symbolic/test_minmax_infer_sym.py new file mode 100644 index 00000000000000..81975c8029bb33 --- /dev/null +++ b/test/ir/pir/cinn/symbolic/test_minmax_infer_sym.py @@ -0,0 +1,119 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import unittest +from os.path import dirname + +import numpy as np +from test_infer_sym_shape_utils import ( + TestBase, + check_infer_results, +) + +import paddle +from paddle.static import InputSpec + +sys.path.append(dirname(dirname(__file__))) +from utils import apply_to_static + +# NOTE(SigureMo): Disable the CSE optimization to avoid op number change. +paddle.set_flags({"FLAGS_enable_cse_in_dy2st": False}) + + +class MaxMinWithIndexNet(paddle.nn.Layer): + def __init__(self): + super().__init__() + + def forward(self, x): + min_vals, min_inds = paddle.compat.min(x, dim=-1, keepdim=False) + max_vals, max_inds = paddle.compat.max(x, dim=-1, keepdim=True) + return min_vals + max_vals.squeeze(axis=-1), min_inds + max_inds + + +class MinMaxWithIndexOpInferSymbolicShapeTest(TestBase): + def prepare_data(self): + self.cases = [np.random.rand(3, 4, 5, 6), np.random.rand(257)] + self.expected = [ + [ + 'shape[S0, S1, S2], data[NULL]', + 'shape[S0, Broadcast(S0, S1), Broadcast(S1, S2), S2], data[NULL]', + ], + ['shape[], data[NULL]', 'shape[1], data[NULL]'], + ] + + def test_eval_symbolic(self): + net = MaxMinWithIndexNet() + + for i in range(len(self.cases)): + x = self.cases[i] + x_spec = InputSpec( + shape=[None for index in range(len(x.shape))], dtype='float32' + ) + input_spec = [x_spec] + net = apply_to_static(net, False, input_spec) + net.eval() + check_infer_results( + net, input_spec, 'builtin.shadow_output', self.expected[i] + ) + + return True + + +class MinMaxWithIndexRawNet(paddle.nn.Layer): + def __init__(self): + super().__init__() + + def forward(self, x): + x = x * 2 + 1 + min_vals, min_inds = paddle._C_ops.min_with_index(x, 1, False, True) + max_vals, max_inds = paddle._C_ops.max_with_index(x, 2, True, True) + return min_vals + max_vals.squeeze(), min_inds * max_inds + + +class MinMaxWithIndexOpRawInferShapeTest(TestBase): + def prepare_data(self): + self.cases = [np.random.rand(4, 5, 6), np.random.rand(3, 7, 1, 2)] + self.expected = [ + [ + 'shape[], data[NULL]', + 'shape[1, 1, 1], data[NULL]', + ], + ['shape[], data[NULL]', 'shape[1, 1, 1, 1], data[NULL]'], + ] + + @unittest.skipIf( + not paddle.core.is_compiled_with_cuda(), + "core is not compiled with CUDA, skipping", + ) + def test_eval_symbolic(self): + net = MinMaxWithIndexRawNet() + + for i in range(len(self.cases)): + x = self.cases[i] + x_spec = InputSpec( + shape=[None for index in range(len(x.shape))], dtype='float32' + ) + input_spec = [x_spec] + net = apply_to_static(net, False, input_spec) + net.eval() + check_infer_results( + net, input_spec, 'builtin.shadow_output', self.expected[i] + ) + + return True + + +if __name__ == "__main__": + unittest.main() diff --git a/test/legacy_test/test_compat_minmax.py b/test/legacy_test/test_compat_minmax.py new file mode 100644 index 00000000000000..0354e72a3759b9 --- /dev/null +++ b/test/legacy_test/test_compat_minmax.py @@ -0,0 +1,564 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np + +import paddle +from paddle.base import core + + +class TestCompatMinMaxBase(unittest.TestCase): + """The default base class is for testing min-related ops""" + + def __init__( + self, + *args, + test_op=paddle.compat.min, + origin_op=paddle.min, + index_op=paddle.argmin, + test_op_name="paddle.compat.min", + origin_op_name="paddle.min", + **kwargs, + ): + super().__init__(*args, **kwargs) + paddle.disable_static() + self.test_op = test_op + self.origin_op = origin_op + self.index_op = index_op + self.test_op_name = test_op_name + self.origin_op_name = origin_op_name + np.random.seed(1) + + def test_case1_simple_reduce_all(self): + data = paddle.to_tensor([[1.0, 2.0], [3.0, 4.0]], dtype='float32') + val = self.test_op(data) + if self.test_op_name.endswith("min"): + self.assertAlmostEqual(val.item(), 1.0) + else: + self.assertAlmostEqual(val.item(), 4.0) + + def test_case2_reduce_dim(self): + """Test dim/keepdim""" + data = paddle.to_tensor( + [[[5, 8], [2, 1]], [[7, 3], [9, 6]]], dtype='float32' + ) + if self.test_op_name.endswith("min"): + in_dim = 1 + result = self.test_op(data, dim=in_dim) + expected_res = np.array([[[5, 3], [2, 1]]]) + self.assertEqual(result.values.shape, [2, 2]) + np.testing.assert_array_equal( + result.values.numpy(), np.array([[2, 1], [7, 3]]) + ) + np.testing.assert_array_equal( + result.indices.numpy(), np.array([[1, 1], [0, 0]]) + ) + else: + in_dim = 2 + result = self.test_op(data, dim=in_dim) + expected_res = np.array([[[7, 8], [9, 6]]]) + self.assertEqual(result.values.shape, [2, 2]) + np.testing.assert_array_equal( + result.values.numpy(), np.array([[8, 2], [7, 9]]) + ) + np.testing.assert_array_equal( + result.indices.numpy(), np.array([[1, 0], [0, 0]]) + ) + + result_keep = self.test_op(data, dim=0, keepdim=True) + self.assertEqual(result_keep.values.shape, [1, 2, 2]) + np.testing.assert_array_equal(result_keep.values.numpy(), expected_res) + result_keep = self.test_op(data, 0, keepdim=True) + np.testing.assert_array_equal(result_keep.values.numpy(), expected_res) + + result_neg = self.test_op(data, dim=in_dim - 3) + np.testing.assert_array_equal( + result_neg.values.numpy(), result.values.numpy() + ) + + def test_case2_grad(self): + data = paddle.to_tensor( + [[[1.0, 2.0], [1.0, 3.0]], [[4.0, 1.0], [5.0, 1.0]]], + dtype='float32', + stop_gradient=False, + ) + y = data * 2 + + result = self.test_op(y, dim=2) + result.values.backward() + + if self.test_op_name.endswith("min"): + expected_grad = np.array( + [[[2.0, 0.0], [2.0, 0.0]], [[0.0, 2.0], [0.0, 2.0]]] + ) + expected_grad2 = np.array( + [[[2.0, 4.0], [0.0, 0.0]], [[8.0, 2.0], [0.0, 0.0]]] + ) + else: + expected_grad = np.array( + [[[0.0, 2.0], [0.0, 2.0]], [[2.0, 0.0], [2.0, 0.0]]] + ) + expected_grad2 = np.array( + [[[2.0, 0.0], [0.0, 6.0]], [[0.0, 2.0], [10.0, 0.0]]] + ) + np.testing.assert_allclose(data.grad.numpy(), expected_grad, atol=1e-6) + + data.clear_grad() + y = data * data + result = self.test_op(y, dim=1) + result[0].backward() + np.testing.assert_allclose(data.grad.numpy(), expected_grad2, atol=1e-6) + + def test_case3_elementwise(self): + x = paddle.to_tensor([[1, 5], [4, 2]], dtype='float32') + y = paddle.to_tensor([[3, 2], [1, 6]], dtype='float32') + z = paddle.to_tensor([3, 4], dtype='float32') + broadcast_res = self.test_op(x, z) + + result = self.test_op(x, y) + if self.test_op_name.endswith("min"): + np.testing.assert_array_equal( + result.numpy(), np.array([[1, 2], [1, 2]]) + ) + np.testing.assert_array_equal( + broadcast_res.numpy(), np.array([[1, 4], [3, 2]]) + ) + else: + np.testing.assert_array_equal( + result.numpy(), np.array([[3, 5], [4, 6]]) + ) + np.testing.assert_array_equal( + broadcast_res.numpy(), np.array([[3, 5], [4, 4]]) + ) + + def test_case3_grad(self): + x = paddle.to_tensor( + [[1.0, 2.0], [3.0, 4.0]], dtype=paddle.float32, stop_gradient=False + ) + y = paddle.to_tensor( + [[0.5, 2.5], [2.0, 3.5]], dtype=paddle.float32, stop_gradient=False + ) + + val = self.test_op(x, y) + val.backward() + + expected_x_grad = np.array([[0.0, 1.0], [0.0, 0.0]]) + expected_y_grad = np.array([[1.0, 0.0], [1.0, 1.0]]) + if self.test_op_name.endswith("max"): + expected_x_grad = 1 - expected_x_grad + expected_y_grad = 1 - expected_y_grad + + np.testing.assert_allclose(x.grad.numpy(), expected_x_grad) + np.testing.assert_allclose(y.grad.numpy(), expected_y_grad) + + def test_edge_cases(self): + """Edge cases test""" + # uniform distributed gradient + uniform_data = paddle.ones([2, 3], dtype='float64') + uniform_data.stop_gradient = False + val = self.test_op(uniform_data) + val.sum().backward() + # uniformly distributed + expected_grad = np.full((2, 3), 1.0 / 6.0) + np.testing.assert_allclose(uniform_data.grad.numpy(), expected_grad) + + uniform_data.clear_grad() + val = self.test_op(uniform_data, 0) + val.values.sum().backward() + # take_along_axis like gradient behavior + expected_grad = np.array([[1.0, 1.0, 1.0], [0.0, 0.0, 0.0]]) + np.testing.assert_allclose(uniform_data.grad.numpy(), expected_grad) + + # 0-dim tensor + dim0_tensor = paddle.to_tensor(2, dtype='float32') + val = self.test_op(dim0_tensor) + np.testing.assert_allclose(val.numpy(), np.array(2.0, dtype=np.float32)) + + # 1-dim tensor + dim1_tensor = paddle.to_tensor([1], dtype='uint8') + val = self.test_op(dim1_tensor, dim=-1, keepdim=True) + np.testing.assert_array_equal( + val[0].numpy(), np.array([1], dtype=np.uint8) + ) + np.testing.assert_array_equal( + val[1].numpy(), np.array([0], dtype=np.int64) + ) + + def test_compare_with_index_ops_to_origin(self): + dtypes = ['float32', 'float64', 'int32', 'int64', 'uint8'] + + for i, dtype in enumerate(dtypes): + data = paddle.to_tensor([[1, 2, 3], [4, 5, 6]], dtype=dtype) + # `bfloat16`, `uint8` and `float16` are rejected for min/argmin + vals_inds = self.test_op(data, dim=0) + self.assertEqual(vals_inds.values.dtype, data.dtype) + self.assertEqual(vals_inds.indices.dtype, paddle.int64) + + origin_indices = self.index_op(data, axis=0, dtype="int64") + if dtype != 'uint8': + origin_values = self.origin_op(data, axis=0) + else: + origin_values = paddle.take_along_axis( + data, origin_indices.unsqueeze(0), axis=0 + ) + origin_values.squeeze_(axis=0) + if i < 4: # floating point + np.testing.assert_allclose( + vals_inds.values.numpy(), origin_values.numpy() + ) + else: + np.testing.assert_array_equal( + vals_inds.values.numpy(), origin_values.numpy() + ) + np.testing.assert_array_equal( + vals_inds[1].numpy(), origin_indices.numpy() + ) + + def test_case1_out(self): + data = np.random.randn(4, 5, 6).astype(np.float32) + x = paddle.to_tensor(data, stop_gradient=False) + y = paddle.to_tensor(data, stop_gradient=False) + out = paddle.to_tensor(0) + self.test_op(x, out=out) + gt_out = self.origin_op(y) + gt_out.backward() + out.backward() + + np.testing.assert_allclose(out.numpy(), gt_out.numpy()) + np.testing.assert_allclose(x.grad.numpy(), y.grad.numpy()) + + def test_case2_out(self): + for type_to_use in (list, tuple): + data = np.random.randn(3, 17, 5).astype(np.float32) + x = paddle.to_tensor(data, stop_gradient=False) + y = paddle.to_tensor(data, stop_gradient=False) + out = type_to_use((paddle.to_tensor(0), paddle.to_tensor(0))) + self.test_op(x, dim=1, out=out) + gt_vals = self.origin_op(y, axis=1) + gt_inds = self.index_op(y, axis=1) + gt_vals.backward() + out[0].backward() + + np.testing.assert_allclose(out[0].numpy(), gt_vals.numpy()) + np.testing.assert_array_equal(out[1].numpy(), gt_inds.numpy()) + np.testing.assert_allclose(x.grad.numpy(), y.grad.numpy()) + + def test_case3_out(self): + data = np.random.randn(3, 4, 5).astype(np.float32) + x = paddle.to_tensor(data) + y = paddle.to_tensor(data) + out = paddle.to_tensor(0) + self.test_op(x, paddle.ones_like(x), out=out) + if self.test_op_name.endswith("min"): + gt_vals = paddle.minimum(x, paddle.ones_like(x)) + else: + gt_vals = paddle.maximum(x, paddle.ones_like(x)) + np.testing.assert_allclose(out.numpy(), gt_vals.numpy()) + + def test_error_handling(self): + """Test whether correct exception will be thrown. Skip error messages (some of them are long)""" + + err_msg1 = ( + "Tensors with integral type: 'paddle.int32' should stop gradient." + ) + err_msg2 = ( + f"{self.origin_op_name}() received unexpected keyword arguments 'dim', 'input'. " + f"\nDid you mean to use {self.test_op_name}() instead?" + ) + err_msg3 = ( + f"{self.test_op_name}() received unexpected keyword argument 'axis'. " + f"\nDid you mean to use {self.origin_op_name}() instead?" + ) + err_msg4 = ( + "Non-CUDA GPU placed Tensor does not have 'paddle.float16' op registered.\n" + "Paddle support following DataTypes: int32, int64, float64, float32, uint8" + ) + err_msg5 = ( + "input should be a tensor, but got an instance with type 'list'" + ) + + # empty tensor + empty_tensor = paddle.to_tensor([], dtype='float32') + with self.assertRaises(ValueError): + self.test_op(empty_tensor) + + # mixed parameters case 1 + input_ts = paddle.to_tensor([1, 2, 3], dtype='float32') + other_ts = paddle.to_tensor([1]) + with self.assertRaises(TypeError): + self.test_op(input_ts, other=other_ts, dim=0) + + # mixed parameters case 2 + with self.assertRaises(TypeError): + self.test_op(input_ts, 0, other=other_ts) + + # trying to perform grad ops for integral types + with self.assertRaises(TypeError) as cm: + tensor = paddle.ones([2, 2], dtype=paddle.int32) + tensor.stop_gradient = False + tensors = self.test_op(tensor, dim=0) + self.assertEqual(str(cm.exception), err_msg1) + + # explicit None case 1 + with self.assertRaises(TypeError) as cm: + self.test_op(input_ts, dim=None) + + # explicit None case 2 + with self.assertRaises(TypeError) as cm: + self.test_op(input_ts, None, keepdim=True) + + # keepdim specified without specifying dim + with self.assertRaises(TypeError) as cm: + self.test_op(input_ts, keepdim=True) + + # Wrong *args specification case 1 + with self.assertRaises(TypeError) as cm: + self.test_op(input_ts, False) + + # Wrong *args specification case 2 + with self.assertRaises(TypeError) as cm: + self.test_op(input_ts, other_ts, True) + + # Tensor input for dim case 1 + with self.assertRaises(TypeError) as cm: + self.test_op(input_ts, dim=paddle.to_tensor([0])) + + # Tensor input for dim case 2 + with self.assertRaises(TypeError) as cm: + self.test_op(input_ts, dim=paddle.to_tensor(0)) + + # Tensor input for dim case 3 + with self.assertRaises(TypeError) as cm: + self.test_op(input_ts, paddle.to_tensor([0]), keepdim=True) + + # Tensor input for dim case 4 + with self.assertRaises(TypeError) as cm: + self.test_op(input_ts, paddle.to_tensor([0]), True) + + # Duplicate Arguments case 1 + with self.assertRaises(TypeError) as cm: + self.test_op(input_ts, 0, dim=0) + + # Duplicate Arguments case 2 + with self.assertRaises(TypeError) as cm: + self.test_op(input_ts, other_ts, other=0) + + # Duplicate Arguments case 3 + with self.assertRaises(TypeError) as cm: + self.test_op(input_ts, dim=0, other=0, keepdim=True) + + # Wrong API used case 1 + with self.assertRaises(TypeError) as cm: + self.origin_op(input=input_ts, dim=0) + self.assertEqual(str(cm.exception), err_msg2) + + # Wrong API used case 2 + with self.assertRaises(TypeError) as cm: + self.test_op(input_ts, axis=0) + self.assertEqual(str(cm.exception), err_msg3) + + # Rejected on CPU types + with self.assertRaises(TypeError) as cm: + tensor = paddle.to_tensor([1, 2, 3], dtype="float16") + cpu_tensor = tensor.to("cpu") + self.test_op(cpu_tensor, dim=0) + self.assertEqual(str(cm.exception), err_msg4) + + # Wrong input type + with self.assertRaises(TypeError) as cm: + self.test_op([1, 2]) + self.assertEqual(str(cm.exception), err_msg5) + + # Wrong second parameter type + with self.assertRaises(TypeError): + self.test_op(input_ts, "first_dim") + + paddle.enable_static() + with ( + self.assertRaises(RuntimeError) as cm, + paddle.static.program_guard(paddle.static.Program()), + ): + x = paddle.static.data(name='x', shape=[None, 6], dtype='float32') + result0, result1 = self.test_op( + paddle.zeros([3, 4]), + dim=1, + out=( + paddle.zeros([3, 4]), + paddle.zeros([3, 4], dtype=paddle.int64), + ), + ) + + place = ( + paddle.CUDAPlace(0) + if paddle.is_compiled_with_cuda() + else paddle.CPUPlace() + ) + paddle.static.Executor(place).run() + self.assertEqual( + str(cm.exception), + "Using `out` static graph CINN backend is currently not supported. Directly return the tensor tuple instead.\n", + ) + paddle.disable_static() + + def test_wrong_out_input(dim, out_input): + with self.assertRaises(TypeError) as cm: + if dim is None: + self.test_op(input_ts, out=out_input) + else: + self.test_op(input_ts, dim=dim, out=out_input) + + test_wrong_out_input(0, [0, paddle.to_tensor(0)]) + test_wrong_out_input(0, paddle.to_tensor(0)) + test_wrong_out_input(None, 0) + test_wrong_out_input(None, (paddle.to_tensor(0),)) + + def _compare_with_origin_static( + self, input_shape, axis_or_other=0, keepdim=False, use_out=False + ): + """Test Case 2 and Case 3 for return output or param output in static graph mode + + TODO(heqianyue): DO NOT set use_out for now! + Currently, static graph + CINN backend will result in unresolved dependency bug for assign op + This test is disabled for now, but will be useful when dy2st bug is fixed. + """ + numel = 1 + for v in input_shape: + numel *= v + paddle.enable_static() + with paddle.static.program_guard(paddle.static.Program()): + input_tensor = paddle.arange(numel, dtype=paddle.float32).reshape( + input_shape + ) + + y = input_tensor**2 + if isinstance(axis_or_other, int): + if use_out: + out = [paddle.to_tensor(0), paddle.to_tensor([0])] + self.test_op(y, dim=axis_or_other, keepdim=keepdim, out=out) + values, indices = out + else: + values, indices = self.test_op( + y, dim=axis_or_other, keepdim=keepdim + ) + gt_values = self.origin_op( + y, axis=axis_or_other, keepdim=keepdim + ) + gt_indices = self.index_op( + y, axis=axis_or_other, keepdim=keepdim + ) + else: + if use_out: + out = paddle.to_tensor(0) + self.test_op(y, axis_or_other, out=out) + values, indices = out, paddle.to_tensor(0) + else: + values, indices = self.test_op(y, axis_or_other) + if self.test_op_name.endswith("min"): + gt_values = paddle.minimum(y, axis=axis_or_other, out=None) + else: + gt_values = paddle.maximum(y, axis=axis_or_other) + gt_indices = paddle.to_tensor(0) + + place = paddle.CUDAPlace(0) + exe = paddle.static.Executor(place) + values_np, indices_np, gt_values_np, gt_indices_np = exe.run( + fetch_list=[values, indices, gt_values, gt_indices] + ) + np.testing.assert_allclose(values_np, gt_values_np) + np.testing.assert_equal(indices_np, gt_indices_np) + paddle.disable_static() + + @unittest.skipIf( + not core.is_compiled_with_cuda(), + "core is not compiled with CUDA, skipping", + ) + def test_static_graph(self): + self._compare_with_origin_static([3, 10, 2], 1) + self._compare_with_origin_static([3, 10, 2], 0, keepdim=True) + self._compare_with_origin_static([17], 0) + + @unittest.skipIf( + not core.is_compiled_with_cuda(), + "core is not compiled with CUDA, skipping", + ) + def test_static_unary_shape_infer_1(self): + # min/max with index is a GPU only op, no need for testing if there is no GPU + + @paddle.jit.to_static(full_graph=True) + def static_func1(x): + y = paddle.zeros([2, 3, 4]) + return paddle._C_ops.min_with_index(y, x.shape[0], False, False) + + @paddle.jit.to_static(full_graph=True) + def static_func2(x): + y = paddle.zeros([2, 3, 4]) + return paddle._C_ops.min_with_index(y, x.shape[0], True, False) + + input_ts1 = paddle.to_tensor([1]) + input_ts2 = paddle.to_tensor([1, 2]) + val1, ind1 = static_func1(input_ts1) + val2, ind2 = static_func2(input_ts2) + + self.assertEqual(val1.shape, [2, 4]) + self.assertEqual(ind1.shape, [2, 4]) + self.assertEqual(val2.shape, [2, 3, 1]) + self.assertEqual(ind2.shape, [2, 3, 1]) + + @unittest.skipIf( + not core.is_compiled_with_cuda(), + "core is not compiled with CUDA, skipping", + ) + def test_static_unary_shape_infer_2(self): + # min/max with index is a GPU only op, no need for testing if there is no GPU + + @paddle.jit.to_static(full_graph=True) + def static_func1(x): + dim = paddle.arange(0, 1).shape[0] + y = paddle.zeros([2, 3, 4]) + return paddle._C_ops.max_with_index(y, dim, False, True) + + @paddle.jit.to_static(full_graph=True) + def static_func2(x): + dim = paddle.arange(0, 2).shape[0] + y = paddle.zeros([2, 3, 4]) + return paddle._C_ops.max_with_index(y, dim, True, True) + + x1 = paddle.to_tensor([1]) + x2 = paddle.to_tensor([1, 2]) + val1, ind1 = static_func1(x1) + val2, ind2 = static_func2(x2) + + self.assertEqual(val1.shape, []) + self.assertEqual(ind1.shape, []) + self.assertEqual(val2.shape, [1, 1, 1]) + self.assertEqual(ind2.shape, [1, 1, 1]) + + +class TestCompatMax(TestCompatMinMaxBase): + def __init__(self, *args, **kwargs): + super().__init__( + *args, + test_op=paddle.compat.max, + origin_op=paddle.max, + index_op=paddle.argmax, + test_op_name="paddle.compat.max", + origin_op_name="paddle.max", + **kwargs, + ) + + +if __name__ == '__main__': + unittest.main() diff --git a/test/legacy_test/test_minmax_with_index_op.py b/test/legacy_test/test_minmax_with_index_op.py new file mode 100644 index 00000000000000..d80d89ae3e3c09 --- /dev/null +++ b/test/legacy_test/test_minmax_with_index_op.py @@ -0,0 +1,235 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from op_test import OpTest + +import paddle +from paddle.base import core + +np.random.seed(0) +paddle.enable_static() + + +def max_with_index(x, dim=None, keepdim=False): + """makeshift wrapper for the C++ op, extracted from compat.max""" + vals, inds = paddle._C_ops.max_with_index(x, dim, keepdim, False) + inds.stop_gradient = True + return vals, inds + + +def min_with_index(x, dim=None, keepdim=False): + """makeshift wrapper for the C++ op, extracted from compat.min""" + vals, inds = paddle._C_ops.min_with_index(x, dim, keepdim, False) + inds.stop_gradient = True + return vals, inds + + +@unittest.skipIf( + not core.is_compiled_with_cuda(), + "core is not compiled with CUDA, skipping", +) +class TestMaxWithIndexBasic(OpTest): + def setUp(self): + self.set_op_input_attr() + self.set_testing_op() + self.set_data_type() + self.set_input_shape() + if self.is_int: + inputs = np.random.randint(0, 255, self.input_shape).astype( + self.dtype + ) + else: + inputs = np.random.rand(*self.input_shape).astype(self.dtype) + + self.prim_op_type = "prim" + self.python_out_sig = ["values", "indices"] + self.attrs = {"dim": self.dim, "keepdim": self.keepdim} + + gt_values = self.value_op(inputs, axis=self.dim, keepdims=self.keepdim) + gt_indices = self.index_op(inputs, axis=self.dim, keepdims=self.keepdim) + self.inputs = { + 'x': inputs, + } + self.outputs = { + 'values': gt_values, + 'indices': gt_indices, + } + + def compute_grad(self): + grad = np.zeros_like(self.inputs['x'], dtype=self.dtype) + indices = ( + self.outputs['indices'] + if self.keepdim + else np.expand_dims(self.outputs['indices'], axis=self.dim) + ) + np.put_along_axis(grad, indices, 1, axis=self.dim) + return grad + + def set_testing_op(self): + self.op_type = "max_with_index" + self.python_api = max_with_index + self.public_python_api = max_with_index + self.value_op = np.max + self.index_op = np.argmax + + def set_data_type(self): + self.dtype = np.float64 + self.is_int = False + + def set_input_shape(self): + self.input_shape = [30, 257, 21] + + def set_op_input_attr(self): + self.dim = 0 + self.keepdim = False + + def test_check_output(self): + self.check_output(check_pir=True) + + def test_check_grad(self): + grad = self.compute_grad() + self.check_grad( + ['x'], + 'values', + check_pir=True, + user_defined_grads=[grad * (1.0 / grad.sum())], + ) + + +@unittest.skipIf( + not core.is_compiled_with_cuda(), + "core is not compiled with CUDA, skipping", +) +class TestMinWithIndexBasic(TestMaxWithIndexBasic): + def set_testing_op(self): + self.op_type = "min_with_index" + self.python_api = min_with_index + self.public_python_api = min_with_index + self.value_op = np.min + self.index_op = np.argmin + + +@unittest.skipIf( + not core.is_compiled_with_cuda(), + "core is not compiled with CUDA, skipping", +) +class TestMinWithIndexKeepDim(TestMinWithIndexBasic): + def set_op_input_attr(self): + self.dim = 1 + self.keepdim = True + + +@unittest.skipIf( + not core.is_compiled_with_cuda(), + "core is not compiled with CUDA, skipping", +) +class TestMaxWithIndexKeepDim(TestMaxWithIndexBasic): + def set_op_input_attr(self): + self.dim = 1 + self.keepdim = True + + +@unittest.skipIf( + not core.is_compiled_with_cuda(), + "core is not compiled with CUDA, skipping", +) +class TestMinWithIndexNegDim(TestMinWithIndexBasic): + def set_op_input_attr(self): + self.dim = -1 + self.keepdim = False + + +@unittest.skipIf( + not core.is_compiled_with_cuda(), + "core is not compiled with CUDA, skipping", +) +class TestMaxWithIndexNegDim(TestMaxWithIndexBasic): + def set_op_input_attr(self): + self.dim = 1 + self.keepdim = False + + +@unittest.skipIf( + not core.is_compiled_with_cuda(), + "core is not compiled with CUDA, skipping", +) +class TestMinWithIndexMoreTypeAndShape(TestMinWithIndexBasic): + def set_op_input_attr(self): + self.dim = 1 + self.keepdim = True + + def set_data_type(self): + self.dtype = np.float32 + self.is_int = False + + def set_input_shape(self): + self.input_shape = [10, 20, 16] + + +@unittest.skipIf( + not core.is_compiled_with_cuda(), + "core is not compiled with CUDA, skipping", +) +class TestMinWithIndexFP16(TestMinWithIndexBasic): + def set_data_type(self): + self.dtype = np.float16 + self.is_int = False + + +@unittest.skipIf( + not core.is_compiled_with_cuda(), + "core is not compiled with CUDA, skipping", +) +class TestMaxWithIndexU8(TestMaxWithIndexBasic): + def set_data_type(self): + self.dtype = np.uint8 + self.is_int = True + + @unittest.skipIf( + True, + "integral type does not need to check grad", + ) + def test_check_grad(self): + pass + + +@unittest.skipIf( + not core.is_compiled_with_cuda(), + "core is not compiled with CUDA, skipping", +) +class TestMaxWithIndexMoreTypeAndShape(TestMaxWithIndexBasic): + def set_op_input_attr(self): + self.dim = -1 + self.keepdim = False + + def set_data_type(self): + self.dtype = np.uint8 + self.is_int = True + + def set_input_shape(self): + self.input_shape = [4095] + + @unittest.skipIf( + True, + "integral type does not need to check grad", + ) + def test_check_grad(self): + pass + + +if __name__ == "__main__": + unittest.main() diff --git a/test/legacy_test/test_zero_dim_sundry_dygraph_api.py b/test/legacy_test/test_zero_dim_sundry_dygraph_api.py index bc958ca42bf242..29d3c5961d6241 100644 --- a/test/legacy_test/test_zero_dim_sundry_dygraph_api.py +++ b/test/legacy_test/test_zero_dim_sundry_dygraph_api.py @@ -551,6 +551,98 @@ def test_argmax(self): out = paddle.argmax(x, keepdim=True) self.assertEqual(out.shape, [1, 1]) + def _make_compat_minmax_test(self, func_name): + # 1) x is 0D + x = paddle.rand([]) + val1, ind1 = func_name(x, 0) + val2, ind2 = func_name(x, -1) + val3 = func_name(x) + + self.assertEqual(val1.shape, []) + self.assertEqual(ind1.shape, []) + np.testing.assert_allclose(val1, x) + np.testing.assert_allclose(ind1, 0) + + self.assertEqual(val2.shape, []) + self.assertEqual(ind2.shape, []) + np.testing.assert_allclose(val2, x) + np.testing.assert_allclose(ind2, 0) + + self.assertEqual(val3.shape, []) + np.testing.assert_allclose(val3, x) + + # 2) x is 1D + x = paddle.rand([5]) + val, ind = func_name(x, 0) + self.assertEqual(val.shape, []) + self.assertEqual(ind.shape, []) + + # 3) x is ND + x = paddle.rand([3, 5]) + val, ind = func_name(x, dim=1) + self.assertEqual(val.shape, [3]) + self.assertEqual(ind.shape, [3]) + + val = func_name(x) + self.assertEqual(val.shape, []) + + # 4) x is ND, keepdim=True + x = paddle.rand([3, 5]) + val, ind = func_name(x, dim=0, keepdim=True) + self.assertEqual(val.shape, [1, 5]) + self.assertEqual(ind.shape, [1, 5]) + + # 5) test backward + x = paddle.randn([4, 5]) + x.stop_gradient = False + + val, ind = func_name(x, dim=0) + val.backward() + self.assertEqual(x.grad.shape, [4, 5]) + + def test_minmax_with_index(self): + # min/max_with_index is a GPU only op + if not paddle.is_compiled_with_cuda(): + return + # 1) x is 0D + x = paddle.to_tensor(1) + val1, ind1 = paddle._C_ops.min_with_index(x, 0, False, True) + + self.assertEqual(val1.shape, []) + self.assertEqual(ind1.shape, []) + np.testing.assert_allclose(val1, 1) + np.testing.assert_allclose(ind1, 0) + + # 2) x is 1D + x = paddle.to_tensor([1, 1, 1]) + val1, ind1 = paddle._C_ops.max_with_index(x, 0, False, True) + + self.assertEqual(val1.shape, []) + self.assertEqual(ind1.shape, []) + np.testing.assert_allclose(val1, 1) + np.testing.assert_allclose(ind1, 0) + + # 3) x is 2D + x = paddle.zeros([2, 3]) + val1, ind1 = paddle._C_ops.min_with_index(x, 1, False, True) + val2, ind2 = paddle._C_ops.max_with_index(x, 1, True, True) + + self.assertEqual(val1.shape, []) + self.assertEqual(ind1.shape, []) + np.testing.assert_allclose(val1, 0) + np.testing.assert_allclose(ind1, 0) + + self.assertEqual(val2.shape, [1, 1]) + self.assertEqual(ind2.shape, [1, 1]) + np.testing.assert_allclose(val2, 0) + np.testing.assert_allclose(ind2, 0) + + def test_compat_min(self): + self._make_compat_minmax_test(paddle.compat.min) + + def test_compat_max(self): + self._make_compat_minmax_test(paddle.compat.max) + def test_kthvalue(self): # 1) x is 0D x = paddle.randn([])