Skip to content

Commit 9d84962

Browse files
[Enh] Add numpy solver in config solvers (#869)
* support registering numpy config solvers when parsing attribute of numpy from config * remove permutation code * avoid duplicated register
1 parent dd971ee commit 9d84962

File tree

3 files changed

+13
-28
lines changed

3 files changed

+13
-28
lines changed

examples/NLS-MB/NLS-MB_optical_rogue_wave.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -417,20 +417,6 @@ def inference(cfg: DictConfig):
417417
store_key: output_dict[infer_key]
418418
for store_key, infer_key in zip(cfg.MODEL.output_keys, output_dict.keys())
419419
}
420-
# TODO: Fix this mapping diff in dy2st
421-
(
422-
output_dict["Eu"],
423-
output_dict["Ev"],
424-
output_dict["eta"],
425-
output_dict["pu"],
426-
output_dict["pv"],
427-
) = (
428-
output_dict["Eu"],
429-
output_dict["Ev"],
430-
output_dict["pu"],
431-
output_dict["pv"],
432-
output_dict["eta"],
433-
)
434420

435421
# visualize prediction
436422
Eu_true, Ev_true, pu_true, pv_true, eta_true = analytic_solution(input_dict)

examples/NLS-MB/NLS-MB_optical_soliton.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -394,20 +394,6 @@ def inference(cfg: DictConfig):
394394
store_key: output_dict[infer_key]
395395
for store_key, infer_key in zip(cfg.MODEL.output_keys, output_dict.keys())
396396
}
397-
# TODO: Fix this mapping diff in dy2st
398-
(
399-
output_dict["Eu"],
400-
output_dict["Ev"],
401-
output_dict["eta"],
402-
output_dict["pu"],
403-
output_dict["pv"],
404-
) = (
405-
output_dict["Eu"],
406-
output_dict["Ev"],
407-
output_dict["pu"],
408-
output_dict["pv"],
409-
output_dict["eta"],
410-
)
411397

412398
# visualize prediction
413399
Eu_true, Ev_true, pu_true, pv_true, eta_true = analytic_solution(input_dict)

ppsci/__init__.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,3 +59,16 @@
5959
"run_check_mesh",
6060
"lambdify",
6161
]
62+
63+
64+
# NOTE: Register custom solvers for parsing values from omegaconf more flexible
65+
def _register_config_solvers():
66+
import numpy as np
67+
from omegaconf import OmegaConf
68+
69+
# register solver for "${numpy: xxx}" item, e.g. pi: "${numpy: pi}"
70+
if not OmegaConf.has_resolver("numpy"):
71+
OmegaConf.register_new_resolver("numpy", lambda x: getattr(np, x))
72+
73+
74+
_register_config_solvers()

0 commit comments

Comments
 (0)