Skip to content

Commit 617d37a

Browse files
committed
feat(ppsci): support data_effient_nopt for training and test
1 parent 747e543 commit 617d37a

File tree

7 files changed

+321
-295
lines changed

7 files changed

+321
-295
lines changed
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
default:
2+
# datapath: '/home/aistudio/data_efficient_nopt/data/possion_64/poisson_64_e5_15_test.h5'
3+
train_path: '/home/aistudio/data_efficient_nopt/data/possion_64/poisson_64_e15_50_train.h5' # pick demos
4+
test_path: '/home/aistudio/data_efficient_nopt/data/possion_64/poisson_64_e15_50_test.h5'
5+
# datapath: '/home/aistudio/data_efficient_nopt/data/possion_64/poisson_64_e20_test.h5'
6+
scales_path: '/home/aistudio/data_efficient_nopt/data/possion_64/poisson_64_e5_15_train_scale.npy'
7+
8+
num_data_workers: 1
9+
subsample: 1
10+
# num_demos: 0
11+
shuffle: False
12+
nx: 64
13+
nt: 64
14+
Lx: !!float 1.0
15+
Ly: !!float 1.0
16+
pack_data: !!bool False
17+
18+
model: 'fno'
19+
layers: [64, 64, 64, 64, 64]
20+
modes1: [65, 65, 65, 65]
21+
modes2: [65, 65, 65, 65]
22+
fc_dim: 128
23+
24+
in_dim: 4
25+
out_dim: 1
26+
mode_cut: 32
27+
embed_cut: 64
28+
fc_cut: 2
29+
dropout: 0
30+
31+
fix_backbone: True
32+
33+
loss_func: mse
34+
35+
batch_size: 1
36+
loss_style: sum
37+
38+
log_to_wandb: !!bool False
39+
logdir: ./log

examples/data_efficient_nopt/config/operators_poisson.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,12 @@ default: &DEFAULT
1717
optimizer: 'adam'
1818
scheduler: 'none'
1919
learning_rate: !!float 1.0
20-
max_epochs: 2
20+
max_epochs: 500
2121
scheduler_epochs: 500
2222
weight_decay: 0
2323
batch_size: 25
2424
# misc
25-
log_to_screen: !!bool False
25+
log_to_screen: !!bool True
2626
save_checkpoint: !!bool False
2727
seed: 0
2828
plot_figs: !!bool False
@@ -45,7 +45,7 @@ default: &DEFAULT
4545
accum_grad: 1
4646
enable_amp: !!bool False
4747
log_interval: 1
48-
checkpoint_save_interval: 10
48+
checkpoint_save_interval: 1000
4949
debug_grad: False
5050

5151
poisson: &poisson

examples/data_efficient_nopt/inference_fno_helmholtz_poisson.py

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -8,24 +8,22 @@
88
from collections import OrderedDict
99

1010
import numpy as np
11-
import torch
12-
import torch.distributed as dist
11+
import paddle
12+
import paddle.distributed as dist
1313
import yaml
1414

1515
# from utils.data_utils import get_data_loader
1616
from data_utils.pois_helm_datasets import get_data_loader
1717
from models.fno import build_fno
1818
from pretrain_basic import l2_err
1919
from scipy.stats import linregress
20-
21-
# from torch.utils.data import DataLoader
2220
from tqdm import tqdm
2321

2422
# from utils.loss_utils import LossMSE
2523
# from utils.YParams import YParams
2624

2725

28-
@torch.no_grad()
26+
@paddle.no_grad()
2927
def get_pred(args):
3028
with open(args.config, "r") as stream:
3129
config = yaml.load(stream, yaml.FullLoader)
@@ -38,7 +36,10 @@ def get_pred(args):
3836
save_path = os.path.join(
3937
save_dir, "fno-prediction-demo%d.pt" % (args.num_demos if args.num_demos else 0)
4038
)
41-
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
39+
if paddle.device.cuda.device_count() >= 1:
40+
paddle.set_device("gpu")
41+
else:
42+
paddle.set_device("cpu")
4243

4344
params = Namespace(**config["default"])
4445
if not hasattr(params, "n_demos"):
@@ -57,17 +58,18 @@ def get_pred(args):
5758
params, params.train_path, dist.is_initialized(), train=False
5859
) # , pack=data_param.pack_data)
5960
input_demos, target_demos = next(iter(dataloader_icl))
60-
input_demos = input_demos.to(device)
61-
target_demos = target_demos.to(device)
61+
input_demos = input_demos
62+
target_demos = target_demos
6263

