@@ -445,9 +445,9 @@ def evaluate(cfg: DictConfig):
445
445
)
446
446
447
447
# visualize prediction
448
- t = np .linspace (cfg .T , cfg .T , 1 )
449
- x = np .linspace (0.0 , cfg .Lx , cfg .Nd )
450
- y = np .linspace (0.0 , cfg .Ly , cfg .Nd )
448
+ t = np .linspace (cfg .T , cfg .T , 1 , dtype = dtype )
449
+ x = np .linspace (0.0 , cfg .Lx , cfg .Nd , dtype = dtype )
450
+ y = np .linspace (0.0 , cfg .Ly , cfg .Nd , dtype = dtype )
451
451
_ , x_grid , y_grid = np .meshgrid (t , x , y )
452
452
453
453
x_test = misc .cartesian_product (t , x , y )
@@ -542,9 +542,9 @@ def inference(cfg: DictConfig):
542
542
predictor = pinn_predictor .PINNPredictor (cfg )
543
543
544
544
# visualize prediction
545
- t = np .linspace (cfg .T , cfg .T , 1 , dtype = np . float32 )
546
- x = np .linspace (0.0 , cfg .Lx , cfg .Nd , dtype = np . float32 )
547
- y = np .linspace (0.0 , cfg .Ly , cfg .Nd , dtype = np . float32 )
545
+ t = np .linspace (cfg .T , cfg .T , 1 , dtype = dtype )
546
+ x = np .linspace (0.0 , cfg .Lx , cfg .Nd , dtype = dtype )
547
+ y = np .linspace (0.0 , cfg .Ly , cfg .Nd , dtype = dtype )
548
548
_ , x_grid , y_grid = np .meshgrid (t , x , y )
549
549
550
550
x_test = misc .cartesian_product (t , x , y )
0 commit comments