Skip to content

Commit adfadca

Browse files
fix solver.predict and device setting (#953)
1 parent 7a8fddd commit adfadca

File tree

3 files changed

+25
-17
lines changed

3 files changed

+25
-17
lines changed

deploy/python_infer/base.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,9 @@ def __init__(
9595
)
9696

9797
def predict(self, input_dict):
98-
raise NotImplementedError
98+
raise NotImplementedError(
99+
f"Method 'predict' is should be implemented in {self.__class__.__name__} class."
100+
)
99101

100102
def _create_paddle_predictor(
101103
self,

ppsci/solver/solver.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ def __init__(
236236
if self.device != "cpu" and paddle.device.get_device() == "cpu":
237237
logger.warning(f"Set device({device}) to 'cpu' for only cpu available.")
238238
self.device = "cpu"
239-
self.device = paddle.set_device(self.device)
239+
self.device = paddle.device.set_device(self.device)
240240

241241
# set equations for physics-driven or data-physics hybrid driven task, such as PINN
242242
self.equation = equation
@@ -790,43 +790,43 @@ def predict(
790790
self.world_size > 1, self.model
791791
):
792792
for batch_id in range(local_batch_num):
793-
# prepare batch input dict
794-
batch_input_dict = {}
793+
# prepare local batch input
795794
if batch_size is not None:
796795
st = batch_id * batch_size
797796
ed = min(local_num_samples_pad, (batch_id + 1) * batch_size)
798-
for key in local_input_dict:
799-
if not paddle.is_tensor(local_input_dict[key]):
800-
batch_input_dict[key] = paddle.to_tensor(
801-
local_input_dict[key][st:ed], paddle.get_default_dtype()
802-
)
803-
else:
804-
batch_input_dict[key] = local_input_dict[key][st:ed]
805-
batch_input_dict[key].stop_gradient = no_grad
797+
batch_input_dict = {
798+
k: v[st:ed] for k, v in local_input_dict.items()
799+
}
806800
else:
807801
batch_input_dict = {**local_input_dict}
802+
# Keep dtype unchanged as all dtype be correct when given into predict function
803+
for key in batch_input_dict:
804+
if not paddle.is_tensor(batch_input_dict[key]):
805+
batch_input_dict[key] = paddle.to_tensor(
806+
batch_input_dict[key], stop_gradient=no_grad
807+
)
808808

809809
# forward
810810
with self.autocast_context_manager(self.use_amp, self.amp_level):
811811
batch_output_dict = self.forward_helper.visu_forward(
812812
expr_dict, batch_input_dict, self.model
813813
)
814814

815-
# collect batch data
815+
# collect local batch output
816816
for key, batch_output in batch_output_dict.items():
817817
pred_dict[key].append(
818818
batch_output.detach() if no_grad else batch_output
819819
)
820820

821-
# concatenate local predictions
821+
# concatenate local output
822822
pred_dict = {key: paddle.concat(value) for key, value in pred_dict.items()}
823823

824824
if self.world_size > 1:
825-
# gather global predictions from all devices if world_size > 1
825+
# gather global output from all devices if world_size > 1
826826
pred_dict = {
827827
key: misc.all_gather(value) for key, value in pred_dict.items()
828828
}
829-
# rearrange predictions as the same order of input_dict according
829+
# rearrange output as the same order of input_dict according
830830
# to inverse permutation
831831
perm = np.arange(num_samples_pad, dtype="int64")
832832
perm = np.concatenate(
@@ -837,7 +837,7 @@ def predict(
837837
perm_inv[perm] = np.arange(num_samples_pad, dtype="int64")
838838
perm_inv = paddle.to_tensor(perm_inv)
839839
pred_dict = {key: value[perm_inv] for key, value in pred_dict.items()}
840-
# then discard predictions of padding data at the end if num_pad > 0
840+
# then discard output of padding data at the end if num_pad > 0
841841
if num_pad > 0:
842842
pred_dict = {
843843
key: value[:num_samples] for key, value in pred_dict.items()

ppsci/utils/callbacks.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,12 @@ def on_job_start(self, config: DictConfig, **kwargs: Any) -> None:
9595
full_cfg.log_level,
9696
)
9797

98+
# set device before running into example function
99+
if "device" in full_cfg:
100+
import paddle
101+
102+
paddle.device.set_device(full_cfg.device)
103+
98104
# enable prim if specified
99105
if "prim" in full_cfg and bool(full_cfg.prim):
100106
# Mostly for compiler running with dy2st.

0 commit comments

Comments
 (0)