Skip to content

Commit 8d4e550

Browse files
support enabling prim via ++prim=1 (#843)
1 parent 6dd8b66 commit 8d4e550

File tree

2 files changed

+10
-0
lines changed

2 files changed

+10
-0
lines changed

ppsci/utils/callbacks.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,3 +97,12 @@ def on_job_start(self, config: DictConfig, **kwargs: Any) -> None:
9797
else None,
9898
full_cfg.log_level,
9999
)
100+
101+
# enable prim if specified
102+
if "prim" in full_cfg and bool(full_cfg.prim):
103+
# Mostly for dy2st running, will be removed in the future
104+
from paddle.framework import core
105+
106+
core.set_prim_eager_enabled(True)
107+
core._set_prim_all_enabled(True)
108+
logger.message("Prim mode is enabled.")

ppsci/utils/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,7 @@ class SolverConfig(BaseModel):
219219
use_amp: bool = False
220220
amp_level: Literal["O0", "O1", "O2", "OD"] = "O1"
221221
to_static: bool = False
222+
prim: bool = False
222223
log_level: Literal["debug", "info", "warning", "error"] = "info"
223224

224225
# Training related config

0 commit comments

Comments
 (0)