4444
4545Warning
4646-------
47- To use MPI collective writing, you need to call first the class methods :class:`Rectilinear.initMPI ` (cf their docstring).
47+ To use MPI collective writing, you need to call first the class methods :class:`Rectilinear.setupMPI ` (cf their docstring).
4848Also, `Rectilinear.setHeader` **must be given the global grids coordinates**, whether the code is run in parallel or not.
4949"""
5050import os
5151import numpy as np
5252from typing import Type , TypeVar
5353import logging
54- import itertools
5554
5655T = TypeVar ("T" )
5756
6160 except ImportError :
6261 pass
6362 from mpi4py import MPI
63+ from mpi4py .util .dtlib import from_numpy_dtype as MPI_DTYPE
6464except ImportError :
6565
6666 class MPI :
6767 COMM_WORLD = None
6868 Intracomm = T
69+ File = T
70+ Datatype = T
71+
72+ def MPI_DTYPE ():
73+ pass
6974
7075
7176# Supported data types
@@ -412,6 +417,8 @@ def setHeader(self, nVar, coords):
412417 coords = self .setupCoords (* coords )
413418 self .header = {"nVar" : int (nVar ), "coords" : coords }
414419 self .nItems = nVar * self .nDoF
420+ if self .MPI_ON :
421+ self .MPI_SETUP ()
415422
416423 @property
417424 def hInfos (self ):
@@ -433,6 +440,8 @@ def readHeader(self, f):
433440 gridSizes = np .fromfile (f , dtype = np .int32 , count = dim )
434441 coords = [np .fromfile (f , dtype = np .float64 , count = n ) for n in gridSizes ]
435442 self .setHeader (nVar , coords )
443+ if self .MPI_ON :
444+ self .MPI_SETUP ()
436445
437446 def reshape (self , fields : np .ndarray ):
438447 """Reshape the fields to a N-d array (inplace operation)"""
@@ -493,7 +502,6 @@ def toVTR(self, baseName, varNames, idxFormat="{:06d}"):
493502 # MPI-parallel implementation
494503 # -------------------------------------------------------------------------
495504 comm : MPI .Intracomm = None
496- _nCollectiveIO = None
497505
498506 @classmethod
499507 def setupMPI (cls , comm : MPI .Intracomm , iLoc , nLoc ):
@@ -513,21 +521,9 @@ def setupMPI(cls, comm: MPI.Intracomm, iLoc, nLoc):
513521 cls .comm = comm
514522 cls .iLoc = iLoc
515523 cls .nLoc = nLoc
516- cls .mpiFile = None
517- cls ._nCollectiveIO = None
518-
519- @property
520- def nCollectiveIO (self ):
521- """
522- Number of collective IO operations over all processes, when reading or writing a field.
523-
524- Returns:
525- --------
526- int: Number of collective IO accesses
527- """
528- if self ._nCollectiveIO is None :
529- self ._nCollectiveIO = self .comm .allreduce (self .nVar * np .prod (self .nLoc [:- 1 ]), op = MPI .MAX )
530- return self ._nCollectiveIO
524+ cls .mpiFile : MPI .File = None
525+ cls .mpiType : MPI .Datatype = None
526+ cls .mpiFileType : MPI .Datatype = None
531527
532528 @property
533529 def MPI_ON (self ):
@@ -543,6 +539,16 @@ def MPI_ROOT(self):
543539 return True
544540 return self .comm .Get_rank () == 0
545541
542+ def MPI_SETUP (self ):
543+ """Setup subarray masks for each processes"""
544+ self .mpiType = MPI_DTYPE (self .dtype )
545+ self .mpiFileType = self .mpiType .Create_subarray (
546+ [self .nVar , * self .gridSizes ], # Global array sizes
547+ [self .nVar , * self .nLoc ], # Local array sizes
548+ [0 , * self .iLoc ], # Global starting indices of local blocks
549+ )
550+ self .mpiFileType .Commit ()
551+
546552 def MPI_FILE_OPEN (self , mode ):
547553 """Open the binary file in MPI mode"""
548554 amode = {
@@ -567,7 +573,8 @@ def MPI_WRITE_AT_ALL(self, offset, data: np.ndarray):
567573 data : np.ndarray
568574 Data to be written in the binary file.
569575 """
570- self .mpiFile .Write_at_all (offset , data )
576+ self .mpiFile .Set_view (disp = offset , etype = self .mpiType , filetype = self .mpiFileType )
577+ self .mpiFile .Write_all (data )
571578
572579 def MPI_READ_AT_ALL (self , offset , data : np .ndarray ):
573580 """
@@ -581,7 +588,8 @@ def MPI_READ_AT_ALL(self, offset, data: np.ndarray):
581588 data : np.ndarray
582589 Array on which to read the data from the binary file.
583590 """
584- self .mpiFile .Read_at_all (offset , data )
591+ self .mpiFile .Set_view (disp = offset , etype = self .mpiType , filetype = self .mpiFileType )
592+ self .mpiFile .Read_all (data )
585593
586594 def MPI_FILE_CLOSE (self ):
587595 """Close the binary file in MPI mode"""
@@ -632,33 +640,15 @@ def addField(self, time, field):
632640 * self .nLoc ,
633641 ), f"expected { (self .nVar , * self .nLoc )} shape, got { field .shape } "
634642
635- offset0 = self .fileSize
643+ offset = self .fileSize
636644 self .MPI_FILE_OPEN (mode = "a" )
637- nWrites = 0
638- nCollectiveIO = self .nCollectiveIO
639645
640646 if self .MPI_ROOT :
641647 self .MPI_WRITE (np .array (time , dtype = T_DTYPE ))
642- offset0 += self .tSize
643-
644- for (iVar , * iBeg ) in itertools .product (range (self .nVar ), * [range (n ) for n in self .nLoc [:- 1 ]]):
645- offset = offset0 + self .iPos (iVar , iBeg ) * self .itemSize
646- self .MPI_WRITE_AT_ALL (offset , field [(iVar , * iBeg )])
647- nWrites += 1
648-
649- for _ in range (nCollectiveIO - nWrites ):
650- # Additional collective write to catch up with other processes
651- self .MPI_WRITE_AT_ALL (offset0 , field [:0 ])
652-
648+ offset += self .tSize
649+ self .MPI_WRITE_AT_ALL (offset , field )
653650 self .MPI_FILE_CLOSE ()
654651
655- def iPos (self , iVar , iX ):
656- iPos = iVar * self .nDoF
657- for axis in range (self .dim - 1 ):
658- iPos += (self .iLoc [axis ] + iX [axis ]) * np .prod (self .gridSizes [axis + 1 :])
659- iPos += self .iLoc [- 1 ]
660- return iPos
661-
662652 def readField (self , idx ):
663653 """
664654 Read one field stored in the binary file, corresponding to the given
@@ -684,26 +674,15 @@ def readField(self, idx):
684674 return super ().readField (idx )
685675
686676 idx = self .formatIndex (idx )
687- offset0 = self .hSize + idx * (self .tSize + self .fSize )
677+ offset = self .hSize + idx * (self .tSize + self .fSize )
688678 with open (self .fileName , "rb" ) as f :
689- t = float (np .fromfile (f , dtype = T_DTYPE , count = 1 , offset = offset0 )[0 ])
690- offset0 += self .tSize
679+ t = float (np .fromfile (f , dtype = T_DTYPE , count = 1 , offset = offset )[0 ])
680+ offset += self .tSize
691681
692682 field = np .empty ((self .nVar , * self .nLoc ), dtype = self .dtype )
693683
694684 self .MPI_FILE_OPEN (mode = "r" )
695- nReads = 0
696- nCollectiveIO = self .nCollectiveIO
697-
698- for (iVar , * iBeg ) in itertools .product (range (self .nVar ), * [range (n ) for n in self .nLoc [:- 1 ]]):
699- offset = offset0 + self .iPos (iVar , iBeg ) * self .itemSize
700- self .MPI_READ_AT_ALL (offset , field [(iVar , * iBeg )])
701- nReads += 1
702-
703- for _ in range (nCollectiveIO - nReads ):
704- # Additional collective read to catch up with other processes
705- self .MPI_READ_AT_ALL (offset0 , field [:0 ])
706-
685+ self .MPI_READ_AT_ALL (offset , field )
707686 self .MPI_FILE_CLOSE ()
708687
709688 return t , field
0 commit comments