Skip to content

Commit 78bd343

Browse files
authored
[Example]fix dtype error of nsfnet example (#1076)
1 parent 025f937 commit 78bd343

File tree

1 file changed

+9
-8
lines changed

1 file changed

+9
-8
lines changed

examples/nsfnet/VP_NSFNet3.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -445,11 +445,11 @@ def evaluate(cfg: DictConfig):
445445
for i in [0, 0.25, 0.5, 0.75, 1.0]:
446446
x_star, y_star, z_star = np.mgrid[-1.0:1.0:100j, -1.0:1.0:100j, -1.0:1.0:100j]
447447
x_star, y_star, z_star = (
448-
x_star.reshape(-1, 1),
449-
y_star.reshape(-1, 1),
450-
z_star.reshape(-1, 1),
448+
x_star.reshape(-1, 1).astype(np.float32),
449+
y_star.reshape(-1, 1).astype(np.float32),
450+
z_star.reshape(-1, 1).astype(np.float32),
451451
)
452-
t_star = i * np.ones(x_star.shape)
452+
t_star = i * np.ones(x_star.shape, dtype=x_star.dtype)
453453
u_star, v_star, w_star, p_star = analytic_solution_generate(
454454
x_star, y_star, z_star, t_star
455455
)
@@ -474,12 +474,12 @@ def evaluate(cfg: DictConfig):
474474

475475
## plot vorticity
476476
grid_x, grid_y = np.mgrid[-1.0:1.0:1000j, -1.0:1.0:1000j]
477-
grid_x = grid_x.reshape(-1, 1)
478-
grid_y = grid_y.reshape(-1, 1)
479-
grid_z = np.zeros(grid_x.shape)
477+
grid_x = grid_x.reshape(-1, 1).astype(np.float32)
478+
grid_y = grid_y.reshape(-1, 1).astype(np.float32)
479+
grid_z = np.zeros(grid_x.shape).astype(np.float32)
480480
T = np.linspace(0, 1, 101)
481481
for i in T:
482-
t_star = i * np.ones(x_star.shape)
482+
t_star = i * np.ones(x_star.shape, dtype=x_star.dtype)
483483
u_star, v_star, w_star, p_star = analytic_solution_generate(
484484
grid_x, grid_y, grid_z, t_star
485485
)
@@ -533,6 +533,7 @@ def evaluate(cfg: DictConfig):
533533
ax[2, 0].set_title("w_exact")
534534
ax[2, 1].set_title("w_pred")
535535
time = "%.3f" % i
536+
logger.info(f"saving velocity_t={str(time)}.png")
536537
fig.savefig(OUTPUT_DIR + f"/velocity_t={str(time)}.png")
537538

538539

0 commit comments

Comments
 (0)