@@ -120,6 +120,9 @@ class IMEXSDCCore(object):
120120 dt = None
121121 axpy = None
122122
123+ # For NN use to compute initial guess, etc ...
124+ model = None
125+
123126 @classmethod
124127 def setParameters (cls , nNodes = None , nodeType = None , quadType = None ,
125128 implSweep = None , explSweep = None , initSweep = None ,
@@ -185,6 +188,15 @@ def setParameters(cls, nNodes=None, nodeType=None, quadType=None,
185188 diagonal *= np .all (np .diag (np .diag (cls .QDelta0 )) == cls .QDelta0 )
186189 cls .diagonal = diagonal
187190
191+ @classmethod
192+ def setupNN (cls , nnType , ** params ):
193+ if nnType == "FNOP-C" :
194+ from fnop .inference .inference import FNOInference as ModelClass
195+ elif nnType == "FNOP-T" :
196+ from fnop .fno import FourierNeuralOp as ModelClass
197+ cls .model = ModelClass (** params )
198+ cls .initSweep = "NN"
199+
188200 # -------------------------------------------------------------------------
189201 # Class properties
190202 # -------------------------------------------------------------------------
@@ -252,6 +264,7 @@ def __init__(self, solver):
252264 self .axpy = blas .get_blas_funcs ('axpy' , dtype = solver .dtype )
253265 self .dt = None
254266 self .firstEval = True
267+ self .firstStep = True
255268
256269 @property
257270 def M (self ):
@@ -414,6 +427,22 @@ def _presetStateCoeffSpace(self, state):
414427 for field in state :
415428 field .preset_layout ('c' )
416429
430+ def _toNumpy (self , state ):
431+ """Extract from state fields a 3D numpy array containing ux, uz, b and p,
432+ to be given to a NN model."""
433+ return np .asarray (
434+ # ux , uz , b , p
435+ [state [2 ].data [0 ], state [2 ].data [1 ], state [1 ].data , state [0 ].data ])
436+
437+ def _setStateWith (self , u , state ):
438+ """Write a 3D numpy array containing ux, uz, b and p into a dedalus state.
439+ Warning : state has to be in grid space"""
440+ np .copyto (state [2 ].data [0 ], u [0 ]) # ux
441+ np .copyto (state [2 ].data [1 ], u [1 ]) # uz
442+ np .copyto (state [1 ].data , u [2 ]) # b
443+ np .copyto (state [0 ].data , u [3 ]) # p
444+
445+
417446 def _initSweep (self ):
418447 """
419448 Initialize node terms for one given time-step
@@ -474,7 +503,7 @@ def _initSweep(self):
474503 # Evaluate and store F(X, t) with current state
475504 self ._evalF (Fk [m ], t0 + dt * tau [m ], dt , wall_time )
476505
477- elif self .initSweep == 'COPY' :
506+ elif self .initSweep == 'COPY' or ( self . initSweep == "NN" and self . firstStep ) :
478507 self ._evalLX (LXk [0 ])
479508 self ._evalF (Fk [0 ], t0 , dt , wall_time )
480509 for m in range (1 , self .M ):
@@ -525,16 +554,30 @@ def _sweep(self, k):
525554 self ._solveAndStoreState (k , m )
526555
527556 # Avoid non necessary RHS evaluations work
528- if not self .forceProl and k == self .nSweeps - 1 :
557+ if not self .forceProl and k == self .nSweeps - 1 and self . initSweep != "NN" :
529558 if self .diagonal :
530559 continue
531560 elif m == self .M - 1 :
532561 continue
533562
563+ tEval = t0 + dt * tau [m ]
564+ # In case NN is used for initial guess (last sweep only)
565+ if self .initSweep == "NN" and k == self .nSweeps - 1 :
566+ # => evaluate current state with NN to be used
567+ # for the tendencies at k=0 for the initial guess of next step
568+ uState = self ._toNumpy (solver .state )
569+ uNext = self .model (uState )
570+ self ._setStateWith (uNext , solver .state )
571+ tEval += dt
572+
534573 # Evaluate and store LX with current state
535574 self ._evalLX (LXk1 [m ])
536575 # Evaluate and store F(X, t) with current state
537- self ._evalF (Fk1 [m ], t0 + dt * tau [m ], dt , wall_time )
576+ self ._evalF (Fk1 [m ], tEval , dt , wall_time )
577+
578+ if self .initSweep == "NN" and k == self .nSweeps - 1 :
579+ # Reset state if it was used for NN initial guess
580+ self ._setStateWith (uState , solver .state )
538581
539582 # Inverse position for iterate k and k+1 in storage
540583 # ie making the new evaluation the old for next iteration
@@ -611,6 +654,7 @@ def step(self, dt, wall_time):
611654 if self .doProlongation :
612655 self ._prolongation ()
613656
614- # Update simulation time and reset evaluation tag
657+ # Update simulation time and update tags
615658 self .solver .sim_time += dt
616659 self .firstEval = True
660+ self .firstStep = False
0 commit comments