19
19
from typing import Optional
20
20
from typing import Union
21
21
22
+ import paddle
23
+
22
24
from ppsci .loss import base
23
25
24
26
@@ -34,7 +36,7 @@ class FunctionalLoss(base.Loss):
34
36
$$
35
37
36
38
Args:
37
- loss_expr (Callable): expression of loss calculation .
39
+ loss_expr (Callable[..., paddle.Tensor] ): Function for custom loss computation .
38
40
weight (Optional[Union[float, Dict[str, float]]]): Weight for loss. Defaults to None.
39
41
40
42
Examples:
@@ -63,11 +65,21 @@ class FunctionalLoss(base.Loss):
63
65
64
66
def __init__ (
65
67
self ,
66
- loss_expr : Callable ,
68
+ loss_expr : Callable [..., paddle . Tensor ] ,
67
69
weight : Optional [Union [float , Dict [str , float ]]] = None ,
68
70
):
69
71
super ().__init__ (None , weight )
70
72
self .loss_expr = loss_expr
71
73
72
- def forward (self , output_dict , label_dict = None , weight_dict = None ):
73
- return self .loss_expr (output_dict , label_dict , weight_dict )
74
+ def forward (self , output_dict , label_dict = None , weight_dict = None ) -> paddle .Tensor :
75
+ loss = self .loss_expr (output_dict , label_dict , weight_dict )
76
+
77
+ assert isinstance (
78
+ loss , (paddle .Tensor , paddle .static .Variable , paddle .pir .Value )
79
+ ), (
80
+ "Loss computed by custom function should be type of 'paddle.Tensor', "
81
+ f"'paddle.static.Variable' or 'paddle.pir.Value', but got { type (loss )} ."
82
+ " Please check the return type of custom loss function."
83
+ )
84
+
85
+ return loss
0 commit comments