Skip to content

Commit f4c45e2

Browse files
authored
[API Compatibility] Support paddle.isclose and paddle.nn.functional.softplus (PaddlePaddle#76255)
* sink isclose and softplus * fix for cast bool to float * rm unused code * Revert "rm unused code" This reverts commit 0523ebb. e * Revert "fix for cast bool to float" This reverts commit 93e220f. * rm rtol and atol check
1 parent fb806ad commit f4c45e2

File tree

9 files changed

+383
-198
lines changed

9 files changed

+383
-198
lines changed

paddle/fluid/pybind/arg_pre_process.cc

Lines changed: 72 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,20 @@
3030
namespace paddle {
3131
namespace pybind {
3232
constexpr char kStopGradientAttrName[] = "stop_gradient"; // NOLINT
33+
static void CheckDataType(const std::string& op_name,
34+
const std::string var_name,
35+
const phi::DataType& var_dtype,
36+
const std::vector<phi::DataType>& expect_dtype) {
37+
for (auto& t : expect_dtype) {
38+
if (var_dtype == t) return;
39+
}
40+
PADDLE_THROW(common::errors::InvalidType(
41+
"The dtype of %s of %s must be one of %s, but received %s.",
42+
var_name,
43+
op_name,
44+
phi::DataTypeToString(expect_dtype),
45+
phi::DataTypeToString(var_dtype)));
46+
}
3347
void ExpandAsPreProcess(paddle::Tensor* x,
3448
paddle::optional<paddle::Tensor>* y,
3549
std::vector<int64_t>* target_shape) {
@@ -136,11 +150,67 @@ void LogsumexpPreProcess(pir::Value* x,
136150
}
137151
return;
138152
}
139-
140-
void SumPreProcess(Tensor* x, IntArray* axis) {}
141153
void SumPreProcess(Value* x, Value* axis) {
142154
paddle::dialect::SetStopGradient(axis);
143155
}
156+
void IsClosePreProcess(Value* x, Value* y, Value* rtol, Value* atol) {
157+
/*
158+
if in_pir_mode():
159+
check_variable_and_dtype(
160+
x,
161+
"input",
162+
['float16', 'float32', 'float64', 'complex64', 'complex128'],
163+
'isclose',
164+
)
165+
check_variable_and_dtype(
166+
y,
167+
"input",
168+
['float16', 'float32', 'float64', 'complex64', 'complex128'],
169+
'isclose',
170+
)
171+
if isinstance(rtol, paddle.pir.Value):
172+
check_variable_and_dtype(
173+
rtol,
174+
"input",
175+
['float64'],
176+
'isclose',
177+
)
178+
else:
179+
check_type(rtol, 'rtol', float, 'isclose')
180+
if isinstance(atol, paddle.pir.Value):
181+
check_variable_and_dtype(
182+
atol,
183+
"input",
184+
['float64'],
185+
'isclose',
186+
)
187+
else:
188+
check_type(atol, 'atol', float, 'isclose')
189+
190+
*/
191+
// 'float16', 'float32', 'float64', 'complex64', 'complex128'
192+
CheckDataType("is_close",
193+
"x",
194+
pir::GetValueDtype(*x),
195+
{phi::DataType::FLOAT16,
196+
phi::DataType::FLOAT32,
197+
phi::DataType::FLOAT64,
198+
phi::DataType::COMPLEX64,
199+
phi::DataType::COMPLEX128});
200+
CheckDataType("is_close",
201+
"y",
202+
pir::GetValueDtype(*y),
203+
{phi::DataType::FLOAT16,
204+
phi::DataType::FLOAT32,
205+
phi::DataType::FLOAT64,
206+
phi::DataType::COMPLEX64,
207+
phi::DataType::COMPLEX128});
208+
// 'float64'
209+
CheckDataType(
210+
"is_close", "rtol", pir::GetValueDtype(*rtol), {phi::DataType::FLOAT64});
211+
CheckDataType(
212+
"is_close", "atol", pir::GetValueDtype(*atol), {phi::DataType::FLOAT64});
213+
}
144214
} // namespace pybind
145215

146216
} // namespace paddle

paddle/fluid/pybind/arg_pre_process.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,8 @@ void RollPreProcess(Value* x, Value* shifts, IntVector* axis);
4242
void LogsumexpPreProcess(Tensor* x, std::vector<int>* axis, bool* reduce_all);
4343
void LogsumexpPreProcess(Value* x, std::vector<int>* axis, bool* reduce_all);
4444

45-
void SumPreProcess(Tensor* x, IntArray* axis);
4645
void SumPreProcess(Value* x, Value* axis);
46+
void IsClosePreProcess(Value* x, Value* y, Value* rtol, Value* atol);
4747
} // namespace pybind
4848

4949
} // namespace paddle

