Skip to content

Commit d81b49f

Browse files
authored
[Update]update for zhongzaicanyu of IJCAI_2024 (#971)
* [Update]update for zhongzaicanyu of IJCAI_2024 * update README.txt
1 parent e580dca commit d81b49f

File tree

5 files changed

+16
-12
lines changed

5 files changed

+16
-12
lines changed

jointContribution/IJCAI_2024/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ mv ./ckpts/bju/geom/ckpt ./bju/geom/
4141
mv ./ckpts/bju/pretrained_checkpoint.pdparams ./bju/pretrained_checkpoint.pdparams
4242

4343
# zhongzaicanyu
44-
# No pretrained checkpoint yet.
44+
mv ./ckpts/zhongzaicanyu/pretrained_checkpoint.pdparams ./zhongzaicanyu/pretrained_checkpoint.pdparams
4545
```
4646

4747
## Inference
@@ -74,5 +74,5 @@ python infer.py
7474
python infer.py --train_data_dir "./Dataset/Trainset_track_B" --test_data_dir "./Dataset/Testset_track_B/Inference" --info_dir "./Dataset/Testset_track_B/Auxiliary" --ulip_ckpt "./geom/ckpt/checkpoint_pointbert.pdparams"
7575

7676
### zhongzaicanyu
77-
python infer.py # not work yet.
77+
python infer.py --data_dir "./Dataset/data_centroid_track_B_vtk" --test_data_dir "./Dataset/track_B_vtk" --save_dir "./Dataset/data_centroid_track_B_vtk_preprocessed_data"
7878
```

jointContribution/IJCAI_2024/zhongzaicanyu/dataset.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -457,6 +457,8 @@ def __getitem__(self, idx):
457457
def collate_fn(self, batch):
458458
batch_data = [data for (data, _) in batch]
459459
batch_shape = paddle.stack([shape for (_, shape) in batch], axis=0)
460+
if len(batch_data) == 1:
461+
return batch_data[0], batch_shape
460462
return batch_data, batch_shape
461463

462464

jointContribution/IJCAI_2024/zhongzaicanyu/download_dataset.ipynb

Lines changed: 1 addition & 1 deletion
Large diffs are not rendered by default.

jointContribution/IJCAI_2024/zhongzaicanyu/infer.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,10 @@ def parse_args():
5454

5555

5656
if __name__ == "__main__":
57-
# only run it first time
57+
print(
58+
"Attention: Please run and only run `data_process()` at first time in `infer.py`. "
59+
"And change path in the file before run it."
60+
)
5861
data_process()
5962

6063
# load setting
@@ -126,8 +129,6 @@ def parse_args():
126129
print(f"Processing mesh index: {mesh_index}")
127130
else:
128131
raise ValueError(f"Invalid mesh file format: {mesh_file}")
129-
cfd_data = cfd_data.to(device)
130-
geom = geom.to(device)
131132
tic = time.time()
132133
out = model((cfd_data, geom))
133134
toc = time.time()
@@ -138,12 +139,7 @@ def parse_args():
138139
press_output = press_output * std_out[-1] + mean_out[-1]
139140
press_output = press_output.detach().cpu().numpy()
140141
np.save(
141-
"./results/"
142-
+ args.cfd_model
143-
+ "_B/"
144-
+ "press"
145-
+ "_"
146-
+ f"{mesh_index}.npy",
142+
"./results/" + "press" + "_" + f"{mesh_index}.npy",
147143
press_output,
148144
)
149145
times.append(toc - tic)
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
einops
2+
numpy
3+
paddlepaddle_gpu
4+
scikit_learn
5+
tqdm
6+
vtk

0 commit comments

Comments
 (0)