@@ -955,3 +955,47 @@ def step(self, dt, wall_time):
955955 # Only last rank (i.e node) will be allowed to (eventually) write outputs
956956 if self .rank != self .M - 1 :
957957 self .firstEval = False
958+
959+
960+ class SDCIMEX_MPI2 (SDCIMEX_MPI ):
961+
962+ def _broadcastState (self ):
963+ state = self .solver .state
964+ sizes = [f .data .size for f in state ]
965+ buffer = np .empty (sum (sizes ), dtype = state [0 ].data .dtype )
966+ rank , M = self .rank , self .M
967+
968+ if rank == M - 1 : # copy last rank state into buffer
969+ pos = 0
970+ for f , size in zip (state , sizes ):
971+ np .copyto (buffer [pos :size ], f .data .flat )
972+ pos += size
973+
974+ self .comm .Bcast (buffer , root = self .M - 1 )
975+
976+ if rank != M - 1 : # copy buffer data into state
977+ pos = 0
978+ for f , size in zip (state , sizes ):
979+ np .copyto (f .data , buffer [pos :size ].reshape (f .data .shape ))
980+ pos += size
981+
982+ def _computeMX0 (self , MX0 :CoeffSystem ):
983+ """
984+ Compute MX0 term used in RHS of both initStep and sweep methods
985+
986+ Update the MX0 attribute of the timestepper object.
987+ """
988+ super (SDCIMEX_MPI , self )._computeMX0 (MX0 )
989+
990+ def _initSweep (self ):
991+ t0 , dt , wall_time = self .solver .sim_time , self .dt , self .wall_time
992+ Fk , LXk = self .F [0 ], self .LX [0 ]
993+ if self .initSweep == 'COPY' :
994+ self ._evalLX (LXk )
995+ self ._evalF (Fk , t0 , dt , wall_time )
996+ else :
997+ raise NotImplementedError ()
998+
999+ def step (self , dt , wall_time ):
1000+ super ().step (dt , wall_time )
1001+ self ._broadcastState ()
0 commit comments