Skip to content

Commit 8b9601d

Browse files
check return type of FunctionalLoss (#854)
1 parent d2ba413 commit 8b9601d

File tree

2 files changed

+46
-4
lines changed

2 files changed

+46
-4
lines changed

ppsci/loss/func.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
from typing import Optional
2020
from typing import Union
2121

22+
import paddle
23+
2224
from ppsci.loss import base
2325

2426

@@ -34,7 +36,7 @@ class FunctionalLoss(base.Loss):
3436
$$
3537
3638
Args:
37-
loss_expr (Callable): expression of loss calculation.
39+
loss_expr (Callable[..., paddle.Tensor]): Function for custom loss computation.
3840
weight (Optional[Union[float, Dict[str, float]]]): Weight for loss. Defaults to None.
3941
4042
Examples:
@@ -63,11 +65,21 @@ class FunctionalLoss(base.Loss):
6365

6466
def __init__(
6567
self,
66-
loss_expr: Callable,
68+
loss_expr: Callable[..., paddle.Tensor],
6769
weight: Optional[Union[float, Dict[str, float]]] = None,
6870
):
6971
super().__init__(None, weight)
7072
self.loss_expr = loss_expr
7173

72-
def forward(self, output_dict, label_dict=None, weight_dict=None):
73-
return self.loss_expr(output_dict, label_dict, weight_dict)
74+
def forward(self, output_dict, label_dict=None, weight_dict=None) -> paddle.Tensor:
75+
loss = self.loss_expr(output_dict, label_dict, weight_dict)
76+
77+
assert isinstance(
78+
loss, (paddle.Tensor, paddle.static.Variable, paddle.pir.Value)
79+
), (
80+
"Loss computed by custom function should be type of 'paddle.Tensor', "
81+
f"'paddle.static.Variable' or 'paddle.pir.Value', but got {type(loss)}."
82+
" Please check the return type of custom loss function."
83+
)
84+
85+
return loss

test/loss/func.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import paddle
2+
import pytest
3+
4+
from ppsci import loss
5+
6+
__all__ = []
7+
8+
9+
def test_non_tensor_return_type():
10+
"""Test for biharmonic equation."""
11+
12+
def loss_func_return_tensor(input_dict, label_dict, weight_dict):
13+
return (0.5 * (input_dict["x"] - label_dict["x"]) ** 2).sum()
14+
15+
def loss_func_reuturn_builtin_float(input_dict, label_dict, weight_dict):
16+
return (0.5 * (input_dict["x"] - label_dict["x"]) ** 2).sum().item()
17+
18+
wrapped_loss1 = loss.FunctionalLoss(loss_func_return_tensor)
19+
wrapped_loss2 = loss.FunctionalLoss(loss_func_reuturn_builtin_float)
20+
21+
input_dict = {"x": paddle.randn([10, 1])}
22+
label_dict = {"x": paddle.zeros([10, 1])}
23+
24+
wrapped_loss1(input_dict, label_dict)
25+
with pytest.raises(AssertionError):
26+
wrapped_loss2(input_dict, label_dict)
27+
28+
29+
if __name__ == "__main__":
30+
pytest.main()

0 commit comments

Comments
 (0)