8
8
from collections import OrderedDict
9
9
10
10
import numpy as np
11
- import torch
12
- import torch .distributed as dist
11
+ import paddle
12
+ import paddle .distributed as dist
13
13
import yaml
14
14
15
15
# from utils.data_utils import get_data_loader
16
16
from data_utils .pois_helm_datasets import get_data_loader
17
17
from models .fno import build_fno
18
18
from pretrain_basic import l2_err
19
19
from scipy .stats import linregress
20
-
21
- # from torch.utils.data import DataLoader
22
20
from tqdm import tqdm
23
21
24
22
# from utils.loss_utils import LossMSE
25
23
# from utils.YParams import YParams
26
24
27
25
28
- @torch .no_grad ()
26
+ @paddle .no_grad ()
29
27
def get_pred (args ):
30
28
with open (args .config , "r" ) as stream :
31
29
config = yaml .load (stream , yaml .FullLoader )
@@ -38,7 +36,10 @@ def get_pred(args):
38
36
save_path = os .path .join (
39
37
save_dir , "fno-prediction-demo%d.pt" % (args .num_demos if args .num_demos else 0 )
40
38
)
41
- device = torch .device ("cuda:0" if torch .cuda .is_available () else "cpu" )
39
+ if paddle .device .cuda .device_count () >= 1 :
40
+ paddle .set_device ("gpu" )
41
+ else :
42
+ paddle .set_device ("cpu" )
42
43
43
44
params = Namespace (** config ["default" ])
44
45
if not hasattr (params , "n_demos" ):
@@ -57,17 +58,18 @@ def get_pred(args):
57
58
params , params .train_path , dist .is_initialized (), train = False
58
59
) # , pack=data_param.pack_data)
59
60
input_demos , target_demos = next (iter (dataloader_icl ))
60
- input_demos = input_demos . to ( device )
61
- target_demos = target_demos . to ( device )
61
+ input_demos = input_demos
62
+ target_demos = target_demos
62
63
63
64
# model_param = Namespace(**config['model'])
64
65
# model_param.n_demos = params.n_demos
65
- model = build_fno (params ). to ( device )
66
+ model = build_fno (params )
66
67
67
68
if args .ckpt_path :
68
- checkpoint = torch .load (args .ckpt_path )
69
+ raise NotImplementedError ("Loading checkpoint is not supported" )
70
+ checkpoint = paddle .load (args .ckpt_path )
69
71
try :
70
- model .load_state_dict (checkpoint ["model_state" ])
72
+ model .set_state_dict (checkpoint ["model_state" ])
71
73
except : # noqa
72
74
new_state_dict = OrderedDict ()
73
75
for key , val in checkpoint ["model_state" ].items ():
@@ -111,7 +113,7 @@ def get_pred(args):
111
113
# for u, a_in in dataloader:
112
114
for inputs , targets in pbar :
113
115
# if len(pred_list) > len(dataloader) // 100: break
114
- inputs , targets = inputs . to ( device ) , targets . to ( device )
116
+ inputs , targets = inputs , targets
115
117
if args .num_demos is None or args .num_demos == 0 :
116
118
u = model (inputs )
117
119
else :
@@ -124,7 +126,8 @@ def get_pred(args):
124
126
data_loss = l2_err (u .detach (), targets .detach ())
125
127
losses .append (data_loss .item ())
126
128
data_loss_normalized = l2_err (
127
- u .detach () / torch .abs (u ).max (), targets .detach () / torch .abs (targets ).max ()
129
+ u .detach () / paddle .abs (u ).max (),
130
+ targets .detach () / paddle .abs (targets ).max (),
128
131
)
129
132
losses_normalized .append (data_loss_normalized .item ())
130
133
# print(data_loss.item())
@@ -133,8 +136,8 @@ def get_pred(args):
133
136
134
137
# print(np.mean(losses))
135
138
slope , intercept , r , p , se = linregress (
136
- torch . cat (pred_list , dim = 0 ).view (- 1 ).numpy (),
137
- torch . cat (truth_list , dim = 0 ).view (- 1 ).numpy (),
139
+ paddle . concat (pred_list , axis = 0 ).view ([ - 1 ] ).numpy (),
140
+ paddle . concat (truth_list , axis = 0 ).view ([ - 1 ] ).numpy (),
138
141
)
139
142
print (
140
143
"RMSE:" ,
@@ -146,9 +149,9 @@ def get_pred(args):
146
149
"Slope:" ,
147
150
slope ,
148
151
)
149
- truth_arr = torch . cat (truth_list , dim = 0 )
150
- pred_arr = torch . cat (pred_list , dim = 0 )
151
- torch .save (
152
+ truth_arr = paddle . concat (truth_list , axis = 0 )
153
+ pred_arr = paddle . concat (pred_list , axis = 0 )
154
+ paddle .save (
152
155
{
153
156
"truth" : truth_arr ,
154
157
"pred" : pred_arr ,
@@ -162,7 +165,6 @@ def get_pred(args):
162
165
163
166
164
167
if __name__ == "__main__" :
165
- torch .backends .cudnn .benchmark = True
166
168
parser = ArgumentParser ()
167
169
parser .add_argument ("--config" , type = str , default = "config/inference_helmholtz.yaml" )
168
170
# parser.add_argument('--ckpt_path', type=str, default='/pscratch/sd/p/puren93/neuralopt/expts/helm-64-o5_15_ft0/all_mask_m6/checkpoints/ckpt.tar')
0 commit comments