Skip to content

Commit b85fd2d

Browse files
committed
TL: started MPI time parallel SDC implementation
1 parent 6d64bd1 commit b85fd2d

File tree

3 files changed

+153
-102
lines changed

3 files changed

+153
-102
lines changed

pySDC/playgrounds/dedalus/mpi.py

Lines changed: 0 additions & 96 deletions
This file was deleted.
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
from time import sleep
2+
from pySDC.playgrounds.dedalus.sdc import initSpaceTimeMPI
3+
4+
gComm, sComm, tComm = initSpaceTimeMPI(nProcTime=4)
5+
6+
gRank = gComm.Get_rank()
7+
gSize = gComm.Get_size()
8+
9+
sRank = sComm.Get_rank()
10+
sSize = sComm.Get_size()
11+
12+
tRank = tComm.Get_rank()
13+
tSize = tComm.Get_size()
14+
15+
sleep(gRank*0.01)
16+
print(f"Rank {gRank}/{gSize} : sRank {sRank}/{sSize}, tRank {tRank}/{tSize}")

pySDC/playgrounds/dedalus/sdc.py

Lines changed: 137 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# Dedalus import
1212
from dedalus.core.system import CoeffSystem
1313
from dedalus.tools.array import apply_sparse, csr_matvecs
14+
from mpi4py import MPI
1415

1516
# QMat imports
1617
from qmat.qcoeff.collocation import Collocation
@@ -205,11 +206,7 @@ def setupNN(cls, nnType, nEval=1, initSweep="NN", modelIsCopy=False, **params):
205206
# -------------------------------------------------------------------------
206207
@classmethod
207208
def getMaxOrder(cls):
208-
# TODO : adapt to non-LEGENDRE node distributions
209-
M = len(cls.nodes)
210-
return 2*M if cls.quadType == 'GAUSS' else \
211-
2*M-1 if cls.quadType.startswith('RADAU') else \
212-
2*M-2 # LOBATTO
209+
return cls.coll.order
213210

