Skip to content

Commit 58fc6e0

Browse files
use default device of platform instead of gpu (#1080)
1 parent 02faa22 commit 58fc6e0

File tree

3 files changed

+10
-4
lines changed

3 files changed

+10
-4
lines changed

ppsci/solver/solver.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ class Solver:
8383
use_wandb (Optional[bool]): Whether use wandb to log data. Defaults to False.
8484
use_tbd (Optional[bool]): Whether use tensorboardX to log data. Defaults to False.
8585
wandb_config (Optional[Dict[str, str]]): Config dict of WandB. Defaults to None.
86-
device (Literal["cpu", "gpu", "xpu"], optional): Runtime device. Defaults to "gpu".
86+
device (Literal["cpu", "gpu", "xpu", None], optional): Runtime device. Defaults to None, which means use default device on current platform.
8787
equation (Optional[Dict[str, ppsci.equation.PDE]]): Equation dict. Defaults to None.
8888
geom (Optional[Dict[str, ppsci.geometry.Geometry]]): Geometry dict. Defaults to None.
8989
validator (Optional[Dict[str, ppsci.validate.Validator]]): Validator dict. Defaults to None.
@@ -145,7 +145,7 @@ def __init__(
145145
use_wandb: bool = False,
146146
use_tbd: bool = False,
147147
wandb_config: Optional[Mapping] = None,
148-
device: Literal["cpu", "gpu", "xpu"] = "gpu",
148+
device: Literal["cpu", "gpu", "xpu", None] = None,
149149
equation: Optional[Dict[str, ppsci.equation.PDE]] = None,
150150
geom: Optional[Dict[str, ppsci.geometry.Geometry]] = None,
151151
validator: Optional[Dict[str, ppsci.validate.Validator]] = None,
@@ -247,7 +247,12 @@ def __init__(
247247
# set running device
248248
if not cfg:
249249
self.device = device
250+
if self.device is None:
251+
# set to default device if not specified
252+
self.device: str = paddle.device.get_device()
253+
250254
if self.device != "cpu" and paddle.device.get_device() == "cpu":
255+
# fall back to cpu if no other device available
251256
logger.warning(f"Set device({device}) to 'cpu' for only cpu available.")
252257
self.device = "cpu"
253258
self.device = paddle.device.set_device(self.device)

ppsci/utils/callbacks.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,8 @@ def on_job_start(self, config: DictConfig, **kwargs: Any) -> None:
9999
if "device" in full_cfg:
100100
import paddle
101101

102-
paddle.device.set_device(full_cfg.device)
102+
if isinstance(full_cfg.device, str):
103+
paddle.device.set_device(full_cfg.device)
103104

104105
# enable prim if specified
105106
if "prim" in full_cfg and bool(full_cfg.prim):

ppsci/utils/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,7 @@ class SolverConfig(BaseModel):
305305
use_tbd: bool = False
306306
wandb_config: Mapping = {}
307307
use_wandb: bool = False
308-
device: Literal["cpu", "gpu", "xpu"] = "gpu"
308+
device: Literal["cpu", "gpu", "xpu", None] = None
309309
use_amp: bool = False
310310
amp_level: Literal["O0", "O1", "O2", "OD"] = "O1"
311311
to_static: bool = False

0 commit comments

Comments
 (0)