@@ -236,7 +236,7 @@ def __init__(
236
236
if self .device != "cpu" and paddle .device .get_device () == "cpu" :
237
237
logger .warning (f"Set device({ device } ) to 'cpu' for only cpu available." )
238
238
self .device = "cpu"
239
- self .device = paddle .set_device (self .device )
239
+ self .device = paddle .device . set_device (self .device )
240
240
241
241
# set equations for physics-driven or data-physics hybrid driven task, such as PINN
242
242
self .equation = equation
@@ -790,43 +790,43 @@ def predict(
790
790
self .world_size > 1 , self .model
791
791
):
792
792
for batch_id in range (local_batch_num ):
793
- # prepare batch input dict
794
- batch_input_dict = {}
793
+ # prepare local batch input
795
794
if batch_size is not None :
796
795
st = batch_id * batch_size
797
796
ed = min (local_num_samples_pad , (batch_id + 1 ) * batch_size )
798
- for key in local_input_dict :
799
- if not paddle .is_tensor (local_input_dict [key ]):
800
- batch_input_dict [key ] = paddle .to_tensor (
801
- local_input_dict [key ][st :ed ], paddle .get_default_dtype ()
802
- )
803
- else :
804
- batch_input_dict [key ] = local_input_dict [key ][st :ed ]
805
- batch_input_dict [key ].stop_gradient = no_grad
797
+ batch_input_dict = {
798
+ k : v [st :ed ] for k , v in local_input_dict .items ()
799
+ }
806
800
else :
807
801
batch_input_dict = {** local_input_dict }
802
+ # Keep dtype unchanged as all dtype be correct when given into predict function
803
+ for key in batch_input_dict :
804
+ if not paddle .is_tensor (batch_input_dict [key ]):
805
+ batch_input_dict [key ] = paddle .to_tensor (
806
+ batch_input_dict [key ], stop_gradient = no_grad
807
+ )
808
808
809
809
# forward
810
810
with self .autocast_context_manager (self .use_amp , self .amp_level ):
811
811
batch_output_dict = self .forward_helper .visu_forward (
812
812
expr_dict , batch_input_dict , self .model
813
813
)
814
814
815
- # collect batch data
815
+ # collect local batch output
816
816
for key , batch_output in batch_output_dict .items ():
817
817
pred_dict [key ].append (
818
818
batch_output .detach () if no_grad else batch_output
819
819
)
820
820
821
- # concatenate local predictions
821
+ # concatenate local output
822
822
pred_dict = {key : paddle .concat (value ) for key , value in pred_dict .items ()}
823
823
824
824
if self .world_size > 1 :
825
- # gather global predictions from all devices if world_size > 1
825
+ # gather global output from all devices if world_size > 1
826
826
pred_dict = {
827
827
key : misc .all_gather (value ) for key , value in pred_dict .items ()
828
828
}
829
- # rearrange predictions as the same order of input_dict according
829
+ # rearrange output as the same order of input_dict according
830
830
# to inverse permutation
831
831
perm = np .arange (num_samples_pad , dtype = "int64" )
832
832
perm = np .concatenate (
@@ -837,7 +837,7 @@ def predict(
837
837
perm_inv [perm ] = np .arange (num_samples_pad , dtype = "int64" )
838
838
perm_inv = paddle .to_tensor (perm_inv )
839
839
pred_dict = {key : value [perm_inv ] for key , value in pred_dict .items ()}
840
- # then discard predictions of padding data at the end if num_pad > 0
840
+ # then discard output of padding data at the end if num_pad > 0
841
841
if num_pad > 0 :
842
842
pred_dict = {
843
843
key : value [:num_samples ] for key , value in pred_dict .items ()
0 commit comments