Skip to content

Commit 3fb103f

Browse files
committed
[API-Compat] Add compat.min/max EN doc
Attempting to fix integral type gradient computation (rejection)
1 parent e8c78b7 commit 3fb103f

File tree

2 files changed

+226
-4
lines changed

2 files changed

+226
-4
lines changed

paddle/phi/kernels/gpu/reduce_kernel.cu

Lines changed: 51 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
// limitations under the License.
1414

1515
#include "paddle/phi/kernels/reduce_kernel.h"
16+
#include <type_traits>
1617

1718
#include "paddle/phi/kernels/gpu/reduce_amin_amax_common.h"
1819
#include "paddle/phi/kernels/reduce_amin_grad_kernel.h"
@@ -159,7 +160,15 @@ void ReduceAMaxGradKernel(const Context& dev_ctx,
159160
dev_ctx, x, out, out_grad, dims, keep_dim, reduce_all, x_grad);
160161
}
161162

162-
template <typename T, typename Context>
163+
template <typename T>
164+
using EnableIfInteger =
165+
typename std::enable_if<std::is_integral<T>::value, int>::type;
166+
167+
template <typename T>
168+
using EnableIfNonInteger =
169+
typename std::enable_if<!std::is_integral<T>::value, int>::type;
170+
171+
template <typename T, typename Context, EnableIfNonInteger<T> = 0>
163172
void MinWithIndexGradKernel(const Context& dev_ctx,
164173
const DenseTensor& x,
165174
const DenseTensor& values,
@@ -174,7 +183,25 @@ void MinWithIndexGradKernel(const Context& dev_ctx,
174183
dev_ctx, x, values, values_grad, {dim_val}, keepdims, flatten, x_grad);
175184
}
176185

177-
template <typename T, typename Context>
186+
template <typename T, typename Context, EnableIfInteger<T> = 0>
187+
void MinWithIndexGradKernel(const Context& dev_ctx,
188+
const DenseTensor& x,
189+
const DenseTensor& values,
190+
const DenseTensor& values_grad,
191+
const Scalar& dim,
192+
bool keepdims,
193+
bool flatten,
194+
DenseTensor* x_grad) {
195+
std::string dtype_name = phi::DataTypeToString(x.dtype());
196+
PADDLE_ENFORCE_EQ(
197+
0,
198+
1,
199+
phi::errors::InvalidArgument(
200+
"Integer type '%s' is not allowed to have stop_gradient=False.",
201+
dtype_name.c_str()));
202+
}
203+
204+
template <typename T, typename Context, EnableIfNonInteger<T> = 0>
178205
void MaxWithIndexGradKernel(const Context& dev_ctx,
179206
const DenseTensor& x,
180207
const DenseTensor& values,
@@ -189,6 +216,24 @@ void MaxWithIndexGradKernel(const Context& dev_ctx,
189216
dev_ctx, x, values, values_grad, {dim_val}, keepdims, flatten, x_grad);
190217
}
191218

