4444
4545Warning
4646-------
47- To use MPI collective writing, you need to call first the class methods :class:`Rectilinear.initMPI ` (cf their docstring).
47+ To use MPI collective writing, you need to call first the class methods :class:`Rectilinear.setupMPI ` (cf their docstring).
4848Also, `Rectilinear.setHeader` **must be given the global grids coordinates**, whether the code is run in parallel or not.
4949"""
5050import os
5151import numpy as np
5252from typing import Type , TypeVar
5353import logging
54- import itertools
5554
5655T = TypeVar ("T" )
5756
6160 except ImportError :
6261 pass
6362 from mpi4py import MPI
63+ from mpi4py .util .dtlib import from_numpy_dtype as MPI_DTYPE
6464except ImportError :
6565
6666 class MPI :
6767 COMM_WORLD = None
6868 Intracomm = T
69+ File = T
70+ Datatype = T
71+
72+ def MPI_DTYPE ():
73+ pass
6974
7075
7176# Supported data types
@@ -412,6 +417,8 @@ def setHeader(self, nVar, coords):
412417 coords = self .setupCoords (* coords )
413418 self .header = {"nVar" : int (nVar ), "coords" : coords }
414419 self .nItems = nVar * self .nDoF
420+ if self .MPI_ON :
421+ self .MPI_SETUP ()
415422
416423 @property
417424 def hInfos (self ):
@@ -433,6 +440,8 @@ def readHeader(self, f):
433440 gridSizes = np .fromfile (f , dtype = np .int32 , count = dim )
434441 coords = [np .fromfile (f , dtype = np .float64 , count = n ) for n in gridSizes ]
435442 self .setHeader (nVar , coords )
443+ if self .MPI_ON :
444+ self .MPI_SETUP ()
436445
437446 def reshape (self , fields : np .ndarray ):
438447 """Reshape the fields to a N-d array (inplace operation)"""
@@ -513,7 +522,9 @@ def setupMPI(cls, comm: MPI.Intracomm, iLoc, nLoc):
513522 cls .comm = comm
514523 cls .iLoc = iLoc
515524 cls .nLoc = nLoc
516- cls .mpiFile = None
525+ cls .mpiFile :MPI .File = None
526+ cls .mpiType :MPI .Datatype = None
527+ cls .mpiFileType :MPI .Datatype = None
517528 cls ._nCollectiveIO = None
518529
519530 @property
@@ -543,6 +554,18 @@ def MPI_ROOT(self):
543554 return True
544555 return self .comm .Get_rank () == 0
545556
557+ def MPI_SETUP (self ):
558+ """Setup subarray masks for each processes"""
559+ self .mpiType = MPI_DTYPE (self .dtype )
560+ self .mpiFileType = self .mpiType .Create_subarray (
561+ [self .nVar , * self .gridSizes ], # Global array sizes
562+ [self .nVar , * self .nLoc ], # Local array sizes
563+ [0 , * self .iLoc ] # Global starting indices of local blocks
564+ )
565+ self .mpiFileType .Commit ()
566+ print ("MPI_TYPE " , self .mpiType )
567+ print ("MPI_FILETYPE " , self .mpiFileType )
568+
546569 def MPI_FILE_OPEN (self , mode ):
547570 """Open the binary file in MPI mode"""
548571 amode = {
@@ -567,7 +590,8 @@ def MPI_WRITE_AT_ALL(self, offset, data: np.ndarray):
567590 data : np.ndarray
568591 Data to be written in the binary file.
569592 """
570- self .mpiFile .Write_at_all (offset , data )
593+ self .mpiFile .Set_view (disp = offset , etype = self .mpiType , filetype = self .mpiFileType )
594+ self .mpiFile .Write_all (data )
571595
572596 def MPI_READ_AT_ALL (self , offset , data : np .ndarray ):
573597 """
@@ -581,7 +605,8 @@ def MPI_READ_AT_ALL(self, offset, data: np.ndarray):
581605 data : np.ndarray
582606 Array on which to read the data from the binary file.
583607 """
584- self .mpiFile .Read_at_all (offset , data )
608+ self .mpiFile .Set_view (disp = offset , etype = self .mpiType , filetype = self .mpiFileType )
609+ self .mpiFile .Read_all (data )
585610
586611 def MPI_FILE_CLOSE (self ):
587612 """Close the binary file in MPI mode"""
@@ -632,33 +657,15 @@ def addField(self, time, field):
632657 * self .nLoc ,
633658 ), f"expected { (self .nVar , * self .nLoc )} shape, got { field .shape } "
634659
635- offset0 = self .fileSize
660+ offset = self .fileSize
636661 self .MPI_FILE_OPEN (mode = "a" )
637- nWrites = 0
638- nCollectiveIO = self .nCollectiveIO
639662
640663 if self .MPI_ROOT :
641664 self .MPI_WRITE (np .array (time , dtype = T_DTYPE ))
642- offset0 += self .tSize
643-
644- for (iVar , * iBeg ) in itertools .product (range (self .nVar ), * [range (n ) for n in self .nLoc [:- 1 ]]):
645- offset = offset0 + self .iPos (iVar , iBeg ) * self .itemSize
646- self .MPI_WRITE_AT_ALL (offset , field [(iVar , * iBeg )])
647- nWrites += 1
648-
649- for _ in range (nCollectiveIO - nWrites ):
650- # Additional collective write to catch up with other processes
651- self .MPI_WRITE_AT_ALL (offset0 , field [:0 ])
652-
665+ offset += self .tSize
666+ self .MPI_WRITE_AT_ALL (offset , field )
653667 self .MPI_FILE_CLOSE ()
654668
655- def iPos (self , iVar , iX ):
656- iPos = iVar * self .nDoF
657- for axis in range (self .dim - 1 ):
658- iPos += (self .iLoc [axis ] + iX [axis ]) * np .prod (self .gridSizes [axis + 1 :])
659- iPos += self .iLoc [- 1 ]
660- return iPos
661-
662669 def readField (self , idx ):
663670 """
664671 Read one field stored in the binary file, corresponding to the given
@@ -684,26 +691,15 @@ def readField(self, idx):
684691 return super ().readField (idx )
685692
686693 idx = self .formatIndex (idx )
687- offset0 = self .hSize + idx * (self .tSize + self .fSize )
694+ offset = self .hSize + idx * (self .tSize + self .fSize )
688695 with open (self .fileName , "rb" ) as f :
689- t = float (np .fromfile (f , dtype = T_DTYPE , count = 1 , offset = offset0 )[0 ])
690- offset0 += self .tSize
696+ t = float (np .fromfile (f , dtype = T_DTYPE , count = 1 , offset = offset )[0 ])
697+ offset += self .tSize
691698
692699 field = np .empty ((self .nVar , * self .nLoc ), dtype = self .dtype )
693700
694701 self .MPI_FILE_OPEN (mode = "r" )
695- nReads = 0
696- nCollectiveIO = self .nCollectiveIO
697-
698- for (iVar , * iBeg ) in itertools .product (range (self .nVar ), * [range (n ) for n in self .nLoc [:- 1 ]]):
699- offset = offset0 + self .iPos (iVar , iBeg ) * self .itemSize
700- self .MPI_READ_AT_ALL (offset , field [(iVar , * iBeg )])
701- nReads += 1
702-
703- for _ in range (nCollectiveIO - nReads ):
704- # Additional collective read to catch up with other processes
705- self .MPI_READ_AT_ALL (offset0 , field [:0 ])
706-
702+ self .MPI_READ_AT_ALL (offset , field )
707703 self .MPI_FILE_CLOSE ()
708704
709705 return t , field
0 commit comments