@@ -122,6 +122,7 @@ class IMEXSDCCore(object):
122122
123123 # For NN use to compute initial guess, etc ...
124124 model = None
125+ modelIsCopy = False
125126
126127 @classmethod
127128 def setParameters (cls , nNodes = None , nodeType = None , quadType = None ,
@@ -189,14 +190,15 @@ def setParameters(cls, nNodes=None, nodeType=None, quadType=None,
189190 cls .diagonal = diagonal
190191
191192 @classmethod
192- def setupNN (cls , nnType , nEval = 1 , ** params ):
193+ def setupNN (cls , nnType , nEval = 1 , initSweep = "NN" , modelIsCopy = False , ** params ):
193194 if nnType == "FNOP-1" :
194- from fnop .inference .inference import FNOInference as ModelClass
195+ from cfno .inference .inference import FNOInference as ModelClass
195196 elif nnType == "FNOP-2" :
196- from fnop .training .fno_pysdc import FourierNeuralOp as ModelClass
197+ from cfno .training .pySDC import FourierNeuralOp as ModelClass
197198 cls .model = ModelClass (** params )
198199 cls .nModelEval = nEval
199- cls .initSweep = "NN"
200+ cls .initSweep = initSweep
201+ cls .modelIsCopy = modelIsCopy
200202
201203 # -------------------------------------------------------------------------
202204 # Class properties
@@ -268,7 +270,7 @@ def __init__(self, solver):
268270 self .firstStep = True
269271
270272 # FNO state
271- if self .initSweep == "NN" :
273+ if self .initSweep . startswith ( "NN" ) :
272274 self .stateFNO = [field .copy () for field in self .solver .state ]
273275
274276 @property
@@ -436,18 +438,31 @@ def _toNumpy(self, state):
436438 """Extract from state fields a 3D numpy array containing ux, uz, b and p,
437439 to be given to a NN model."""
438440 for field in state :
441+ old_scales = field .scales [0 ]
439442 field .change_scales (1 )
440443 field .require_grid_space ()
441- return np .asarray (
444+ u = np .asarray (
442445 # ux , uz , b , p
443446 [state [2 ].data [0 ], state [2 ].data [1 ], state [1 ].data , state [0 ].data ])
447+ for field in state :
448+ field .require_coeff_space ()
449+ field .change_scales (old_scales )
450+ return u
451+
444452
445453 def _setStateWith (self , u , state ):
446454 """Write a 3D numpy array containing ux, uz, b and p into a dedalus state."""
455+ for field in state :
456+ old_scales = field .scales [0 ]
457+ field .change_scales (1 )
458+ field .require_grid_space ()
447459 np .copyto (state [2 ].data [0 ], u [0 ]) # ux
448460 np .copyto (state [2 ].data [1 ], u [1 ]) # uz
449461 np .copyto (state [1 ].data , u [2 ]) # b
450462 np .copyto (state [0 ].data , u [3 ]) # p
463+ for field in state :
464+ field .require_coeff_space ()
465+ field .change_scales (old_scales )
451466
452467
453468 def _initSweep (self ):
@@ -517,10 +532,36 @@ def _initSweep(self):
517532 np .copyto (LXk [m ].data , LXk [0 ].data )
518533 np .copyto (Fk [m ].data , Fk [0 ].data )
519534
520- elif self .initSweep == "NN" :
535+ elif self .initSweep in "NN" :
521536 # nothing to do, initialization of tendencies already done
522537 # during last sweep ...
523- pass
538+ self ._evalLX (self .LX [1 ][0 ])
539+ self ._evalF (self .F [1 ][0 ], t0 , dt , wall_time )
540+ print (f"NN, t={ t0 :1.2f} , firstEval : { self .firstEval } " )
541+
542+ elif self .initSweep == "NNI" :
543+ self ._evalLX (self .LX [1 ][0 ])
544+ self ._evalF (self .F [1 ][0 ], t0 , dt , wall_time )
545+ print (f"NNI, t={ t0 :1.2f} , firstEval : { self .firstEval } " )
546+
547+ current = solver .state
548+ state = self .stateFNO
549+
550+ # Evaluate FNO on current state
551+ for c , f in zip (current , state ):
552+ np .copyto (f .data , c .data )
553+ u0 = self ._toNumpy (state )
554+ u1 = self .model (u0 , nEval = self .nModelEval )
555+
556+ # Evaluate RHS with interpolation between current and FNO solution
557+ solver .state = state
558+ for m in range (self .M ):
559+ tEval = t0 + dt * tau [m ]
560+ self ._setStateWith (u0 + tau [m ]* (u1 - u0 ), state )
561+ self ._evalLX (LXk [m ])
562+ self ._evalF (Fk [m ], tEval , dt , wall_time )
563+
564+ solver .state = current
524565
525566 else :
526567 raise NotImplementedError (f'initSweep={ self .initSweep } ' )
@@ -574,16 +615,18 @@ def _sweep(self, k):
574615
575616 tEval = t0 + dt * tau [m ]
576617 # In case NN is used for initial guess (last sweep only)
577- if self .initSweep == "NN" and k == self .nSweeps - 1 :
618+ if self .initSweep == "NN" and k == ( self .nSweeps - 1 ) :
578619 # => evaluate current state with NN to be used
579620 # for the tendencies at k=0 for the initial guess of next step
580621 current = solver .state
581622 state = self .stateFNO
582623 for c , f in zip (current , state ):
583624 np .copyto (f .data , c .data )
584625 uState = self ._toNumpy (state )
585- uNext = self .model (uState , nEval = self .nModelEval )
586- np .clip (uNext [2 ], a_min = 0 , a_max = 1 , out = uNext [2 ]) # temporary : clip buoyancy between 0 and 1
626+ if self .modelIsCopy :
627+ uNext = uState
628+ else :
629+ uNext = self .model (uState , nEval = self .nModelEval )
587630 self ._setStateWith (uNext , state )
588631 solver .state = state
589632 tEval += dt
@@ -593,7 +636,7 @@ def _sweep(self, k):
593636 # Evaluate and store LX with current state
594637 self ._evalLX (LXk1 [m ])
595638
596- if self .initSweep == "NN" and k == self .nSweeps - 1 :
639+ if self .initSweep == "NN" and k == ( self .nSweeps - 1 ) :
597640 # Reset state if it was used for NN initial guess
598641 solver .state = current
599642
0 commit comments