paddle/phi/common/data_type.h

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

1515
#pragma once
16-
16+
#include <vector>
1717
#include "paddle/common/exception.h"
1818
#include "paddle/phi/common/bfloat16.h"
1919
#include "paddle/phi/common/complex.h"
@@ -275,6 +275,16 @@ inline std::string DataTypeToString(const DataType& dtype) {
275275
PD_THROW("Invalid enum data type `", static_cast<int>(dtype), "`.");
276276
}
277277
}
278+
inline std::string DataTypeToString(const std::vector<DataType>& dtypes) {
279+
std::string dtype_str;
280+
for (size_t i = 0; i < dtypes.size(); ++i) {
281+
dtype_str += DataTypeToString(dtypes[i]);
282+
if (i != dtypes.size() - 1) {
283+
dtype_str += ", ";
284+
}
285+
}
286+
return dtype_str;
287+
}
278288

279289
inline DataType StringToDataType(const std::string& dtype) {
280290
if (dtype == "Undefined(ALL_DTYPE)") {

paddle/phi/ops/yaml/python_api_info.yaml

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@
193193
args_alias:
194194
use_default_mapping : True
195195
pre_process:
196-
func : SumPreProcess(x, axis)
196+
static_func : SumPreProcess(x, axis)
197197
args_mapper :
198198
func : ArgSumMapper
199199

@@ -240,3 +240,14 @@
240240
args_alias :
241241
value : [values]
242242
use_default_mapping : True
243+
244+
- op : softplus
245+
name : [paddle.nn.functional.softplus]
246+
args_alias :
247+
use_default_mapping : True
248+
- op : isclose
249+
name : [paddle.isclose, paddle.Tensor.isclose]
250+
args_alias :
251+
use_default_mapping : True
252+
pre_process:
253+
static_func: IsClosePreProcess(x, y, rtol,atol)

python/paddle/_paddle_docs.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -793,6 +793,102 @@ def logsumexp(
793793
) -> Tensor
794794
""",
795795
)
796+
add_doc_and_signature(
797+
"softplus",
798+
"""
799+
softplus activation
800+
801+
.. math::
802+
softplus(x)=\begin{cases}
803+
\frac{1}{\beta} * \\log(1 + e^{\beta * x}),&x\\leqslant\frac{\varepsilon}{\beta};\\
804+
x,&x>\frac{\varepsilon}{\beta}.
805+
\\end{cases}
806+
807+
Parameters:
808+
x (Tensor): The input Tensor with data type float32, float64, complex64, complex128.
809+
beta (float, optional): The value of :math:`\beta` for softplus. Default is 1
810+
threshold (float, optional): The value of :math:`\varepsilon` for softplus. Default is 20
811+
name (str|None, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None.
812+
813+
Returns:
814+
A Tensor with the same data type and shape as ``x`` .
815+
816+
Examples:
817+
.. code-block:: python
818+
819+
>>> import paddle
820+
>>> import paddle.nn.functional as F
821+
822+
>>> x = paddle.to_tensor([-0.4, -0.2, 0.1, 0.3], dtype='float32')
823+
>>> out = F.softplus(x)
824+
>>> print(out)
825+
Tensor(shape=[4], dtype=float32, place=Place(cpu), stop_gradient=True,
826+
[0.51301527, 0.59813893, 0.74439669, 0.85435522])
827+
""",
828+
"""
829+
def softplus(
830+
x: Tensor, beta: float = 1, threshold: float = 20, name: str | None = None
831+
) -> Tensor
832+
""",
833+
)
834+
add_doc_and_signature(
835+
"isclose",
836+
"""
837+
Check if all :math:`x` and :math:`y` satisfy the condition:
838+
.. math::
839+
\\left| x - y \right| \\leq atol + rtol \times \\left| y \right|
840+
elementwise, for all elements of :math:`x` and :math:`y`. The behaviour of this
841+
operator is analogous to :math:`numpy.isclose`, namely that it returns :math:`True` if
842+
two tensors are elementwise equal within a tolerance.
843+
Args:
844+
x(Tensor): The input tensor, it's data type should be float16, float32, float64, complex64, complex128.
845+
y(Tensor): The input tensor, it's data type should be float16, float32, float64, complex64, complex128.
846+
rtol(float, optional): The relative tolerance. Default: :math:`1e-5` .
847+
atol(float, optional): The absolute tolerance. Default: :math:`1e-8` .
848+
equal_nan(bool, optional): If :math:`True` , then two :math:`NaNs` will be compared as equal. Default: :math:`False` .
849+
name (str|None, optional): Name for the operation. For more information, please
850+
refer to :ref:`api_guide_Name`. Default: None.
851+
Returns:
852+
Tensor: The output tensor, it's data type is bool.
853+
Examples:
854+
.. code-block:: python
855+
>>> import paddle
856+
>>> x = paddle.to_tensor([10000., 1e-07])
857+
>>> y = paddle.to_tensor([10000.1, 1e-08])
858+
>>> result1 = paddle.isclose(x, y, rtol=1e-05, atol=1e-08,
859+
... equal_nan=False, name="ignore_nan")
860+
>>> print(result1)
861+
Tensor(shape=[2], dtype=bool, place=Place(cpu), stop_gradient=True,
862+
[True , False])
863+
>>> result2 = paddle.isclose(x, y, rtol=1e-05, atol=1e-08,
864+
... equal_nan=True, name="equal_nan")
865+
>>> print(result2)
866+
Tensor(shape=[2], dtype=bool, place=Place(cpu), stop_gradient=True,
867+
[True , False])
868+
>>> x = paddle.to_tensor([1.0, float('nan')])
869+
>>> y = paddle.to_tensor([1.0, float('nan')])
870+
>>> result1 = paddle.isclose(x, y, rtol=1e-05, atol=1e-08,
871+
... equal_nan=False, name="ignore_nan")
872+
>>> print(result1)
873+
Tensor(shape=[2], dtype=bool, place=Place(cpu), stop_gradient=True,
874+
[True , False])
875+
>>> result2 = paddle.isclose(x, y, rtol=1e-05, atol=1e-08,
876+
... equal_nan=True, name="equal_nan")
877+
>>> print(result2)
878+
Tensor(shape=[2], dtype=bool, place=Place(cpu), stop_gradient=True,
879+
[True, True])
880+
""",
881+
"""
882+
def isclose(
883+
x: Tensor,
884+
y: Tensor,
885+
rtol: float = 1e-05,
886+
atol: float = 1e-08,
887+
equal_nan: bool = False,
888+
name: str | None = None,
889+
) -> Tensor
890+
""",
891+
)
796892

797893

798894
# zhengsheng

python/paddle/nn/functional/activation.py

Lines changed: 1 addition & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838

3939
from paddle._C_ops import ( # noqa: F401
4040
gelu,
41+
softplus,
4142
)
4243

4344

@@ -1266,67 +1267,6 @@ def softmax_(
12661267
return _C_ops.softmax_(outs_cast, axis)
12671268

12681269

1269-
def softplus(
1270-
x: Tensor, beta: float = 1, threshold: float = 20, name: str | None = None
1271-
) -> Tensor:
1272-
r"""
1273-
softplus activation
1274-
1275-
.. math::
1276-
softplus(x)=\begin{cases}
1277-
\frac{1}{\beta} * \log(1 + e^{\beta * x}),&x\leqslant\frac{\varepsilon}{\beta};\\
1278-
x,&x>\frac{\varepsilon}{\beta}.
1279-
\end{cases}
1280-
1281-
Parameters:
1282-
x (Tensor): The input Tensor with data type float32, float64, complex64, complex128.
1283-
beta (float, optional): The value of :math:`\beta` for softplus. Default is 1
1284-
threshold (float, optional): The value of :math:`\varepsilon` for softplus. Default is 20
1285-
name (str|None, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None.
1286-
1287-
Returns:
1288-
A Tensor with the same data type and shape as ``x`` .
1289-
1290-
Examples:
1291-
.. code-block:: python
1292-
1293-
>>> import paddle
1294-
>>> import paddle.nn.functional as F
1295-
1296-
>>> x = paddle.to_tensor([-0.4, -0.2, 0.1, 0.3], dtype='float32')
1297-
>>> out = F.softplus(x)
1298-
>>> print(out)
1299-
Tensor(shape=[4], dtype=float32, place=Place(cpu), stop_gradient=True,
1300-
[0.51301527, 0.59813893, 0.74439669, 0.85435522])
1301-
"""
1302-
1303-
if in_dynamic_or_pir_mode():
1304-
return _C_ops.softplus(x, beta, threshold)
1305-
else:
1306-
check_variable_and_dtype(
1307-
x,
1308-
'x',
1309-
[
1310-
'float16',
1311-
'uint16',
1312-
'float32',
1313-
'float64',
1314-
'complex64',
1315-
'complex128',
1316-
],
1317-
'softplus',
1318-
)
1319-
helper = LayerHelper('softplus', **locals())
1320-
out = helper.create_variable_for_type_inference(x.dtype)
1321-
helper.append_op(
1322-
type='softplus',
1323-
inputs={'X': x},
1324-
outputs={'Out': out},
1325-
attrs={'beta': beta, 'threshold': threshold},
1326-
)
1327-
return out
1328-
1329-
13301270
def softshrink(
13311271
x: Tensor, threshold: float = 0.5, name: str | None = None
13321272
) -> Tensor:

0 commit comments

Comments
 (0)