Skip to content

Commit de9f833

Browse files
[API] support set None to Validator.loss (#1214)
* support set None to validator * support set None to validator
1 parent 42a308b commit de9f833

File tree

7 files changed

+41
-34
lines changed

7 files changed

+41
-34
lines changed

ppsci/solver/eval.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -108,9 +108,10 @@ def _eval_by_dataset(
108108
weight_dict,
109109
)
110110

111-
loss_dict[f"{_validator.name}/loss"] = float(
112-
sum(list(validator_loss.values()))
113-
)
111+
if len(validator_loss) > 0:
112+
loss_dict[f"{_validator.name}/loss"] = float(
113+
sum(list(validator_loss.values()))
114+
)
114115

115116
for key, output in output_dict.items():
116117
all_output[key].append(
@@ -235,9 +236,10 @@ def _eval_by_batch(
235236
weight_dict,
236237
)
237238

238-
loss_dict[f"{_validator.name}/loss"] = float(
239-
sum(list(validator_loss.values()))
240-
)
239+
if len(validator_loss) > 0:
240+
loss_dict[f"{_validator.name}/loss"] = float(
241+
sum(list(validator_loss.values()))
242+
)
241243

242244
# collect batch metric
243245
for metric_name, metric_func in _validator.metric.items():

ppsci/solver/printer.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,8 @@ def log_train_info(
7676
iters_width = len(str(solver.iters_per_epoch))
7777
log_str = (
7878
f"[Train][Epoch {epoch_id:>{epoch_width}}/{solver.epochs}]"
79-
f"[Iter {iter_id:>{iters_width}}/{solver.iters_per_epoch}] {lr_msg}, "
80-
f"{metric_msg}, {time_msg}, {ips_msg}, {eta_msg}"
79+
+ f"[Iter {iter_id:>{iters_width}}/{solver.iters_per_epoch}] {lr_msg}, "
80+
+ ", ".join(filter(None, [metric_msg, time_msg, ips_msg, eta_msg]))
8181
)
8282
if solver.benchmark_flag:
8383
max_mem_reserved_msg = (
@@ -136,13 +136,13 @@ def log_eval_info(
136136
if isinstance(epoch_id, int):
137137
logger.info(
138138
f"[Eval][Epoch {epoch_id:>{epoch_width}}/{solver.epochs}]"
139-
f"[Iter {iter_id:>{iters_width}}/{iters_per_epoch}] "
140-
f"{metric_msg}, {time_msg}, {ips_msg}, {eta_msg}"
139+
+ f"[Iter {iter_id:>{iters_width}}/{iters_per_epoch}] "
140+
+ ", ".join(filter(None, [metric_msg, time_msg, ips_msg, eta_msg]))
141141
)
142142
else:
143143
logger.info(
144144
f"[Eval][Iter {iter_id:>{iters_width}}/{iters_per_epoch}] "
145-
f"{metric_msg}, {time_msg}, {ips_msg}, {eta_msg}"
145+
+ ", ".join(filter(None, [metric_msg, time_msg, ips_msg, eta_msg]))
146146
)
147147

148148
# reset time information after printing

ppsci/solver/solver.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -491,10 +491,16 @@ def dist_wrapper(model: nn.Layer) -> paddle.DataParallel:
491491
# whether enable static for forward pass. Defaults to False
492492
if not cfg:
493493
self.to_static = to_static
494-
jit.enable_to_static(self.to_static)
495-
logger.message(
496-
f"Set to_static={self.to_static} for computational optimization."
497-
)
494+
495+
if self.to_static:
496+
jit.enable_to_static(self.to_static)
497+
logger.message("Enable jit.to_static for computational optimization.")
498+
self.forward_helper.train_forward = paddle.jit.to_static(
499+
self.forward_helper.train_forward
500+
)
501+
self.forward_helper.eval_forward = paddle.jit.to_static(
502+
self.forward_helper.eval_forward
503+
)
498504

499505
# convert sympy to callable object if exist
500506
extra_parameters = []

ppsci/utils/expression.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
from typing import Optional
2121
from typing import Tuple
2222

23-
from paddle import jit
2423
from paddle import nn
2524
from paddle.framework import core
2625

@@ -57,7 +56,6 @@ def forward(self, *args, **kwargs):
5756
"Use train_forward/eval_forward/visu_forward instead of forward."
5857
)
5958

60-
@jit.to_static
6159
def train_forward(
6260
self,
6361
expr_dicts: Tuple[Dict[str, Callable], ...],
@@ -145,7 +143,6 @@ def train_forward(
145143

146144
return losses_all, losses_constraint
147145

148-
@jit.to_static
149146
def eval_forward(
150147
self,
151148
expr_dict: Dict[str, Callable],
@@ -187,11 +184,13 @@ def eval_forward(
187184
clear()
188185

189186
# compute loss for each validator according to its' own output, label and weight
190-
validator_losses = validator.loss(
191-
output_dict,
192-
label_dict,
193-
weight_dict,
194-
)
187+
validator_losses: Dict[str, "paddle.Tensor"] = {}
188+
if callable(validator.loss):
189+
validator_losses = validator.loss(
190+
output_dict,
191+
label_dict,
192+
weight_dict,
193+
)
195194
return output_dict, validator_losses
196195

197196
def visu_forward(

ppsci/validate/base.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,18 +34,18 @@ class Validator:
3434
Args:
3535
dataset (io.Dataset): Dataset for validator.
3636
dataloader_cfg (Dict[str, Any]): Dataloader config.
37-
loss (loss.Loss): Loss functor.
38-
metric (Optional[Dict[str, metric.Metric]]): Named metric functors in dict.
39-
name (str): Name of validator.
37+
loss (Optional[loss.Loss]): Loss functor. Defaults to None.
38+
metric (Optional[Dict[str, metric.Metric]]): Named metric functors in dict. Defaults to None.
39+
name (str): Name of validator. Defaults to "validator".
4040
"""
4141

4242
def __init__(
4343
self,
4444
dataset: io.Dataset,
4545
dataloader_cfg: Dict[str, Any],
46-
loss: "loss.Loss",
47-
metric: Optional[Dict[str, "metric.Metric"]],
48-
name: str,
46+
loss: Optional["loss.Loss"] = None,
47+
metric: Optional[Dict[str, "metric.Metric"]] = None,
48+
name: str = "validator",
4949
):
5050
self.data_loader = data.build_dataloader(dataset, dataloader_cfg)
5151
self.data_iter = iter(self.data_loader)

ppsci/validate/geo_validator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ class GeometryValidator(base.Validator):
4343
label, which will be a reference value to participate in the loss calculation.
4444
geom (geometry.Geometry): Geometry where data sampled from.
4545
dataloader_cfg (Dict[str, Any]): Dataloader config.
46-
loss (loss.Loss): Loss functor.
46+
loss (Optional[loss.Loss]): Loss functor. Defaults to None.
4747
random (Literal["pseudo", "Halton", "LHS"], optional): Random method for sampling data in
4848
geometry. Defaults to "pseudo".
4949
criteria (Optional[Callable]): Criteria for refining specified domain. Defaults to None.
@@ -75,7 +75,7 @@ def __init__(
7575
label_dict: Dict[str, Union[float, Callable]],
7676
geom: geometry.Geometry,
7777
dataloader_cfg: Dict[str, Any],
78-
loss: loss.Loss,
78+
loss: Optional["loss.Loss"] = None,
7979
random: Literal["pseudo", "Halton", "LHS"] = "pseudo",
8080
criteria: Optional[Callable] = None,
8181
evenly: bool = False,

ppsci/validate/sup_validator.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@ class SupervisedValidator(base.Validator):
3030
3131
Args:
3232
dataloader_cfg (Dict[str, Any]): Config of building a dataloader.
33-
loss (loss.Loss): Loss functor.
34-
output_expr (Optional[Dict[str, Callable]]): List of label expression.
33+
loss (Optional[loss.Loss]): Loss functor. Defaults to None.
34+
output_expr (Optional[Dict[str, Callable]]): List of label expression. Defaults to None.
3535
metric (Optional[Dict[str, metric.Metric]]): Named metric functors in dict. Defaults to None.
3636
name (Optional[str]): Name of validator. Defaults to None.
3737
@@ -63,7 +63,7 @@ class SupervisedValidator(base.Validator):
6363
def __init__(
6464
self,
6565
dataloader_cfg: Dict[str, Any],
66-
loss: loss.Loss,
66+
loss: Optional["loss.Loss"] = None,
6767
output_expr: Optional[Dict[str, Callable]] = None,
6868
metric: Optional[Dict[str, metric.Metric]] = None,
6969
name: Optional[str] = None,

0 commit comments

Comments
 (0)