@@ -193,7 +193,7 @@ def setupNN(cls, nnType, **params):
193193 if nnType == "FNOP-1" :
194194 from fnop .inference .inference import FNOInference as ModelClass
195195 elif nnType == "FNOP-2" :
196- from fnop .fno import FourierNeuralOp as ModelClass
196+ from fnop .training . fno_pysdc import FourierNeuralOp as ModelClass
197197 cls .model = ModelClass (** params )
198198 cls .initSweep = "NN"
199199
@@ -430,13 +430,15 @@ def _presetStateCoeffSpace(self, state):
430430 def _toNumpy (self , state ):
431431 """Extract from state fields a 3D numpy array containing ux, uz, b and p,
432432 to be given to a NN model."""
433+ for field in state :
434+ field .change_scales (1 )
435+ field .require_grid_space ()
433436 return np .asarray (
434437 # ux , uz , b , p
435438 [state [2 ].data [0 ], state [2 ].data [1 ], state [1 ].data , state [0 ].data ])
436439
437440 def _setStateWith (self , u , state ):
438- """Write a 3D numpy array containing ux, uz, b and p into a dedalus state.
439- Warning : state has to be in grid space"""
441+ """Write a 3D numpy array containing ux, uz, b and p into a dedalus state."""
440442 np .copyto (state [2 ].data [0 ], u [0 ]) # ux
441443 np .copyto (state [2 ].data [1 ], u [1 ]) # uz
442444 np .copyto (state [1 ].data , u [2 ]) # b
@@ -510,6 +512,11 @@ def _initSweep(self):
510512 np .copyto (LXk [m ].data , LXk [0 ].data )
511513 np .copyto (Fk [m ].data , Fk [0 ].data )
512514
515+ elif self .initSweep == "NN" :
516+ # nothing to do, initialization of tendencies already done
517+ # during last sweep ...
518+ pass
519+
513520 else :
514521 raise NotImplementedError (f'initSweep={ self .initSweep } ' )
515522
@@ -565,19 +572,23 @@ def _sweep(self, k):
565572 if self .initSweep == "NN" and k == self .nSweeps - 1 :
566573 # => evaluate current state with NN to be used
567574 # for the tendencies at k=0 for the initial guess of next step
568- uState = self ._toNumpy (solver .state )
575+ current = solver .state
576+ state = [field .copy () for field in current ]
577+ uState = self ._toNumpy (state )
569578 uNext = self .model (uState )
570- self ._setStateWith (uNext , solver .state )
579+ np .clip (uNext [2 ], a_min = 0 , a_max = 1 , out = uNext [2 ]) # temporary : clip buoyancy between 0 and 1
580+ self ._setStateWith (uNext , state )
581+ solver .state = state
571582 tEval += dt
572583
573- # Evaluate and store LX with current state
574- self ._evalLX (LXk1 [m ])
575584 # Evaluate and store F(X, t) with current state
576585 self ._evalF (Fk1 [m ], tEval , dt , wall_time )
586+ # Evaluate and store LX with current state
587+ self ._evalLX (LXk1 [m ])
577588
578589 if self .initSweep == "NN" and k == self .nSweeps - 1 :
579590 # Reset state if it was used for NN initial guess
580- self . _setStateWith ( uState , solver .state )
591+ solver .state = current
581592
582593 # Inverse position for iterate k and k+1 in storage
583594 # ie making the new evaluation the old for next iteration
0 commit comments