Skip to content

Commit f54bb8f

Browse files
committed
TL: coupling with fnop
1 parent cefdf0e commit f54bb8f

File tree

1 file changed

+48
-4
lines changed
  • pySDC/playgrounds/dedalus

1 file changed

+48
-4
lines changed

pySDC/playgrounds/dedalus/sdc.py

Lines changed: 48 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,9 @@ class IMEXSDCCore(object):
120120
dt = None
121121
axpy = None
122122

123+
# For NN use to compute initial guess, etc ...
124+
model = None
125+
123126
@classmethod
124127
def setParameters(cls, nNodes=None, nodeType=None, quadType=None,
125128
implSweep=None, explSweep=None, initSweep=None,
@@ -185,6 +188,15 @@ def setParameters(cls, nNodes=None, nodeType=None, quadType=None,
185188
diagonal *= np.all(np.diag(np.diag(cls.QDelta0)) == cls.QDelta0)
186189
cls.diagonal = diagonal
187190

191+
@classmethod
192+
def setupNN(cls, nnType, **params):
193+
if nnType == "FNOP-C":
194+
from fnop.inference.inference import FNOInference as ModelClass
195+
elif nnType == "FNOP-T":
196+
from fnop.fno import FourierNeuralOp as ModelClass
197+
cls.model = ModelClass(**params)
198+
cls.initSweep = "NN"
199+
188200
# -------------------------------------------------------------------------
189201
# Class properties
190202
# -------------------------------------------------------------------------
@@ -252,6 +264,7 @@ def __init__(self, solver):
252264
self.axpy = blas.get_blas_funcs('axpy', dtype=solver.dtype)
253265
self.dt = None
254266
self.firstEval = True
267+
self.firstStep = True
255268

256269
@property
257270
def M(self):
@@ -414,6 +427,22 @@ def _presetStateCoeffSpace(self, state):
414427
for field in state:
415428
field.preset_layout('c')
416429

430+
def _toNumpy(self, state):
431+
"""Extract from state fields a 3D numpy array containing ux, uz, b and p,
432+
to be given to a NN model."""
433+
return np.asarray(
434+
# ux , uz , b , p
435+
[state[2].data[0], state[2].data[1], state[1].data, state[0].data])
436+
437+
def _setStateWith(self, u, state):
438+
"""Write a 3D numpy array containing ux, uz, b and p into a dedalus state.
439+
Warning : state has to be in grid space"""
440+
np.copyto(state[2].data[0], u[0]) # ux
441+
np.copyto(state[2].data[1], u[1]) # uz
442+
np.copyto(state[1].data, u[2]) # b
443+
np.copyto(state[0].data, u[3]) # p
444+
445+
417446
def _initSweep(self):
418447
"""
419448
Initialize node terms for one given time-step
@@ -474,7 +503,7 @@ def _initSweep(self):
474503
# Evaluate and store F(X, t) with current state
475504
self._evalF(Fk[m], t0+dt*tau[m], dt, wall_time)
476505

477-
elif self.initSweep == 'COPY':
506+
elif self.initSweep == 'COPY' or (self.initSweep == "NN" and self.firstStep):
478507
self._evalLX(LXk[0])
479508
self._evalF(Fk[0], t0, dt, wall_time)
480509
for m in range(1, self.M):
@@ -525,16 +554,30 @@ def _sweep(self, k):
525554
self._solveAndStoreState(k, m)
526555

527556
# Avoid non necessary RHS evaluations work
528-
if not self.forceProl and k == self.nSweeps-1:
557+
if not self.forceProl and k == self.nSweeps-1 and self.initSweep != "NN":
529558
if self.diagonal:
530559
continue
531560
elif m == self.M-1:
532561
continue
533562

563+
tEval = t0+dt*tau[m]
564+
# In case NN is used for initial guess (last sweep only)
565+
if self.initSweep == "NN" and k == self.nSweeps-1:
566+
# => evaluate current state with NN to be used
567+
# for the tendencies at k=0 for the initial guess of next step
568+
uState = self._toNumpy(solver.state)
569+
uNext = self.model(uState)
570+
self._setStateWith(uNext, solver.state)
571+
tEval += dt
572+
534573
# Evaluate and store LX with current state
535574
self._evalLX(LXk1[m])
536575
# Evaluate and store F(X, t) with current state
537-
self._evalF(Fk1[m], t0+dt*tau[m], dt, wall_time)
576+
self._evalF(Fk1[m], tEval, dt, wall_time)
577+
578+
if self.initSweep == "NN" and k == self.nSweeps-1:
579+
# Reset state if it was used for NN initial guess
580+
self._setStateWith(uState, solver.state)
538581

539582
# Inverse position for iterate k and k+1 in storage
540583
# ie making the new evaluation the old for next iteration
@@ -611,6 +654,7 @@ def step(self, dt, wall_time):
611654
if self.doProlongation:
612655
self._prolongation()
613656

614-
# Update simulation time and reset evaluation tag
657+
# Update simulation time and update tags
615658
self.solver.sim_time += dt
616659
self.firstEval = True
660+
self.firstStep = False

0 commit comments

Comments
 (0)