1111# Dedalus import
1212from dedalus .core .system import CoeffSystem
1313from dedalus .tools .array import apply_sparse , csr_matvecs
14+ from mpi4py import MPI
1415
1516# QMat imports
1617from qmat .qcoeff .collocation import Collocation
@@ -205,11 +206,7 @@ def setupNN(cls, nnType, nEval=1, initSweep="NN", modelIsCopy=False, **params):
205206 # -------------------------------------------------------------------------
206207 @classmethod
207208 def getMaxOrder (cls ):
208- # TODO : adapt to non-LEGENDRE node distributions
209- M = len (cls .nodes )
210- return 2 * M if cls .quadType == 'GAUSS' else \
211- 2 * M - 1 if cls .quadType .startswith ('RADAU' ) else \
212- 2 * M - 2 # LOBATTO
209+ return cls .coll .order
213210
214211 @classmethod
215212 def getInfos (cls ):
@@ -378,7 +375,6 @@ def _evalF(self, F, time, dt, wall_time):
378375 wall_time : float
379376 Current wall time.
380377 """
381-
382378 solver = self .solver
383379 # Evaluate non linear term on current state
384380 t0 = solver .sim_time
@@ -614,6 +610,7 @@ def _sweep(self, k):
614610 continue
615611
616612 tEval = t0 + dt * tau [m ]
613+
617614 # In case NN is used for initial guess (last sweep only)
618615 if self .initSweep == "NN" and k == (self .nSweeps - 1 ):
619616 # => evaluate current state with NN to be used
@@ -719,3 +716,137 @@ def step(self, dt, wall_time):
719716 self .solver .sim_time += dt
720717 self .firstEval = True
721718 self .firstStep = False
719+
720+ def initSpaceTimeMPI (nProcSpace = None , nProcTime = None , groupTime = False ):
721+
722+ gComm = MPI .COMM_WORLD
723+ gRank = gComm .Get_rank ()
724+ gSize = gComm .Get_size ()
725+
726+ if (nProcTime is None ) and (nProcSpace is None ):
727+ nProcTime = 1
728+ nProcSpace = gSize // nProcTime
729+ elif nProcSpace is None :
730+ nProcSpace = gSize // nProcTime
731+ elif nProcTime is None :
732+ nProcTime = gSize // nProcSpace
733+
734+ if gRank == 0 :
735+ print ("Starting space-time MPI initialization ..." )
736+
737+ # Check for inadequate decomposition
738+ if (gSize != nProcSpace * nProcTime ) and (gSize != 1 ):
739+ raise ValueError (f'product of nps ({ nProcSpace } ) with npt ({ nProcTime } ) is not '
740+ f'equal to the total number of processes ({ gSize } )' )
741+
742+ # Information message
743+ if gSize == 1 :
744+ print (" -- no parallelisation at all" )
745+ return gComm , None , None
746+ else :
747+ if nProcSpace != 1 :
748+ if gRank == 0 :
749+ print (" -- space parallelisation activated : {} mpi processes"
750+ .format (nProcSpace ))
751+ else :
752+ if gRank == 0 :
753+ print (" -- no space parallelisation" )
754+ if nProcTime != 1 :
755+ if gRank == 0 :
756+ print (" -- time parallelisation activated : {} mpi processes"
757+ .format (nProcTime ))
758+ else :
759+ if gRank == 0 :
760+ print (" -- no time parallelisation" )
761+ if gRank == 0 :
762+ print (' -- finished MPI initialization' )
763+
764+ # MPI decomposition -- space are close
765+ if groupTime :
766+ sColor = gRank % nProcTime
767+ sComm = gComm .Split (sColor , gRank )
768+ gComm .Barrier ()
769+ tColor = (gRank - gRank % nProcTime ) / nProcTime
770+ tComm = gComm .Split (tColor , gRank )
771+ gComm .Barrier ()
772+ else :
773+ tColor = gRank % nProcSpace
774+ tComm = gComm .Split (tColor , gRank )
775+ gComm .Barrier ()
776+ sColor = (gRank - gRank % nProcSpace ) / nProcSpace
777+ sComm = gComm .Split (sColor , gRank )
778+ gComm .Barrier ()
779+
780+ return gComm , sComm , tComm
781+
782+
783+ class SDCIMEX_MPI (SpectralDeferredCorrectionIMEX ):
784+
785+ comm :MPI .Intracomm = None
786+
787+ @classmethod
788+ def initSpaceTimeComms (cls , nProcSpace = None , groupTime = False ):
789+ gComm , sComm , cls .comm = initSpaceTimeMPI (nProcSpace , cls .getM (), groupTime )
790+ return gComm , sComm , cls .comm
791+
792+ @property
793+ def rank (self ):
794+ return self .comm .Get_rank ()
795+
796+ def __init__ (self , solver ):
797+
798+ assert isinstance (self .comm , MPI .Intracomm ), "comm is not a MPI communicator"
799+
800+ # Store class attributes as instance attributes
801+ self .infos = self .getInfos ()
802+
803+ # Store solver as attribute
804+ self .solver = solver
805+ self .subproblems = [sp for sp in solver .subproblems if sp .size ]
806+ self .stages = self .M # need this for solver.log_stats()
807+
808+ # Create coefficient systems for steps history
809+ c = lambda : CoeffSystem (solver .subproblems , dtype = solver .dtype )
810+ self .MX0 , self .RHS = c (), c ()
811+ self .LX = deque ([c () for _ in range (2 )])
812+ self .F = deque ([c () for _ in range (2 )])
813+
814+ # Attributes
815+ self .axpy = blas .get_blas_funcs ('axpy' , dtype = solver .dtype )
816+ self .dt = None
817+
818+ self .firstEval = (self .rank == 0 )
819+ self .firstStep = True
820+
821+
822+ def _updateLHS (self , dt , init = False ):
823+ """Update LHS and LHS solvers for each subproblem
824+
825+ Parameters
826+ ----------
827+ dt : float
828+ Time-step for the updated LHS.
829+ init : bool, optional
830+ Wether or not initialize the LHS_solvers attribute for each
831+ subproblem. The default is False.
832+ """
833+ # Attribute references
834+ qI = self .QDeltaI
835+ solver = self .solver
836+
837+ # Update LHS and LHS solvers for each subproblems
838+ for sp in solver .subproblems :
839+ if init :
840+ # Potentially instantiate list of solver (ony first time step)
841+ sp .LHS_solvers = [[None for _ in range (self .M )] for _ in range (self .nSweeps )]
842+ for k in range (self .nSweeps ):
843+ m = self .rank
844+ if solver .store_expanded_matrices :
845+ raise NotImplementedError ("code correction required" )
846+ np .copyto (sp .LHS .data , sp .M_exp .data )
847+ self .axpy (a = dt * qI [k , m , m ], x = sp .L_exp .data , y = sp .LHS .data )
848+ else :
849+ sp .LHS = (sp .M_min + dt * qI [k , m , m ]* sp .L_min )
850+ sp .LHS_solvers [k ][m ] = solver .matsolver (sp .LHS , solver )
851+ if self .initSweep == "QDELTA" :
852+ raise NotImplementedError ()
0 commit comments