214211
@classmethod
215212
def getInfos(cls):
@@ -378,7 +375,6 @@ def _evalF(self, F, time, dt, wall_time):
378375
wall_time : float
379376
Current wall time.
380377
"""
381-
382378
solver = self.solver
383379
# Evaluate non linear term on current state
384380
t0 = solver.sim_time
@@ -614,6 +610,7 @@ def _sweep(self, k):
614610
continue
615611

616612
tEval = t0+dt*tau[m]
613+
617614
# In case NN is used for initial guess (last sweep only)
618615
if self.initSweep == "NN" and k == (self.nSweeps-1):
619616
# => evaluate current state with NN to be used
@@ -719,3 +716,137 @@ def step(self, dt, wall_time):
719716
self.solver.sim_time += dt
720717
self.firstEval = True
721718
self.firstStep = False
719+
720+
def initSpaceTimeMPI(nProcSpace=None, nProcTime=None, groupTime=False):
721+
722+
gComm = MPI.COMM_WORLD
723+
gRank = gComm.Get_rank()
724+
gSize = gComm.Get_size()
725+
726+
if (nProcTime is None) and (nProcSpace is None):
727+
nProcTime = 1
728+
nProcSpace = gSize // nProcTime
729+
elif nProcSpace is None:
730+
nProcSpace = gSize // nProcTime
731+
elif nProcTime is None:
732+
nProcTime = gSize // nProcSpace
733+
734+
if gRank == 0:
735+
print("Starting space-time MPI initialization ...")
736+
737+
# Check for inadequate decomposition
738+
if (gSize != nProcSpace*nProcTime) and (gSize != 1):
739+
raise ValueError(f'product of nps ({nProcSpace}) with npt ({nProcTime}) is not '
740+
f'equal to the total number of processes ({gSize})')
741+
742+
# Information message
743+
if gSize == 1:
744+
print(" -- no parallelisation at all")
745+
return gComm, None, None
746+
else:
747+
if nProcSpace != 1:
748+
if gRank == 0:
749+
print(" -- space parallelisation activated : {} mpi processes"
750+
.format(nProcSpace))
751+
else:
752+
if gRank == 0:
753+
print(" -- no space parallelisation")
754+
if nProcTime != 1:
755+
if gRank == 0:
756+
print(" -- time parallelisation activated : {} mpi processes"
757+
.format(nProcTime))
758+
else:
759+
if gRank == 0:
760+
print(" -- no time parallelisation")
761+
if gRank == 0:
762+
print(' -- finished MPI initialization')
763+
764+
# MPI decomposition -- space are close
765+
if groupTime:
766+
sColor = gRank % nProcTime
767+
sComm = gComm.Split(sColor, gRank)
768+
gComm.Barrier()
769+
tColor = (gRank - gRank % nProcTime) / nProcTime
770+
tComm = gComm.Split(tColor, gRank)
771+
gComm.Barrier()
772+
else:
773+
tColor = gRank % nProcSpace
774+
tComm = gComm.Split(tColor, gRank)
775+
gComm.Barrier()
776+
sColor = (gRank - gRank % nProcSpace) / nProcSpace
777+
sComm = gComm.Split(sColor, gRank)
778+
gComm.Barrier()
779+
780+
return gComm, sComm, tComm
781+
782+
783+
class SDCIMEX_MPI(SpectralDeferredCorrectionIMEX):
784+
785+
comm:MPI.Intracomm = None
786+
787+
@classmethod
788+
def initSpaceTimeComms(cls, nProcSpace=None, groupTime=False):
789+
gComm, sComm, cls.comm = initSpaceTimeMPI(nProcSpace, cls.getM(), groupTime)
790+
return gComm, sComm, cls.comm
791+
792+
@property
793+
def rank(self):
794+
return self.comm.Get_rank()
795+
796+
def __init__(self, solver):
797+
798+
assert isinstance(self.comm, MPI.Intracomm), "comm is not a MPI communicator"
799+
800+
# Store class attributes as instance attributes
801+
self.infos = self.getInfos()
802+
803+
# Store solver as attribute
804+
self.solver = solver
805+
self.subproblems = [sp for sp in solver.subproblems if sp.size]
806+
self.stages = self.M # need this for solver.log_stats()
807+
808+
# Create coefficient systems for steps history
809+
c = lambda: CoeffSystem(solver.subproblems, dtype=solver.dtype)
810+
self.MX0, self.RHS = c(), c()
811+
self.LX = deque([c() for _ in range(2)])
812+
self.F = deque([c() for _ in range(2)])
813+
814+
# Attributes
815+
self.axpy = blas.get_blas_funcs('axpy', dtype=solver.dtype)
816+
self.dt = None
817+
818+
self.firstEval = (self.rank == 0)
819+
self.firstStep = True
820+
821+
822+
def _updateLHS(self, dt, init=False):
823+
"""Update LHS and LHS solvers for each subproblem
824+
825+
Parameters
826+
----------
827+
dt : float
828+
Time-step for the updated LHS.
829+
init : bool, optional
830+
Wether or not initialize the LHS_solvers attribute for each
831+
subproblem. The default is False.
832+
"""
833+
# Attribute references
834+
qI = self.QDeltaI
835+
solver = self.solver
836+
837+
# Update LHS and LHS solvers for each subproblems
838+
for sp in solver.subproblems:
839+
if init:
840+
# Potentially instantiate list of solver (ony first time step)
841+
sp.LHS_solvers = [[None for _ in range(self.M)] for _ in range(self.nSweeps)]
842+
for k in range(self.nSweeps):
843+
m = self.rank
844+
if solver.store_expanded_matrices:
845+
raise NotImplementedError("code correction required")
846+
np.copyto(sp.LHS.data, sp.M_exp.data)
847+
self.axpy(a=dt*qI[k, m, m], x=sp.L_exp.data, y=sp.LHS.data)
848+
else:
849+
sp.LHS = (sp.M_min + dt*qI[k, m, m]*sp.L_min)
850+
sp.LHS_solvers[k][m] = solver.matsolver(sp.LHS, solver)
851+
if self.initSweep == "QDELTA":
852+
raise NotImplementedError()

0 commit comments

Comments
 (0)