File tree Expand file tree Collapse file tree 5 files changed +16
-12
lines changed
jointContribution/IJCAI_2024 Expand file tree Collapse file tree 5 files changed +16
-12
lines changed Original file line number Diff line number Diff line change @@ -41,7 +41,7 @@ mv ./ckpts/bju/geom/ckpt ./bju/geom/
41
41
mv ./ckpts/bju/pretrained_checkpoint.pdparams ./bju/pretrained_checkpoint.pdparams
42
42
43
43
# zhongzaicanyu
44
- # No pretrained checkpoint yet.
44
+ mv ./ckpts/zhongzaicanyu/pretrained_checkpoint.pdparams ./zhongzaicanyu/pretrained_checkpoint.pdparams
45
45
```
46
46
47
47
## Inference
@@ -74,5 +74,5 @@ python infer.py
74
74
python infer.py -- train_data_dir " ./Dataset/Trainset_track_B" -- test_data_dir " ./Dataset/Testset_track_B/Inference" -- info_dir " ./Dataset/Testset_track_B/Auxiliary" -- ulip_ckpt " ./geom/ckpt/checkpoint_pointbert.pdparams"
75
75
76
76
# ## zhongzaicanyu
77
- python infer.py # not work yet.
77
+ python infer.py -- data_dir " ./Dataset/data_centroid_track_B_vtk " -- test_data_dir " ./Dataset/track_B_vtk " -- save_dir " ./Dataset/data_centroid_track_B_vtk_preprocessed_data "
78
78
```
Original file line number Diff line number Diff line change @@ -457,6 +457,8 @@ def __getitem__(self, idx):
457
457
def collate_fn (self , batch ):
458
458
batch_data = [data for (data , _ ) in batch ]
459
459
batch_shape = paddle .stack ([shape for (_ , shape ) in batch ], axis = 0 )
460
+ if len (batch_data ) == 1 :
461
+ return batch_data [0 ], batch_shape
460
462
return batch_data , batch_shape
461
463
462
464
Load Diff Large diffs are not rendered by default.
Original file line number Diff line number Diff line change @@ -54,7 +54,10 @@ def parse_args():
54
54
55
55
56
56
if __name__ == "__main__" :
57
- # only run it first time
57
+ print (
58
+ "Attention: Please run and only run `data_process()` at first time in `infer.py`. "
59
+ "And change path in the file before run it."
60
+ )
58
61
data_process ()
59
62
60
63
# load setting
@@ -126,8 +129,6 @@ def parse_args():
126
129
print (f"Processing mesh index: { mesh_index } " )
127
130
else :
128
131
raise ValueError (f"Invalid mesh file format: { mesh_file } " )
129
- cfd_data = cfd_data .to (device )
130
- geom = geom .to (device )
131
132
tic = time .time ()
132
133
out = model ((cfd_data , geom ))
133
134
toc = time .time ()
@@ -138,12 +139,7 @@ def parse_args():
138
139
press_output = press_output * std_out [- 1 ] + mean_out [- 1 ]
139
140
press_output = press_output .detach ().cpu ().numpy ()
140
141
np .save (
141
- "./results/"
142
- + args .cfd_model
143
- + "_B/"
144
- + "press"
145
- + "_"
146
- + f"{ mesh_index } .npy" ,
142
+ "./results/" + "press" + "_" + f"{ mesh_index } .npy" ,
147
143
press_output ,
148
144
)
149
145
times .append (toc - tic )
Original file line number Diff line number Diff line change
1
+ einops
2
+ numpy
3
+ paddlepaddle_gpu
4
+ scikit_learn
5
+ tqdm
6
+ vtk
You can’t perform that action at this time.
0 commit comments