Skip to content

Commit c08f928

Browse files
【PPSCI Export&Infer No.21】tempoGAN (#884)
* add tempoGAN export&infer * fix tempoGAN.md * fix tempoGAN.md * fix tempoGAN.py
1 parent 1af86e1 commit c08f928

File tree

3 files changed

+99
-1
lines changed

3 files changed

+99
-1
lines changed

docs/zh/examples/tempoGAN.md

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,22 @@
2626
python tempoGAN.py mode=eval EVAL.pretrained_model_path=https://paddle-org.bj.bcebos.com/paddlescience/models/tempoGAN/tempogan_pretrained.pdparams
2727
```
2828

29+
=== "模型导出命令"
30+
31+
``` sh
32+
python tempoGAN.py mode=export
33+
```
34+
35+
=== "模型推理命令"
36+
37+
``` sh
38+
# linux
39+
wget -nc https://paddle-org.bj.bcebos.com/paddlescience/datasets/tempoGAN/2d_valid.mat -P datasets/tempoGAN/
40+
# windows
41+
# curl https://paddle-org.bj.bcebos.com/paddlescience/datasets/tempoGAN/2d_valid.mat --output datasets/tempoGAN/2d_valid.mat
42+
python tempoGAN.py mode=infer
43+
```
44+
2945
| 预训练模型 | 指标 |
3046
|:--| :--|
3147
| [tempogan_pretrained.pdparams](https://paddle-org.bj.bcebos.com/paddlescience/models/tempoGAN/tempogan_pretrained.pdparams) | MSE: 4.21e-5<br>PSNR: 47.19<br>SSIM: 0.9974 |

examples/tempoGAN/conf/tempogan.yaml

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ hydra:
1111
- TRAIN.checkpoint_path
1212
- TRAIN.pretrained_model_path
1313
- EVAL.pretrained_model_path
14+
- INFER.pretrained_model_path
15+
- INFER.export_path
1416
- mode
1517
- output_dir
1618
- log_freq
@@ -92,3 +94,20 @@ TRAIN:
9294
EVAL:
9395
pretrained_model_path: null
9496
save_outs: true
97+
98+
INFER:
99+
pretrained_model_path: https://paddle-org.bj.bcebos.com/paddlescience/models/tempoGAN/tempogan_pretrained.pdparams
100+
export_path: ./inference/tempoGAN
101+
pdmodel_path: ${INFER.export_path}.pdmodel
102+
pdpiparams_path: ${INFER.export_path}.pdiparams
103+
device: gpu
104+
engine: native
105+
precision: fp32
106+
onnx_path: ${INFER.export_path}.onnx
107+
ir_optim: true
108+
min_subgraph_size: 10
109+
gpu_mem: 4000
110+
gpu_id: 0
111+
max_batch_size: 16
112+
num_cpu_threads: 4
113+
batch_size: 1

examples/tempoGAN/tempoGAN.py

Lines changed: 64 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -406,14 +406,77 @@ def scale(data):
406406
)
407407

408408

409+
def export(cfg: DictConfig):
410+
from paddle.static import InputSpec
411+
412+
# set models
413+
gen_funcs = func_module.GenFuncs(cfg.WEIGHT_GEN, None)
414+
model_gen = ppsci.arch.Generator(**cfg.MODEL.gen_net)
415+
model_gen.register_input_transform(gen_funcs.transform_in)
416+
417+
# define model_list
418+
model_list = ppsci.arch.ModelList((model_gen,))
419+
420+
# load pretrained model
421+
solver = ppsci.solver.Solver(
422+
model=model_list, pretrained_model_path=cfg.INFER.pretrained_model_path
423+
)
424+
425+
# export models
426+
input_spec = [
427+
{"density_low": InputSpec([None, 1, 128, 128], "float32", name="density_low")},
428+
]
429+
solver.export(input_spec, cfg.INFER.export_path, skip_prune_program=True)
430+
431+
432+
def inference(cfg: DictConfig):
433+
from matplotlib import image as Img
434+
435+
from deploy.python_infer import pinn_predictor
436+
437+
# set model predictor
438+
predictor = pinn_predictor.PINNPredictor(cfg)
439+
440+
# load dataset
441+
dataset_infer = {
442+
"density_low": hdf5storage.loadmat(cfg.DATASET_PATH_VALID)["density_low"]
443+
}
444+
445+
output_dict = predictor.predict(dataset_infer, cfg.INFER.batch_size)
446+
447+
# mapping data to cfg.INFER.output_keys
448+
output = [output_dict[key] for key in output_dict]
449+
450+
def scale(data):
451+
smax = np.max(data)
452+
smin = np.min(data)
453+
return (data - smin) / (smax - smin)
454+
455+
for i, img in enumerate(output[0]):
456+
img = scale(np.squeeze(img))
457+
Img.imsave(
458+
osp.join(cfg.output_dir, f"out_{i}.png"),
459+
img,
460+
vmin=0.0,
461+
vmax=1.0,
462+
cmap="gray",
463+
)
464+
465+
409466
@hydra.main(version_base=None, config_path="./conf", config_name="tempogan.yaml")
410467
def main(cfg: DictConfig):
411468
if cfg.mode == "train":
412469
train(cfg)
413470
elif cfg.mode == "eval":
414471
evaluate(cfg)
472+
elif cfg.mode == "export":
473+
export(cfg)
474+
elif cfg.mode == "infer":
475+
inference(cfg)
415476
else:
416-
raise ValueError(f"cfg.mode should in ['train', 'eval'], but got '{cfg.mode}'")
477+
raise ValueError(
478+
f"cfg.mode should in ['train', 'eval', 'export', 'infer'], but got '{cfg.mode}'"
479+
)
417480

418481

419482
if __name__ == "__main__":

0 commit comments

Comments
 (0)