Skip to content

Commit 68fe0cf

Browse files
[Refine] Refine evaluation output (#866)
* fix(test=document_fix) * pretty evaluation output with prettytable * remove epoch id when ony evaluating * hidden epoch information when not evaluating on flying
1 parent 4b39a66 commit 68fe0cf

File tree

7 files changed

+43
-23
lines changed

7 files changed

+43
-23
lines changed

examples/darcy/darcy2d.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ def poisson_ref_compute_func(_in):
109109
cfg.NPOINT_PDE + cfg.NPOINT_BC, evenly=True
110110
)
111111
visualizer = {
112-
"visualize_p": ppsci.visualize.VisualizerVtu(
112+
"visualize_p_ux_uy": ppsci.visualize.VisualizerVtu(
113113
vis_points,
114114
{
115115
"p": lambda d: d["p"],
@@ -246,7 +246,7 @@ def poisson_ref_compute_func(_in):
246246
cfg.NPOINT_PDE + cfg.NPOINT_BC, evenly=True
247247
)
248248
visualizer = {
249-
"visualize_p": ppsci.visualize.VisualizerVtu(
249+
"visualize_p_ux_uy": ppsci.visualize.VisualizerVtu(
250250
vis_points,
251251
{
252252
"p": lambda d: d["p"],

ppsci/solver/eval.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import time
1818
from typing import TYPE_CHECKING
1919
from typing import Dict
20+
from typing import Optional
2021
from typing import Tuple
2122
from typing import Union
2223

@@ -59,7 +60,7 @@ def _get_dataset_length(
5960

6061

6162
def _eval_by_dataset(
62-
solver: "solver.Solver", epoch_id: int, log_freq: int
63+
solver: "solver.Solver", epoch_id: Optional[int], log_freq: int
6364
) -> Tuple[float, Dict[str, Dict[str, float]]]:
6465
"""Evaluate with computing metric on total samples(default process).
6566
@@ -68,7 +69,7 @@ def _eval_by_dataset(
6869
6970
Args:
7071
solver (solver.Solver): Main Solver.
71-
epoch_id (int): Epoch id.
72+
epoch_id (Optional[int]): Epoch id.
7273
log_freq (int): Log evaluation information every `log_freq` steps.
7374
7475
Returns:
@@ -189,7 +190,7 @@ def _eval_by_dataset(
189190

190191

191192
def _eval_by_batch(
192-
solver: "solver.Solver", epoch_id: int, log_freq: int
193+
solver: "solver.Solver", epoch_id: Optional[int], log_freq: int
193194
) -> Tuple[float, Dict[str, Dict[str, float]]]:
194195
"""Evaluate with computing metric by batch, which is memory-efficient.
195196
@@ -199,7 +200,7 @@ def _eval_by_batch(
199200
200201
Args:
201202
solver (solver.Solver): Main Solver.
202-
epoch_id (int): Epoch id.
203+
epoch_id (Optional[int]): Epoch id.
203204
log_freq (int): Log evaluation information every `log_freq` steps.
204205
205206
Returns:
@@ -303,13 +304,13 @@ def _eval_by_batch(
303304

304305

305306
def eval_func(
306-
solver: "solver.Solver", epoch_id: int, log_freq: int
307+
solver: "solver.Solver", epoch_id: Optional[int], log_freq: int
307308
) -> Tuple[float, Dict[str, Dict[str, float]]]:
308309
"""Evaluation function.
309310
310311
Args:
311312
solver (solver.Solver): Main Solver.
312-
epoch_id (int): Epoch id.
313+
epoch_id (Optional[int]): Epoch id.
313314
log_freq (int): Log evaluation information every `log_freq` steps.
314315
315316
Returns:

ppsci/solver/printer.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -132,11 +132,17 @@ def log_eval_info(
132132

133133
epoch_width = len(str(solver.epochs))
134134
iters_width = len(str(iters_per_epoch))
135-
logger.info(
136-
f"[Eval][Epoch {epoch_id:>{epoch_width}}/{solver.epochs}]"
137-
f"[Iter {iter_id:>{iters_width}}/{iters_per_epoch}] "
138-
f"{metric_msg}, {time_msg}, {ips_msg}, {eta_msg}"
139-
)
135+
if isinstance(epoch_id, int):
136+
logger.info(
137+
f"[Eval][Epoch {epoch_id:>{epoch_width}}/{solver.epochs}]"
138+
f"[Iter {iter_id:>{iters_width}}/{iters_per_epoch}] "
139+
f"{metric_msg}, {time_msg}, {ips_msg}, {eta_msg}"
140+
)
141+
else:
142+
logger.info(
143+
f"[Eval][Iter {iter_id:>{iters_width}}/{iters_per_epoch}] "
144+
f"{metric_msg}, {time_msg}, {ips_msg}, {eta_msg}"
145+
)
140146

141147
# logger.scalar(
142148
# {

ppsci/solver/solver.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -619,11 +619,13 @@ def finetune(self, pretrained_model_path: str) -> None:
619619
self.train()
620620

621621
@misc.run_on_eval_mode
622-
def eval(self, epoch_id: int = 0) -> Tuple[float, Dict[str, Dict[str, float]]]:
622+
def eval(
623+
self, epoch_id: Optional[int] = None
624+
) -> Tuple[float, Dict[str, Dict[str, float]]]:
623625
"""Evaluation.
624626
625627
Args:
626-
epoch_id (int, optional): Epoch id. Defaults to 0.
628+
epoch_id (Optional[int]): Epoch id. Defaults to None.
627629
628630
Returns:
629631
Tuple[float, Dict[str, Dict[str, float]]]: A targe metric value(float) and
@@ -636,23 +638,30 @@ def eval(self, epoch_id: int = 0) -> Tuple[float, Dict[str, Dict[str, float]]]:
636638
metric_msg = ", ".join(
637639
[self.eval_output_info[key].avg_info for key in self.eval_output_info]
638640
)
639-
logger.info(f"[Eval][Epoch {epoch_id}][Avg] {metric_msg}")
641+
642+
if isinstance(epoch_id, int):
643+
logger.info(f"[Eval][Epoch {epoch_id}][Avg] {metric_msg}")
644+
else:
645+
logger.info(f"[Eval][Avg] {metric_msg}")
640646
self.eval_output_info.clear()
641647

642648
return result
643649

644650
@misc.run_on_eval_mode
645-
def visualize(self, epoch_id: int = 0):
651+
def visualize(self, epoch_id: Optional[int] = None):
646652
"""Visualization.
647653
648654
Args:
649-
epoch_id (int, optional): Epoch id. Defaults to 0.
655+
epoch_id (Optional[int]): Epoch id. Defaults to None.
650656
"""
651657
# set visualize func
652658
self.visu_func = ppsci.solver.visu.visualize_func
653659

654660
self.visu_func(self, epoch_id)
655-
logger.info(f"[Visualize][Epoch {epoch_id}] Finish visualization")
661+
if isinstance(epoch_id, int):
662+
logger.info(f"[Visualize][Epoch {epoch_id}] Finish visualization")
663+
else:
664+
logger.info("[Visualize] Finish visualization")
656665

657666
@misc.run_on_eval_mode
658667
def predict(

ppsci/solver/visu.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import os
1818
import os.path as osp
1919
from typing import TYPE_CHECKING
20+
from typing import Optional
2021

2122
import paddle
2223

@@ -26,12 +27,12 @@
2627
from ppsci import solver
2728

2829

29-
def visualize_func(solver: "solver.Solver", epoch_id: int):
30+
def visualize_func(solver: "solver.Solver", epoch_id: Optional[int]):
3031
"""Visualization program.
3132
3233
Args:
3334
solver (solver.Solver): Main Solver.
34-
epoch_id (int): Epoch id.
35+
epoch_id (Optional[int]): Epoch id.
3536
"""
3637
for _, _visualizer in solver.visualizer.items():
3738
all_input = misc.Prettydefaultdict(list)
@@ -87,7 +88,9 @@ def visualize_func(solver: "solver.Solver", epoch_id: int):
8788
# save visualization
8889
with misc.RankZeroOnly(solver.rank) as is_master:
8990
if is_master:
90-
visual_dir = osp.join(solver.output_dir, "visual", f"epoch_{epoch_id}")
91+
visual_dir = osp.join(solver.output_dir, "visual")
92+
if epoch_id:
93+
visual_dir = osp.join(visual_dir, f"epoch_{epoch_id}")
9194
os.makedirs(visual_dir, exist_ok=True)
9295
_visualizer.save(
9396
osp.join(visual_dir, _visualizer.prefix),

ppsci/utils/download.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,7 @@ def _download(url, path, md5sum=None):
157157
if chunk:
158158
f.write(chunk)
159159
shutil.move(tmp_fullname, fullname)
160+
logger.message(f"Finished downloading pretrained model and saved to {fullname}")
160161

161162
return fullname
162163

ppsci/utils/ema.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def __init__(self, model: nn.Layer, decay: Optional[float] = None):
4141
self.model = model # As a quick reference to online model
4242
self.decay = decay
4343

44-
self.params_shadow: Dict[str, paddle.Tensor] = {} # ema param or bufer
44+
self.params_shadow: Dict[str, paddle.Tensor] = {} # ema param or buffer
4545
self.params_backup: Dict[str, paddle.Tensor] = {} # used for apply and restore
4646
for name, param_or_buffer in itertools.chain(
4747
self.model.named_parameters(), self.model.named_buffers()

0 commit comments

Comments
 (0)