Skip to content

Commit 9e67024

Browse files
committed
TL: allowed chained evaluation for FNO model
1 parent f6fa9df commit 9e67024

File tree

1 file changed

+3
-2
lines changed
  • pySDC/playgrounds/dedalus

1 file changed

+3
-2
lines changed

pySDC/playgrounds/dedalus/sdc.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -189,12 +189,13 @@ def setParameters(cls, nNodes=None, nodeType=None, quadType=None,
189189
cls.diagonal = diagonal
190190

191191
@classmethod
192-
def setupNN(cls, nnType, **params):
192+
def setupNN(cls, nnType, nEval=1, **params):
193193
if nnType == "FNOP-1":
194194
from fnop.inference.inference import FNOInference as ModelClass
195195
elif nnType == "FNOP-2":
196196
from fnop.training.fno_pysdc import FourierNeuralOp as ModelClass
197197
cls.model = ModelClass(**params)
198+
cls.nModelEval = nEval
198199
cls.initSweep = "NN"
199200

200201
# -------------------------------------------------------------------------
@@ -581,7 +582,7 @@ def _sweep(self, k):
581582
for c, f in zip(current, state):
582583
np.copyto(f.data, c.data)
583584
uState = self._toNumpy(state)
584-
uNext = self.model(uState)
585+
uNext = self.model(uState, nEval=self.nModelEval)
585586
np.clip(uNext[2], a_min=0, a_max=1, out=uNext[2]) # temporary : clip buoyancy between 0 and 1
586587
self._setStateWith(uNext, state)
587588
solver.state = state

0 commit comments

Comments
 (0)