Skip to content

Commit ef431ee

Browse files
fix input dtype in shock wave (#954)
1 parent adfadca commit ef431ee

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

examples/shock_wave/shock_wave.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -445,9 +445,9 @@ def evaluate(cfg: DictConfig):
445445
)
446446

447447
# visualize prediction
448-
t = np.linspace(cfg.T, cfg.T, 1)
449-
x = np.linspace(0.0, cfg.Lx, cfg.Nd)
450-
y = np.linspace(0.0, cfg.Ly, cfg.Nd)
448+
t = np.linspace(cfg.T, cfg.T, 1, dtype=dtype)
449+
x = np.linspace(0.0, cfg.Lx, cfg.Nd, dtype=dtype)
450+
y = np.linspace(0.0, cfg.Ly, cfg.Nd, dtype=dtype)
451451
_, x_grid, y_grid = np.meshgrid(t, x, y)
452452

453453
x_test = misc.cartesian_product(t, x, y)
@@ -542,9 +542,9 @@ def inference(cfg: DictConfig):
542542
predictor = pinn_predictor.PINNPredictor(cfg)
543543

544544
# visualize prediction
545-
t = np.linspace(cfg.T, cfg.T, 1, dtype=np.float32)
546-
x = np.linspace(0.0, cfg.Lx, cfg.Nd, dtype=np.float32)
547-
y = np.linspace(0.0, cfg.Ly, cfg.Nd, dtype=np.float32)
545+
t = np.linspace(cfg.T, cfg.T, 1, dtype=dtype)
546+
x = np.linspace(0.0, cfg.Lx, cfg.Nd, dtype=dtype)
547+
y = np.linspace(0.0, cfg.Ly, cfg.Nd, dtype=dtype)
548548
_, x_grid, y_grid = np.meshgrid(t, x, y)
549549

550550
x_test = misc.cartesian_product(t, x, y)

0 commit comments

Comments
 (0)