Skip to content

Commit aacb5d4

Browse files
committed
TL: adapted to last changes on fnop + some debug
1 parent fdf2a3c commit aacb5d4

File tree

1 file changed

+19
-8
lines changed
  • pySDC/playgrounds/dedalus

1 file changed

+19
-8
lines changed

pySDC/playgrounds/dedalus/sdc.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ def setupNN(cls, nnType, **params):
193193
if nnType == "FNOP-1":
194194
from fnop.inference.inference import FNOInference as ModelClass
195195
elif nnType == "FNOP-2":
196-
from fnop.fno import FourierNeuralOp as ModelClass
196+
from fnop.training.fno_pysdc import FourierNeuralOp as ModelClass
197197
cls.model = ModelClass(**params)
198198
cls.initSweep = "NN"
199199

@@ -430,13 +430,15 @@ def _presetStateCoeffSpace(self, state):
430430
def _toNumpy(self, state):
431431
"""Extract from state fields a 3D numpy array containing ux, uz, b and p,
432432
to be given to a NN model."""
433+
for field in state:
434+
field.change_scales(1)
435+
field.require_grid_space()
433436
return np.asarray(
434437
# ux , uz , b , p
435438
[state[2].data[0], state[2].data[1], state[1].data, state[0].data])
436439

437440
def _setStateWith(self, u, state):
438-
"""Write a 3D numpy array containing ux, uz, b and p into a dedalus state.
439-
Warning : state has to be in grid space"""
441+
"""Write a 3D numpy array containing ux, uz, b and p into a dedalus state."""
440442
np.copyto(state[2].data[0], u[0]) # ux
441443
np.copyto(state[2].data[1], u[1]) # uz
442444
np.copyto(state[1].data, u[2]) # b
@@ -510,6 +512,11 @@ def _initSweep(self):
510512
np.copyto(LXk[m].data, LXk[0].data)
511513
np.copyto(Fk[m].data, Fk[0].data)
512514

515+
elif self.initSweep == "NN":
516+
# nothing to do, initialization of tendencies already done
517+
# during last sweep ...
518+
pass
519+
513520
else:
514521
raise NotImplementedError(f'initSweep={self.initSweep}')
515522

@@ -565,19 +572,23 @@ def _sweep(self, k):
565572
if self.initSweep == "NN" and k == self.nSweeps-1:
566573
# => evaluate current state with NN to be used
567574
# for the tendencies at k=0 for the initial guess of next step
568-
uState = self._toNumpy(solver.state)
575+
current = solver.state
576+
state = [field.copy() for field in current]
577+
uState = self._toNumpy(state)
569578
uNext = self.model(uState)
570-
self._setStateWith(uNext, solver.state)
579+
np.clip(uNext[2], a_min=0, a_max=1, out=uNext[2]) # temporary : clip buoyancy between 0 and 1
580+
self._setStateWith(uNext, state)
581+
solver.state = state
571582
tEval += dt
572583

573-
# Evaluate and store LX with current state
574-
self._evalLX(LXk1[m])
575584
# Evaluate and store F(X, t) with current state
576585
self._evalF(Fk1[m], tEval, dt, wall_time)
586+
# Evaluate and store LX with current state
587+
self._evalLX(LXk1[m])
577588

578589
if self.initSweep == "NN" and k == self.nSweeps-1:
579590
# Reset state if it was used for NN initial guess
580-
self._setStateWith(uState, solver.state)
591+
solver.state = current
581592

582593
# Inverse position for iterate k and k+1 in storage
583594
# ie making the new evaluation the old for next iteration

0 commit comments

Comments
 (0)