6364
# model_param = Namespace(**config['model'])
6465
# model_param.n_demos = params.n_demos
65-
model = build_fno(params).to(device)
66+
model = build_fno(params)
6667

6768
if args.ckpt_path:
68-
checkpoint = torch.load(args.ckpt_path)
69+
raise NotImplementedError("Loading checkpoint is not supported")
70+
checkpoint = paddle.load(args.ckpt_path)
6971
try:
70-
model.load_state_dict(checkpoint["model_state"])
72+
model.set_state_dict(checkpoint["model_state"])
7173
except: # noqa
7274
new_state_dict = OrderedDict()
7375
for key, val in checkpoint["model_state"].items():
@@ -111,7 +113,7 @@ def get_pred(args):
111113
# for u, a_in in dataloader:
112114
for inputs, targets in pbar:
113115
# if len(pred_list) > len(dataloader) // 100: break
114-
inputs, targets = inputs.to(device), targets.to(device)
116+
inputs, targets = inputs, targets
115117
if args.num_demos is None or args.num_demos == 0:
116118
u = model(inputs)
117119
else:
@@ -124,7 +126,8 @@ def get_pred(args):
124126
data_loss = l2_err(u.detach(), targets.detach())
125127
losses.append(data_loss.item())
126128
data_loss_normalized = l2_err(
127-
u.detach() / torch.abs(u).max(), targets.detach() / torch.abs(targets).max()
129+
u.detach() / paddle.abs(u).max(),
130+
targets.detach() / paddle.abs(targets).max(),
128131
)
129132
losses_normalized.append(data_loss_normalized.item())
130133
# print(data_loss.item())
@@ -133,8 +136,8 @@ def get_pred(args):
133136

134137
# print(np.mean(losses))
135138
slope, intercept, r, p, se = linregress(
136-
torch.cat(pred_list, dim=0).view(-1).numpy(),
137-
torch.cat(truth_list, dim=0).view(-1).numpy(),
139+
paddle.concat(pred_list, axis=0).view([-1]).numpy(),
140+
paddle.concat(truth_list, axis=0).view([-1]).numpy(),
138141
)
139142
print(
140143
"RMSE:",
@@ -146,9 +149,9 @@ def get_pred(args):
146149
"Slope:",
147150
slope,
148151
)
149-
truth_arr = torch.cat(truth_list, dim=0)
150-
pred_arr = torch.cat(pred_list, dim=0)
151-
torch.save(
152+
truth_arr = paddle.concat(truth_list, axis=0)
153+
pred_arr = paddle.concat(pred_list, axis=0)
154+
paddle.save(
152155
{
153156
"truth": truth_arr,
154157
"pred": pred_arr,
@@ -162,7 +165,6 @@ def get_pred(args):
162165

163166

164167
if __name__ == "__main__":
165-
torch.backends.cudnn.benchmark = True
166168
parser = ArgumentParser()
167169
parser.add_argument("--config", type=str, default="config/inference_helmholtz.yaml")
168170
# parser.add_argument('--ckpt_path', type=str, default='/pscratch/sd/p/puren93/neuralopt/expts/helm-64-o5_15_ft0/all_mask_m6/checkpoints/ckpt.tar')

examples/data_efficient_nopt/models/basics.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -240,13 +240,13 @@ def __init__(self, in_channels, out_channels, modes1, modes2):
240240
)
241241

242242
def forward(self, x: paddle.Tensor):
243-
size_0 = x.size(-2)
244-
size_1 = x.size(-1)
243+
size_0 = x.shape[-2]
244+
size_1 = x.shape[-1]
245245
batchsize = x.shape[0]
246246
# dtype = x.dtype
247247

248248
# Compute Fourier coeffcients up to factor of e^(- something constant)
249-
x_ft = paddle.fft.rfft2(x.float(), axes=(-2, -1), norm="ortho")
249+
x_ft = paddle.fft.rfft2(x.astype(paddle.float32), axes=(-2, -1), norm="ortho")
250250
x_ft = paddle.as_real(x_ft)
251251

252252
out_ft = paddle.zeros(

0 commit comments

Comments
 (0)