4444
4545Warning
4646-------
47- To use MPI collective writing, you need to call first the class methods :class:`Rectilinear.initMPI ` (cf their docstring).
47+ To use MPI collective writing, you need to call first the class methods :class:`Rectilinear.setupMPI ` (cf their docstring).
4848Also, `Rectilinear.setHeader` **must be given the global grids coordinates**, whether the code is run in parallel or not.
4949"""
5050import os
5151import numpy as np
5252from typing import Type , TypeVar
5353import logging
54- import itertools
5554
5655T = TypeVar ("T" )
5756
6160 except ImportError :
6261 pass
6362 from mpi4py import MPI
63+ from mpi4py .util .dtlib import from_numpy_dtype as MPI_DTYPE
6464except ImportError :
6565
6666 class MPI :
6767 COMM_WORLD = None
6868 Intracomm = T
69+ File = T
70+ Datatype = T
71+
72+ def MPI_DTYPE ():
73+ pass
6974
7075
7176# Supported data types
@@ -493,7 +498,6 @@ def toVTR(self, baseName, varNames, idxFormat="{:06d}"):
493498 # MPI-parallel implementation
494499 # -------------------------------------------------------------------------
495500 comm : MPI .Intracomm = None
496- _nCollectiveIO = None
497501
498502 @classmethod
499503 def setupMPI (cls , comm : MPI .Intracomm , iLoc , nLoc ):
@@ -513,21 +517,9 @@ def setupMPI(cls, comm: MPI.Intracomm, iLoc, nLoc):
513517 cls .comm = comm
514518 cls .iLoc = iLoc
515519 cls .nLoc = nLoc
516- cls .mpiFile = None
517- cls ._nCollectiveIO = None
518-
519- @property
520- def nCollectiveIO (self ):
521- """
522- Number of collective IO operations over all processes, when reading or writing a field.
523-
524- Returns:
525- --------
526- int: Number of collective IO accesses
527- """
528- if self ._nCollectiveIO is None :
529- self ._nCollectiveIO = self .comm .allreduce (self .nVar * np .prod (self .nLoc [:- 1 ]), op = MPI .MAX )
530- return self ._nCollectiveIO
520+ cls .mpiFile : MPI .File = None
521+ cls .mpiType : MPI .Datatype = None
522+ cls .mpiFileType : MPI .Datatype = None
531523
532524 @property
533525 def MPI_ON (self ):
@@ -543,13 +535,25 @@ def MPI_ROOT(self):
543535 return True
544536 return self .comm .Get_rank () == 0
545537
538+ def MPI_SETUP_FILETYPE (self ):
539+ """Setup subarray masks for each processes"""
540+ self .mpiType = MPI_DTYPE (self .dtype )
541+ self .mpiFileType = self .mpiType .Create_subarray (
542+ [self .nVar , * self .gridSizes ], # Global array sizes
543+ [self .nVar , * self .nLoc ], # Local array sizes
544+ [0 , * self .iLoc ], # Global starting indices of local blocks
545+ )
546+ self .mpiFileType .Commit ()
547+
546548 def MPI_FILE_OPEN (self , mode ):
547549 """Open the binary file in MPI mode"""
548550 amode = {
549551 "r" : MPI .MODE_RDONLY ,
550552 "a" : MPI .MODE_WRONLY | MPI .MODE_APPEND ,
551553 }[mode ]
552554 self .mpiFile = MPI .File .Open (self .comm , self .fileName , amode )
555+ if self .mpiType is None :
556+ self .MPI_SETUP_FILETYPE ()
553557
554558 def MPI_WRITE (self , data ):
555559 """Write data (np.ndarray) in the binary file in MPI mode, at the current file cursor position."""
@@ -567,7 +571,8 @@ def MPI_WRITE_AT_ALL(self, offset, data: np.ndarray):
567571 data : np.ndarray
568572 Data to be written in the binary file.
569573 """
570- self .mpiFile .Write_at_all (offset , data )
574+ self .mpiFile .Set_view (disp = offset , etype = self .mpiType , filetype = self .mpiFileType )
575+ self .mpiFile .Write_all (data )
571576
572577 def MPI_READ_AT_ALL (self , offset , data : np .ndarray ):
573578 """
@@ -581,7 +586,8 @@ def MPI_READ_AT_ALL(self, offset, data: np.ndarray):
581586 data : np.ndarray
582587 Array on which to read the data from the binary file.
583588 """
584- self .mpiFile .Read_at_all (offset , data )
589+ self .mpiFile .Set_view (disp = offset , etype = self .mpiType , filetype = self .mpiFileType )
590+ self .mpiFile .Read_all (data )
585591
586592 def MPI_FILE_CLOSE (self ):
587593 """Close the binary file in MPI mode"""
@@ -632,33 +638,15 @@ def addField(self, time, field):
632638 * self .nLoc ,
633639 ), f"expected { (self .nVar , * self .nLoc )} shape, got { field .shape } "
634640
635- offset0 = self .fileSize
641+ offset = self .fileSize
636642 self .MPI_FILE_OPEN (mode = "a" )
637- nWrites = 0
638- nCollectiveIO = self .nCollectiveIO
639643
640644 if self .MPI_ROOT :
641645 self .MPI_WRITE (np .array (time , dtype = T_DTYPE ))
642- offset0 += self .tSize
643-
644- for (iVar , * iBeg ) in itertools .product (range (self .nVar ), * [range (n ) for n in self .nLoc [:- 1 ]]):
645- offset = offset0 + self .iPos (iVar , iBeg ) * self .itemSize
646- self .MPI_WRITE_AT_ALL (offset , field [(iVar , * iBeg )])
647- nWrites += 1
648-
649- for _ in range (nCollectiveIO - nWrites ):
650- # Additional collective write to catch up with other processes
651- self .MPI_WRITE_AT_ALL (offset0 , field [:0 ])
652-
646+ offset += self .tSize
647+ self .MPI_WRITE_AT_ALL (offset , field )
653648 self .MPI_FILE_CLOSE ()
654649
655- def iPos (self , iVar , iX ):
656- iPos = iVar * self .nDoF
657- for axis in range (self .dim - 1 ):
658- iPos += (self .iLoc [axis ] + iX [axis ]) * np .prod (self .gridSizes [axis + 1 :])
659- iPos += self .iLoc [- 1 ]
660- return iPos
661-
662650 def readField (self , idx ):
663651 """
664652 Read one field stored in the binary file, corresponding to the given
@@ -684,26 +672,15 @@ def readField(self, idx):
684672 return super ().readField (idx )
685673
686674 idx = self .formatIndex (idx )
687- offset0 = self .hSize + idx * (self .tSize + self .fSize )
675+ offset = self .hSize + idx * (self .tSize + self .fSize )
688676 with open (self .fileName , "rb" ) as f :
689- t = float (np .fromfile (f , dtype = T_DTYPE , count = 1 , offset = offset0 )[0 ])
690- offset0 += self .tSize
677+ t = float (np .fromfile (f , dtype = T_DTYPE , count = 1 , offset = offset )[0 ])
678+ offset += self .tSize
691679
692680 field = np .empty ((self .nVar , * self .nLoc ), dtype = self .dtype )
693681
694682 self .MPI_FILE_OPEN (mode = "r" )
695- nReads = 0
696- nCollectiveIO = self .nCollectiveIO
697-
698- for (iVar , * iBeg ) in itertools .product (range (self .nVar ), * [range (n ) for n in self .nLoc [:- 1 ]]):
699- offset = offset0 + self .iPos (iVar , iBeg ) * self .itemSize
700- self .MPI_READ_AT_ALL (offset , field [(iVar , * iBeg )])
701- nReads += 1
702-
703- for _ in range (nCollectiveIO - nReads ):
704- # Additional collective read to catch up with other processes
705- self .MPI_READ_AT_ALL (offset0 , field [:0 ])
706-
683+ self .MPI_READ_AT_ALL (offset , field )
707684 self .MPI_FILE_CLOSE ()
708685
709686 return t , field
0 commit comments