Skip to content

Commit 9d3ef1d

Browse files
committed
TL: update on sdc solver
1 parent 9e67024 commit 9d3ef1d

File tree

1 file changed

+55
-12
lines changed
  • pySDC/playgrounds/dedalus

1 file changed

+55
-12
lines changed

pySDC/playgrounds/dedalus/sdc.py

Lines changed: 55 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ class IMEXSDCCore(object):
122122

123123
# For NN use to compute initial guess, etc ...
124124
model = None
125+
modelIsCopy = False
125126

126127
@classmethod
127128
def setParameters(cls, nNodes=None, nodeType=None, quadType=None,
@@ -189,14 +190,15 @@ def setParameters(cls, nNodes=None, nodeType=None, quadType=None,
189190
cls.diagonal = diagonal
190191

191192
@classmethod
192-
def setupNN(cls, nnType, nEval=1, **params):
193+
def setupNN(cls, nnType, nEval=1, initSweep="NN", modelIsCopy=False, **params):
193194
if nnType == "FNOP-1":
194-
from fnop.inference.inference import FNOInference as ModelClass
195+
from cfno.inference.inference import FNOInference as ModelClass
195196
elif nnType == "FNOP-2":
196-
from fnop.training.fno_pysdc import FourierNeuralOp as ModelClass
197+
from cfno.training.pySDC import FourierNeuralOp as ModelClass
197198
cls.model = ModelClass(**params)
198199
cls.nModelEval = nEval
199-
cls.initSweep = "NN"
200+
cls.initSweep = initSweep
201+
cls.modelIsCopy = modelIsCopy
200202

201203
# -------------------------------------------------------------------------
202204
# Class properties
@@ -268,7 +270,7 @@ def __init__(self, solver):
268270
self.firstStep = True
269271

270272
# FNO state
271-
if self.initSweep == "NN":
273+
if self.initSweep.startswith("NN"):
272274
self.stateFNO = [field.copy() for field in self.solver.state]
273275

274276
@property
@@ -436,18 +438,31 @@ def _toNumpy(self, state):
436438
"""Extract from state fields a 3D numpy array containing ux, uz, b and p,
437439
to be given to a NN model."""
438440
for field in state:
441+
old_scales = field.scales[0]
439442
field.change_scales(1)
440443
field.require_grid_space()
441-
return np.asarray(
444+
u = np.asarray(
442445
# ux , uz , b , p
443446
[state[2].data[0], state[2].data[1], state[1].data, state[0].data])
447+
for field in state:
448+
field.require_coeff_space()
449+
field.change_scales(old_scales)
450+
return u
451+
444452

445453
def _setStateWith(self, u, state):
446454
"""Write a 3D numpy array containing ux, uz, b and p into a dedalus state."""
455+
for field in state:
456+
old_scales = field.scales[0]
457+
field.change_scales(1)
458+
field.require_grid_space()
447459
np.copyto(state[2].data[0], u[0]) # ux
448460
np.copyto(state[2].data[1], u[1]) # uz
449461
np.copyto(state[1].data, u[2]) # b
450462
np.copyto(state[0].data, u[3]) # p
463+
for field in state:
464+
field.require_coeff_space()
465+
field.change_scales(old_scales)
451466

452467

453468
def _initSweep(self):
@@ -517,10 +532,36 @@ def _initSweep(self):
517532
np.copyto(LXk[m].data, LXk[0].data)
518533
np.copyto(Fk[m].data, Fk[0].data)
519534

520-
elif self.initSweep == "NN":
535+
elif self.initSweep in "NN":
521536
# nothing to do, initialization of tendencies already done
522537
# during last sweep ...
523-
pass
538+
self._evalLX(self.LX[1][0])
539+
self._evalF(self.F[1][0], t0, dt, wall_time)
540+
print(f"NN, t={t0:1.2f}, firstEval : {self.firstEval}")
541+
542+
elif self.initSweep == "NNI":
543+
self._evalLX(self.LX[1][0])
544+
self._evalF(self.F[1][0], t0, dt, wall_time)
545+
print(f"NNI, t={t0:1.2f}, firstEval : {self.firstEval}")
546+
547+
current = solver.state
548+
state = self.stateFNO
549+
550+
# Evaluate FNO on current state
551+
for c, f in zip(current, state):
552+
np.copyto(f.data, c.data)
553+
u0 = self._toNumpy(state)
554+
u1 = self.model(u0, nEval=self.nModelEval)
555+
556+
# Evaluate RHS with interpolation between current and FNO solution
557+
solver.state = state
558+
for m in range(self.M):
559+
tEval = t0 + dt*tau[m]
560+
self._setStateWith(u0 + tau[m]*(u1-u0), state)
561+
self._evalLX(LXk[m])
562+
self._evalF(Fk[m], tEval, dt, wall_time)
563+
564+
solver.state = current
524565

525566
else:
526567
raise NotImplementedError(f'initSweep={self.initSweep}')
@@ -574,16 +615,18 @@ def _sweep(self, k):
574615

575616
tEval = t0+dt*tau[m]
576617
# In case NN is used for initial guess (last sweep only)
577-
if self.initSweep == "NN" and k == self.nSweeps-1:
618+
if self.initSweep == "NN" and k == (self.nSweeps-1):
578619
# => evaluate current state with NN to be used
579620
# for the tendencies at k=0 for the initial guess of next step
580621
current = solver.state
581622
state = self.stateFNO
582623
for c, f in zip(current, state):
583624
np.copyto(f.data, c.data)
584625
uState = self._toNumpy(state)
585-
uNext = self.model(uState, nEval=self.nModelEval)
586-
np.clip(uNext[2], a_min=0, a_max=1, out=uNext[2]) # temporary : clip buoyancy between 0 and 1
626+
if self.modelIsCopy:
627+
uNext = uState
628+
else:
629+
uNext = self.model(uState, nEval=self.nModelEval)
587630
self._setStateWith(uNext, state)
588631
solver.state = state
589632
tEval += dt
@@ -593,7 +636,7 @@ def _sweep(self, k):
593636
# Evaluate and store LX with current state
594637
self._evalLX(LXk1[m])
595638

596-
if self.initSweep == "NN" and k == self.nSweeps-1:
639+
if self.initSweep == "NN" and k == (self.nSweeps-1):
597640
# Reset state if it was used for NN initial guess
598641
solver.state = current
599642

0 commit comments

Comments
 (0)