219+
template <typename T, typename Context, EnableIfInteger<T> = 0>
220+
void MaxWithIndexGradKernel(const Context& dev_ctx,
221+
const DenseTensor& x,
222+
const DenseTensor& values,
223+
const DenseTensor& values_grad,
224+
const Scalar& dim,
225+
bool keepdims,
226+
bool flatten,
227+
DenseTensor* x_grad) {
228+
std::string dtype_name = phi::DataTypeToString(x.dtype());
229+
PADDLE_ENFORCE_EQ(
230+
0,
231+
1,
232+
phi::errors::InvalidArgument(
233+
"Integer type '%s' is not allowed to have stop_gradient=False.",
234+
dtype_name.c_str()));
235+
}
236+
192237
template <typename T, typename Context>
193238
void ReduceMaxGradKernel(const Context& dev_ctx,
194239
const DenseTensor& x,
@@ -320,7 +365,9 @@ PD_REGISTER_KERNEL(max_with_index_grad,
320365
phi::MaxWithIndexGradKernel,
321366
float,
322367
double,
368+
uint8_t,
323369
int,
370+
int16_t,
324371
int64_t,
325372
phi::dtype::float16,
326373
phi::dtype::bfloat16) {}
@@ -357,7 +404,9 @@ PD_REGISTER_KERNEL(min_with_index_grad,
357404
phi::MinWithIndexGradKernel,
358405
float,
359406
double,
407+
uint8_t,
360408
int,
409+
int16_t,
361410
int64_t,
362411
phi::dtype::float16,
363412
phi::dtype::bfloat16) {}

python/paddle/tensor/compat.py

Lines changed: 175 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -293,14 +293,108 @@ def try_get_keys(key):
293293
return dim_or_other, keepdim
294294

295295

296-
@forbid_keywords(['x', 'axis'], 'paddle.min')
296+
def _min_max_tensor_allow_grad(input: Tensor):
297+
"""Prevent integral input tensor type to have `stop_gradient=False`"""
298+
in_dtype = input.dtype
299+
if (
300+
in_dtype == paddle.int32
301+
or in_dtype == paddle.int64
302+
or in_dtype == paddle.uint8
303+
or in_dtype == paddle.int16
304+
):
305+
if not input.stop_gradient:
306+
raise TypeError(
307+
f"Tensors with integral type: '{in_dtype}' should stop gradient."
308+
)
309+
310+
311+
@ForbidKeywordsDecorator(
312+
illegal_keys=['x', 'axis'],
313+
func_name="paddle.compat.min",
314+
correct_name='paddle.min',
315+
)
297316
def min(input: Tensor, *args: Any, **kwargs: Any) -> Tensor | MinMaxRetType:
317+
"""
318+
319+
Computes the minimum of tensor elements. There are mainly 3 cases (functionalities):
320+
1. paddle.compat.min(input: Tensor): reduce min over all dims, return a single value Tensor
321+
2. paddle.compat.min(input: Tensor, dim: int (cannot be None), keepdim=False): reduce min over the given dim,
322+
returns a named tuple MinMaxRetType(values: Tensor, indices: Tensor)
323+
3. paddle.compat.min(input: Tensor, other: Tensor): see `paddle.minimum`
324+
325+
Note: If there are multiple minimum elements, this API evenly distributes gradient between these equal values,
326+
following torch.min. The gradient behavior of `values` for case 2 is the same as `paddle.amin`.
327+
328+
Args:
329+
input (Tensor): A tensor, the data type is bfloat16, float16, float32, float64, int32, int64.
330+
dim (int, optional): The dim along which the minimum is computed.
331+
If this is not specified: see case 1, note that: `None` cannot be passed to this (TypeError will be thrown)
332+
compute the minimum over all elements of `input` and return a Tensor with a single element,
333+
otherwise must be in the range :math:`[-input.ndim, input.ndim)`.
334+
If :math:`dim < 0`, the axis to reduce is :math:`input.ndim + dim`.
335+
keepdim (bool, optional): Whether to reserve the reduced dimension in the
336+
output Tensor. The result tensor will have one fewer dimension
337+
than the `input` unless :attr:`keepdim` is true, default
338+
value is False. Note that if `dim` does not appear in neither (*args) or (**kwargs), this parameter cannot be passed alone
339+
other (Tensor, optional): the other tensor to perform `paddle.minimum` with. This Tensor should
340+
have the same or broadcast-able shape as the `input`. Note that (`dim` & `keepdim`) and `other` are mutually exclusive
341+
meaning that trying to composite both will result in TypeError
342+
343+
Returns:
344+
- For case 1: a single value Tensor (0-dim)
345+
- For case 2: a named tuple MinMaxRetType(values: Tensor, indices: Tensor), `values` has the same data type as the `input`,
346+
while indices is always an int64 Tensor, with exactly the same shape as `values`.
347+
MinMaxRetType can be used (indexed, packed, unpacked) in the same way as a regular tuple
348+
- For case 3: see `paddle.minimum`
349+
350+
351+
Examples:
352+
.. code-block:: python
353+
354+
>>> import paddle
355+
356+
>>> # data_x is a Tensor with shape [2, 4]
357+
>>> # the axis is a int element
358+
>>> x = paddle.to_tensor([[0.2, 0.3, 0.5, 0.9],
359+
... [0.1, 0.2, 0.6, 0.7]],
360+
... dtype='float64', stop_gradient=False)
361+
>>> # Case 1: reduce over all dims
362+
>>> result1 = paddle.compat.min(x)
363+
>>> result1
364+
Tensor(shape=[], dtype=float64, place=Place(gpu:0), stop_gradient=False,
365+
0.10000000)
366+
367+
>>> # Case 2: reduce over specified dim
368+
>>> x.clear_grad()
369+
>>> result2 = paddle.compat.min(x, dim=1)
370+
>>> result2
371+
MinMaxRetType(values=Tensor(shape=[2], dtype=float64, place=Place(gpu:0), stop_gradient=False,
372+
[0.20000000, 0.10000000]), indices=Tensor(shape=[2], dtype=int64, place=Place(gpu:0), stop_gradient=True,
373+
[0, 0]))
374+
>>> result2[0].backward()
375+
>>> x.grad
376+
Tensor(shape=[2, 4], dtype=float64, place=Place(gpu:0), stop_gradient=False,
377+
[[1., 0., 0., 0.],
378+
[1., 0., 0., 0.]])
379+
380+
>>> # Case 3: equivalent to `paddle.minimum`
381+
>>> x.clear_grad()
382+
>>> y = paddle.to_tensor([[0.5, 0.4, 0.1, 0.2],
383+
... [0.3, 0.1, 0.6, 0.7]],
384+
... dtype='float64', stop_gradient=False)
385+
>>> result3 = paddle.compat.min(x, y)
386+
>>> result3
387+
Tensor(shape=[2, 4], dtype=float64, place=Place(gpu:0), stop_gradient=False,
388+
[[0.20000000, 0.30000000, 0.10000000, 0.20000000],
389+
[0.10000000, 0.10000000, 0.60000000, 0.70000000]])
390+
"""
298391
if not isinstance(input, paddle.pir.Value) and not isinstance(
299392
input, paddle.Tensor
300393
):
301394
raise TypeError(
302395
f"input should be a tensor, but got an instance with type '{type(input).__name__}'"
303396
)
397+
_min_max_tensor_allow_grad(input)
304398

305399
dim_or_other, keepdim = _min_max_param_checker("min", *args, **kwargs)
306400

@@ -329,14 +423,93 @@ def min(input: Tensor, *args: Any, **kwargs: Any) -> Tensor | MinMaxRetType:
329423
return _C_ops.minimum(input, dim_or_other)
330424

331425

332-
@forbid_keywords(['x', 'axis'], 'paddle.max')
426+
@ForbidKeywordsDecorator(
427+
illegal_keys=['x', 'axis'],
428+
func_name="paddle.compat.max",
429+
correct_name='paddle.max',
430+
)
333431
def max(input: Tensor, *args: Any, **kwargs: Any) -> Tensor | MinMaxRetType:
432+
"""
433+
434+
Computes the maximum of tensor elements. There are mainly 3 cases (functionalities):
435+
1. paddle.compat.max(input: Tensor): reduce max over all dims, return a single value Tensor
436+
2. paddle.compat.max(input: Tensor, dim: int (cannot be None), keepdim=False): reduce max over the given dim,
437+
returns a named tuple MinMaxRetType(values: Tensor, indices: Tensor)
438+
3. paddle.compat.max(input: Tensor, other: Tensor): see `paddle.maximum`
439+
440+
Note: If there are multiple maximum elements, this API evenly distributes gradient between these equal values,
441+
following torch.max. The gradient behavior of `values` for case 2 is the same as `paddle.amax`.
442+
443+
Args:
444+
input (Tensor): A tensor, the data type is bfloat16, float16, float32, float64, int32, int64.
445+
dim (int, optional): The dim along which the maximum is computed.
446+
If this is not specified: see case 1, note that: `None` cannot be passed to this (TypeError will be thrown)
447+
compute the maximum over all elements of `input` and return a Tensor with a single element,
448+
otherwise must be in the range :math:`[-input.ndim, input.ndim)`.
449+
If :math:`dim < 0`, the axis to reduce is :math:`input.ndim + dim`.
450+
keepdim (bool, optional): Whether to reserve the reduced dimension in the
451+
output Tensor. The result tensor will have one fewer dimension
452+
than the `input` unless :attr:`keepdim` is true, default
453+
value is False. Note that if `dim` does not appear in neither (*args) or (**kwargs), this parameter cannot be passed alone
454+
other (Tensor, optional): the other tensor to perform `paddle.maximum` with. This Tensor should
455+
have the same or broadcast-able shape as the `input`. Note that (`dim` & `keepdim`) and `other` are mutually exclusive
456+
meaning that trying to composite both will result in TypeError
457+
458+
Returns:
459+
- For case 1: a single value Tensor (0-dim)
460+
- For case 2: a named tuple MinMaxRetType(values: Tensor, indices: Tensor), `values` has the same data type as the `input`,
461+
while indices is always an int64 Tensor, with exactly the same shape as `values`.
462+
MinMaxRetType can be used (indexed, packed, unpacked) in the same way as a regular tuple
463+
- For case 3: see `paddle.maximum`
464+
465+
466+
Examples:
467+
.. code-block:: python
468+
469+
>>> import paddle
470+
471+
>>> # data_x is a Tensor with shape [2, 4]
472+
>>> # the axis is a int element
473+
>>> x = paddle.to_tensor([[0.2, 0.3, 0.5, 0.9],
474+
... [0.1, 0.2, 0.6, 0.7]],
475+
... dtype='float64', stop_gradient=False)
476+
>>> # Case 1: reduce over all dims
477+
>>> result1 = paddle.compat.max(x)
478+
>>> result1
479+
Tensor(shape=[], dtype=float64, place=Place(gpu:0), stop_gradient=False,
480+
0.90000000)
481+
482+
>>> # Case 2: reduce over specified dim
483+
>>> x.clear_grad()
484+
>>> result2 = paddle.compat.max(x, dim=1)
485+
>>> result2
486+
MinMaxRetType(values=Tensor(shape=[2], dtype=float64, place=Place(gpu:0), stop_gradient=False,
487+
[0.90000000, 0.70000000]), indices=Tensor(shape=[2], dtype=int64, place=Place(gpu:0), stop_gradient=True,
488+
[3, 3]))
489+
>>> result2[0].backward()
490+
>>> x.grad
491+
Tensor(shape=[2, 4], dtype=float64, place=Place(gpu:0), stop_gradient=False,
492+
[[0., 0., 0., 1.],
493+
[0., 0., 0., 1.]])
494+
495+
>>> # Case 3: equivalent to `paddle.maximum`
496+
>>> x.clear_grad()
497+
>>> y = paddle.to_tensor([[0.5, 0.4, 0.1, 0.2],
498+
... [0.3, 0.1, 0.6, 0.7]],
499+
... dtype='float64', stop_gradient=False)
500+
>>> result3 = paddle.compat.max(x, y)
501+
>>> result3
502+
Tensor(shape=[2, 4], dtype=float64, place=Place(gpu:0), stop_gradient=False,
503+
[[0.50000000, 0.40000000, 0.50000000, 0.90000000],
504+
[0.30000000, 0.20000000, 0.60000000, 0.70000000]])
505+
"""
334506
if not isinstance(input, paddle.pir.Value) and not isinstance(
335507
input, paddle.Tensor
336508
):
337509
raise TypeError(
338510
f"input should be a tensor, but got an instance with type '{type(input).__name__}'"
339511
)
512+
_min_max_tensor_allow_grad(input)
340513

341514
dim_or_other, keepdim = _min_max_param_checker("max", *args, **kwargs)
342515

0 commit comments

Comments
 (0)