@@ -286,7 +286,7 @@ def leftIsNode(self):
286286 def doProlongation (self ):
287287 return not self .rightIsNode or self .forceProl
288288
289- def _computeMX0 (self , MX0 ):
289+ def _computeMX0 (self , MX0 : CoeffSystem ):
290290 """
291291 Compute MX0 term used in RHS of both initStep and sweep methods
292292
@@ -380,6 +380,7 @@ def _evalF(self, F, time, dt, wall_time):
380380 t0 = solver .sim_time
381381 solver .sim_time = time
382382 if self .firstEval :
383+ # Eventually write field in file
383384 solver .evaluator .evaluate_scheduled (
384385 wall_time = wall_time , timestep = dt , sim_time = time ,
385386 iteration = solver .iteration )
@@ -596,7 +597,7 @@ def _sweep(self, k):
596597 axpy (a = - dt * qE [k , m , i ], x = Fk [i ].data , y = RHS .data )
597598 axpy (a = dt * qI [k , m , i ], x = LXk [i ].data , y = RHS .data )
598599
599- # Add LX terms from iteration k+1 and current nodes
600+ # Add LX terms from iteration k and current nodes
600601 axpy (a = dt * qI [k , m , m ], x = LXk [m ].data , y = RHS .data )
601602
602603 # Solve system and store node solution in solver state
@@ -717,6 +718,7 @@ def step(self, dt, wall_time):
717718 self .firstEval = True
718719 self .firstStep = False
719720
721+
720722def initSpaceTimeMPI (nProcSpace = None , nProcTime = None , groupTime = False ):
721723
722724 gComm = MPI .COMM_WORLD
@@ -776,7 +778,7 @@ def initSpaceTimeMPI(nProcSpace=None, nProcTime=None, groupTime=False):
776778 sColor = (gRank - gRank % nProcSpace ) / nProcSpace
777779 sComm = gComm .Split (sColor , gRank )
778780 gComm .Barrier ()
779-
781+
780782 return gComm , sComm , tComm
781783
782784
@@ -788,14 +790,16 @@ class SDCIMEX_MPI(SpectralDeferredCorrectionIMEX):
788790 def initSpaceTimeComms (cls , nProcSpace = None , groupTime = False ):
789791 gComm , sComm , cls .comm = initSpaceTimeMPI (nProcSpace , cls .getM (), groupTime )
790792 return gComm , sComm , cls .comm
791-
793+
792794 @property
793795 def rank (self ):
794796 return self .comm .Get_rank ()
795797
796798 def __init__ (self , solver ):
797799
798800 assert isinstance (self .comm , MPI .Intracomm ), "comm is not a MPI communicator"
801+ assert self .diagonal , "MPI parallelization works only with diagonal SDC"
802+ assert not self .forceProl , "MPI parallelization not implemented with forceProl"
799803
800804 # Store class attributes as instance attributes
801805 self .infos = self .getInfos ()
@@ -814,10 +818,9 @@ def __init__(self, solver):
814818 # Attributes
815819 self .axpy = blas .get_blas_funcs ('axpy' , dtype = solver .dtype )
816820 self .dt = None
817-
818- self .firstEval = (self .rank == 0 )
819- self .firstStep = True
820821
822+ self .firstEval = (self .rank == self .M - 1 )
823+ self .firstStep = True
821824
822825 def _updateLHS (self , dt , init = False ):
823826 """Update LHS and LHS solvers for each subproblem
@@ -831,22 +834,124 @@ def _updateLHS(self, dt, init=False):
831834 subproblem. The default is False.
832835 """
833836 # Attribute references
834- qI = self .QDeltaI
837+ m = self .rank
838+ qI = self .QDeltaI [:, m , m ]
835839 solver = self .solver
836840
837841 # Update LHS and LHS solvers for each subproblems
838842 for sp in solver .subproblems :
839843 if init :
840844 # Potentially instantiate list of solver (ony first time step)
841- sp .LHS_solvers = [[ None for _ in range ( self . M )] for _ in range (self .nSweeps )]
845+ sp .LHS_solvers = [None for _ in range (self .nSweeps )]
842846 for k in range (self .nSweeps ):
843- m = self .rank
844847 if solver .store_expanded_matrices :
845848 raise NotImplementedError ("code correction required" )
846849 np .copyto (sp .LHS .data , sp .M_exp .data )
847- self .axpy (a = dt * qI [k , m , m ], x = sp .L_exp .data , y = sp .LHS .data )
850+ self .axpy (a = dt * qI [k ], x = sp .L_exp .data , y = sp .LHS .data )
848851 else :
849- sp .LHS = (sp .M_min + dt * qI [k , m , m ]* sp .L_min )
850- sp .LHS_solvers [k ][ m ] = solver .matsolver (sp .LHS , solver )
852+ sp .LHS = (sp .M_min + dt * qI [k ]* sp .L_min )
853+ sp .LHS_solvers [k ] = solver .matsolver (sp .LHS , solver )
851854 if self .initSweep == "QDELTA" :
852- raise NotImplementedError ()
855+ raise NotImplementedError ()
856+
857+ def _solveAndStoreState (self , k ):
858+ """
859+ Solve LHS * X = RHS using the LHS associated to a given node,
860+ and store X into the solver state.
861+ It uses the current RHS attribute of the object.
862+
863+ Parameters
864+ ----------
865+ k : int
866+ Sweep index (0 for the first sweep).
867+ """
868+ # Attribute references
869+ solver = self .solver
870+ RHS = self .RHS
871+
872+ self ._presetStateCoeffSpace (solver .state )
873+
874+ # Solve and store for each subproblem
875+ for sp in solver .subproblems :
876+ # Slice out valid subdata, skipping invalid components
877+ spRHS = RHS .get_subdata (sp )
878+ spX = sp .LHS_solvers [k ].solve (spRHS ) # CREATES TEMPORARY
879+ sp .scatter_inputs (spX , solver .state )
880+
881+ def _computeMX0 (self , MX0 :CoeffSystem ):
882+ """
883+ Compute MX0 term used in RHS of both initStep and sweep methods
884+
885+ Update the MX0 attribute of the timestepper object.
886+ """
887+ if self .rank == self .M - 1 : # only last node compute MX0
888+ super ()._computeMX0 (MX0 )
889+ # Broadcast MX0 to all nodes
890+ self .comm .Bcast (MX0 .data , root = self .M - 1 )
891+
892+ def _initSweep (self ):
893+ t0 , dt , wall_time = self .solver .sim_time , self .dt , self .wall_time
894+ Fk , LXk = self .F [0 ], self .LX [0 ]
895+ if self .initSweep == 'COPY' :
896+ if self .rank == self .M - 1 : # only last node evaluate
897+ self ._evalLX (LXk )
898+ self ._evalF (Fk , t0 , dt , wall_time )
899+ # Broadcast LXk and Fk to all nodes
900+ self .comm .Bcast (LXk .data , root = self .M - 1 )
901+ self .comm .Bcast (Fk .data , root = self .M - 1 )
902+ else :
903+ raise NotImplementedError ()
904+
905+ def _sweep (self , k ):
906+ """Perform a sweep for the current time-step"""
907+ # Only compute for the current node
908+ m = self .rank
909+
910+ # Attribute references
911+ tau , qI , q = self .nodes [m ], self .QDeltaI [:, m , m ], self .Q [:, m ]
912+ solver = self .solver
913+ t0 , dt , wall_time = solver .sim_time , self .dt , self .wall_time
914+ RHS , MX0 = self .RHS , self .MX0
915+ Fk , LXk , Fk1 , LXk1 = self .F [0 ], self .LX [0 ], self .F [1 ], self .LX [1 ]
916+ axpy = self .axpy
917+
918+ # Build RHS
919+ if RHS .data .size :
920+
921+ # Initialize with MX0 term
922+ np .copyto (RHS .data , MX0 .data )
923+
924+ # Add quadrature terms using reduced sum accross nodes
925+ recvBuf = np .zeros_like (RHS .data )
926+ sendBuf = np .zeros_like (RHS .data )
927+ for i in range (self .M - 1 , - 1 , - 1 ): # start from last node
928+ sendBuf .fill (0 )
929+ axpy (a = dt * q [i ], x = Fk .data , y = sendBuf )
930+ axpy (a = - dt * q [i ], x = LXk .data , y = sendBuf )
931+ self .comm .Reduce (sendBuf , recvBuf , root = i , op = MPI .SUM )
932+ RHS .data += recvBuf
933+
934+ # Add LX terms from iteration k and current nodes
935+ axpy (a = dt * qI [k ], x = LXk .data , y = RHS .data )
936+
937+ # Solve system and store node solution in solver state
938+ self ._solveAndStoreState (k )
939+
940+ if k < self .nSweeps - 1 :
941+ tEval = t0 + dt * tau
942+ # Evaluate and store F(X, t) with current state
943+ self ._evalF (Fk1 , tEval , dt , wall_time )
944+ # Evaluate and store LX with current state
945+ self ._evalLX (LXk1 )
946+
947+ # Inverse position for iterate k and k+1 in storage
948+ # ie making the new evaluation the old for next iteration
949+ self .F .rotate ()
950+ self .LX .rotate ()
951+
952+ def step (self , dt , wall_time ):
953+ super ().step (dt , wall_time )
954+
955+ # Only last rank (i.e node) will be allowed to (eventually) write outputs
956+ if self .rank != self .M - 1 :
957+ self .firstEval = False
0 commit comments