Skip to content

Commit 4f18b6b

Browse files
committed
TL: finished MPI time-parallel SDC implementation, test OK
1 parent b85fd2d commit 4f18b6b

File tree

1 file changed

+119
-14
lines changed
  • pySDC/playgrounds/dedalus

1 file changed

+119
-14
lines changed

pySDC/playgrounds/dedalus/sdc.py

Lines changed: 119 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,7 @@ def leftIsNode(self):
286286
def doProlongation(self):
287287
return not self.rightIsNode or self.forceProl
288288

289-
def _computeMX0(self, MX0):
289+
def _computeMX0(self, MX0:CoeffSystem):
290290
"""
291291
Compute MX0 term used in RHS of both initStep and sweep methods
292292
@@ -380,6 +380,7 @@ def _evalF(self, F, time, dt, wall_time):
380380
t0 = solver.sim_time
381381
solver.sim_time = time
382382
if self.firstEval:
383+
# Eventually write field in file
383384
solver.evaluator.evaluate_scheduled(
384385
wall_time=wall_time, timestep=dt, sim_time=time,
385386
iteration=solver.iteration)
@@ -596,7 +597,7 @@ def _sweep(self, k):
596597
axpy(a=-dt*qE[k, m, i], x=Fk[i].data, y=RHS.data)
597598
axpy(a=dt*qI[k, m, i], x=LXk[i].data, y=RHS.data)
598599

599-
# Add LX terms from iteration k+1 and current nodes
600+
# Add LX terms from iteration k and current nodes
600601
axpy(a=dt*qI[k, m, m], x=LXk[m].data, y=RHS.data)
601602

602603
# Solve system and store node solution in solver state
@@ -717,6 +718,7 @@ def step(self, dt, wall_time):
717718
self.firstEval = True
718719
self.firstStep = False
719720

721+
720722
def initSpaceTimeMPI(nProcSpace=None, nProcTime=None, groupTime=False):
721723

722724
gComm = MPI.COMM_WORLD
@@ -776,7 +778,7 @@ def initSpaceTimeMPI(nProcSpace=None, nProcTime=None, groupTime=False):
776778
sColor = (gRank - gRank % nProcSpace) / nProcSpace
777779
sComm = gComm.Split(sColor, gRank)
778780
gComm.Barrier()
779-
781+
780782
return gComm, sComm, tComm
781783

782784

@@ -788,14 +790,16 @@ class SDCIMEX_MPI(SpectralDeferredCorrectionIMEX):
788790
def initSpaceTimeComms(cls, nProcSpace=None, groupTime=False):
789791
gComm, sComm, cls.comm = initSpaceTimeMPI(nProcSpace, cls.getM(), groupTime)
790792
return gComm, sComm, cls.comm
791-
793+
792794
@property
793795
def rank(self):
794796
return self.comm.Get_rank()
795797

796798
def __init__(self, solver):
797799

798800
assert isinstance(self.comm, MPI.Intracomm), "comm is not a MPI communicator"
801+
assert self.diagonal, "MPI parallelization works only with diagonal SDC"
802+
assert not self.forceProl, "MPI parallelization not implemented with forceProl"
799803

800804
# Store class attributes as instance attributes
801805
self.infos = self.getInfos()
@@ -814,10 +818,9 @@ def __init__(self, solver):
814818
# Attributes
815819
self.axpy = blas.get_blas_funcs('axpy', dtype=solver.dtype)
816820
self.dt = None
817-
818-
self.firstEval = (self.rank == 0)
819-
self.firstStep = True
820821

822+
self.firstEval = (self.rank == self.M-1)
823+
self.firstStep = True
821824

822825
def _updateLHS(self, dt, init=False):
823826
"""Update LHS and LHS solvers for each subproblem
@@ -831,22 +834,124 @@ def _updateLHS(self, dt, init=False):
831834
subproblem. The default is False.
832835
"""
833836
# Attribute references
834-
qI = self.QDeltaI
837+
m = self.rank
838+
qI = self.QDeltaI[:, m, m]
835839
solver = self.solver
836840

837841
# Update LHS and LHS solvers for each subproblems
838842
for sp in solver.subproblems:
839843
if init:
840844
# Potentially instantiate list of solver (ony first time step)
841-
sp.LHS_solvers = [[None for _ in range(self.M)] for _ in range(self.nSweeps)]
845+
sp.LHS_solvers = [None for _ in range(self.nSweeps)]
842846
for k in range(self.nSweeps):
843-
m = self.rank
844847
if solver.store_expanded_matrices:
845848
raise NotImplementedError("code correction required")
846849
np.copyto(sp.LHS.data, sp.M_exp.data)
847-
self.axpy(a=dt*qI[k, m, m], x=sp.L_exp.data, y=sp.LHS.data)
850+
self.axpy(a=dt*qI[k], x=sp.L_exp.data, y=sp.LHS.data)
848851
else:
849-
sp.LHS = (sp.M_min + dt*qI[k, m, m]*sp.L_min)
850-
sp.LHS_solvers[k][m] = solver.matsolver(sp.LHS, solver)
852+
sp.LHS = (sp.M_min + dt*qI[k]*sp.L_min)
853+
sp.LHS_solvers[k] = solver.matsolver(sp.LHS, solver)
851854
if self.initSweep == "QDELTA":
852-
raise NotImplementedError()
855+
raise NotImplementedError()
856+
857+
def _solveAndStoreState(self, k):
858+
"""
859+
Solve LHS * X = RHS using the LHS associated to a given node,
860+
and store X into the solver state.
861+
It uses the current RHS attribute of the object.
862+
863+
Parameters
864+
----------
865+
k : int
866+
Sweep index (0 for the first sweep).
867+
"""
868+
# Attribute references
869+
solver = self.solver
870+
RHS = self.RHS
871+
872+
self._presetStateCoeffSpace(solver.state)
873+
874+
# Solve and store for each subproblem
875+
for sp in solver.subproblems:
876+
# Slice out valid subdata, skipping invalid components
877+
spRHS = RHS.get_subdata(sp)
878+
spX = sp.LHS_solvers[k].solve(spRHS) # CREATES TEMPORARY
879+
sp.scatter_inputs(spX, solver.state)
880+
881+
def _computeMX0(self, MX0:CoeffSystem):
882+
"""
883+
Compute MX0 term used in RHS of both initStep and sweep methods
884+
885+
Update the MX0 attribute of the timestepper object.
886+
"""
887+
if self.rank == self.M-1: # only last node compute MX0
888+
super()._computeMX0(MX0)
889+
# Broadcast MX0 to all nodes
890+
self.comm.Bcast(MX0.data, root=self.M-1)
891+
892+
def _initSweep(self):
893+
t0, dt, wall_time = self.solver.sim_time, self.dt, self.wall_time
894+
Fk, LXk = self.F[0], self.LX[0]
895+
if self.initSweep == 'COPY':
896+
if self.rank == self.M-1: # only last node evaluate
897+
self._evalLX(LXk)
898+
self._evalF(Fk, t0, dt, wall_time)
899+
# Broadcast LXk and Fk to all nodes
900+
self.comm.Bcast(LXk.data, root=self.M-1)
901+
self.comm.Bcast(Fk.data, root=self.M-1)
902+
else:
903+
raise NotImplementedError()
904+
905+
def _sweep(self, k):
906+
"""Perform a sweep for the current time-step"""
907+
# Only compute for the current node
908+
m = self.rank
909+
910+
# Attribute references
911+
tau, qI, q = self.nodes[m], self.QDeltaI[:, m, m], self.Q[:, m]
912+
solver = self.solver
913+
t0, dt, wall_time = solver.sim_time, self.dt, self.wall_time
914+
RHS, MX0 = self.RHS, self.MX0
915+
Fk, LXk, Fk1, LXk1 = self.F[0], self.LX[0], self.F[1], self.LX[1]
916+
axpy = self.axpy
917+
918+
# Build RHS
919+
if RHS.data.size:
920+
921+
# Initialize with MX0 term
922+
np.copyto(RHS.data, MX0.data)
923+
924+
# Add quadrature terms using reduced sum accross nodes
925+
recvBuf = np.zeros_like(RHS.data)
926+
sendBuf = np.zeros_like(RHS.data)
927+
for i in range(self.M-1, -1, -1): # start from last node
928+
sendBuf.fill(0)
929+
axpy(a=dt*q[i], x=Fk.data, y=sendBuf)
930+
axpy(a=-dt*q[i], x=LXk.data, y=sendBuf)
931+
self.comm.Reduce(sendBuf, recvBuf, root=i, op=MPI.SUM)
932+
RHS.data += recvBuf
933+
934+
# Add LX terms from iteration k and current nodes
935+
axpy(a=dt*qI[k], x=LXk.data, y=RHS.data)
936+
937+
# Solve system and store node solution in solver state
938+
self._solveAndStoreState(k)
939+
940+
if k < self.nSweeps-1:
941+
tEval = t0+dt*tau
942+
# Evaluate and store F(X, t) with current state
943+
self._evalF(Fk1, tEval, dt, wall_time)
944+
# Evaluate and store LX with current state
945+
self._evalLX(LXk1)
946+
947+
# Inverse position for iterate k and k+1 in storage
948+
# ie making the new evaluation the old for next iteration
949+
self.F.rotate()
950+
self.LX.rotate()
951+
952+
def step(self, dt, wall_time):
953+
super().step(dt, wall_time)
954+
955+
# Only last rank (i.e node) will be allowed to (eventually) write outputs
956+
if self.rank != self.M-1:
957+
self.firstEval = False

0 commit comments

Comments
 (0)