@@ -83,7 +83,7 @@ class Solver:
83
83
use_wandb (Optional[bool]): Whether use wandb to log data. Defaults to False.
84
84
use_tbd (Optional[bool]): Whether use tensorboardX to log data. Defaults to False.
85
85
wandb_config (Optional[Dict[str, str]]): Config dict of WandB. Defaults to None.
86
- device (Literal["cpu", "gpu", "xpu"], optional): Runtime device. Defaults to "gpu" .
86
+ device (Literal["cpu", "gpu", "xpu", None ], optional): Runtime device. Defaults to None, which means use default device on current platform .
87
87
equation (Optional[Dict[str, ppsci.equation.PDE]]): Equation dict. Defaults to None.
88
88
geom (Optional[Dict[str, ppsci.geometry.Geometry]]): Geometry dict. Defaults to None.
89
89
validator (Optional[Dict[str, ppsci.validate.Validator]]): Validator dict. Defaults to None.
@@ -145,7 +145,7 @@ def __init__(
145
145
use_wandb : bool = False ,
146
146
use_tbd : bool = False ,
147
147
wandb_config : Optional [Mapping ] = None ,
148
- device : Literal ["cpu" , "gpu" , "xpu" ] = "gpu" ,
148
+ device : Literal ["cpu" , "gpu" , "xpu" , None ] = None ,
149
149
equation : Optional [Dict [str , ppsci .equation .PDE ]] = None ,
150
150
geom : Optional [Dict [str , ppsci .geometry .Geometry ]] = None ,
151
151
validator : Optional [Dict [str , ppsci .validate .Validator ]] = None ,
@@ -247,7 +247,12 @@ def __init__(
247
247
# set running device
248
248
if not cfg :
249
249
self .device = device
250
+ if self .device is None :
251
+ # set to default device if not specified
252
+ self .device : str = paddle .device .get_device ()
253
+
250
254
if self .device != "cpu" and paddle .device .get_device () == "cpu" :
255
+ # fall back to cpu if no other device available
251
256
logger .warning (f"Set device({ device } ) to 'cpu' for only cpu available." )
252
257
self .device = "cpu"
253
258
self .device = paddle .device .set_device (self .device )
0 commit comments