Skip to content

Commit 45f45ea

Browse files
authored
[Upadte]update plotting of hpinn's inference (#903)
* [Upadte]update plotting of hpinn's inference * update1
1 parent 4bea2b6 commit 45f45ea

File tree

3 files changed

+18
-11
lines changed

3 files changed

+18
-11
lines changed

docs/zh/examples/hpinns.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@
4545
```
4646

4747

48-
4948
| 预训练模型 | 指标 |
5049
|:--| :--|
5150
| [hpinns_pretrained.pdparams](https://paddle-org.bj.bcebos.com/paddlescience/models/hPINNs/hpinns_pretrained.pdparams) | loss(opt_sup): 0.05352<br>MSE.eval_metric(opt_sup): 0.00002<br>loss(val_sup): 0.02205<br>MSE.eval_metric(val_sup): 0.00001 |

examples/hpinns/holography.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -460,12 +460,19 @@ def inference(cfg: DictConfig):
460460
for store_key, infer_key in zip(cfg.INFER.output_keys, output_dict.keys())
461461
}
462462

463-
ppsci.visualize.save_vtu_from_dict(
464-
"./hpinns_pred.vtu",
465-
{**input_dict, **output_dict},
466-
input_dict.keys(),
467-
cfg.INFER.output_keys,
463+
# plotting E and eps
464+
N = ((func_module.l_BOX[1] - func_module.l_BOX[0]) / 0.05).astype(int)
465+
input_eval = np.stack((input_dict["x"], input_dict["y"]), axis=-1).reshape(
466+
N[0], N[1], 2
468467
)
468+
e_re = output_dict["e_re"].reshape(N[0], N[1])
469+
e_im = output_dict["e_im"].reshape(N[0], N[1])
470+
eps = output_dict["eps"].reshape(N[0], N[1])
471+
v_visual = e_re**2 + e_im**2
472+
field_visual = np.stack((v_visual, eps), axis=-1)
473+
plot_module.field_name = ["Fig7_E", "Fig7_eps"]
474+
plot_module.FIGNAME = "hpinns_pred"
475+
plot_module.plot_field_holo(input_eval, field_visual)
469476

470477

471478
@hydra.main(version_base=None, config_path="./conf", config_name="hpinns.yaml")

examples/hpinns/plotting.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from typing import Callable
2121
from typing import Dict
2222
from typing import List
23+
from typing import Optional
2324

2425
import functions as func_module
2526
import matplotlib.pyplot as plt
@@ -114,16 +115,16 @@ def prepare_data(solver: ppsci.solver.Solver, expr_dict: Dict[str, Callable]):
114115
def plot_field_holo(
115116
coord_visual: np.ndarray,
116117
field_visual: np.ndarray,
117-
coord_lambda: np.ndarray,
118-
field_lambda: np.ndarray,
118+
coord_lambda: Optional[np.ndarray] = None,
119+
field_lambda: Optional[np.ndarray] = None,
119120
):
120121
"""Plot fields of of holography example.
121122
122123
Args:
123124
coord_visual (np.ndarray): The coord of epsilon and |E|**2.
124125
field_visual (np.ndarray): The filed of epsilon and |E|**2.
125-
coord_lambda (np.ndarray): The coord of lambda.
126-
field_lambda (np.ndarray): The filed of lambda.
126+
coord_lambda (Optional[np.ndarray], optional): The coord of lambda. Defaults to None.
127+
field_lambda (Optional[np.ndarray], optional): The filed of lambda. Defaults to None.
127128
"""
128129
fmin, fmax = np.array([0, 1.0]), np.array([0.6, 12])
129130
cmin, cmax = coord_visual.min(axis=(0, 1)), coord_visual.max(axis=(0, 1))
@@ -168,7 +169,7 @@ def plot_field_holo(
168169
cb = plt.colorbar()
169170
plt.axis((emin[0], emax[0], emin[1], emax[1]))
170171
plt.clim(vmin=fmin[fi], vmax=fmax[fi])
171-
else:
172+
elif coord_lambda is not None and field_lambda is not None:
172173
# Fig_6C_lambda_
173174
plt.figure(fi * 100 + 101, figsize=(8, 6))
174175
plt.clf()

0 commit comments

Comments
 (0)