44Base generic script for fields IO
55"""
66import os
7+ import sys
78import numpy as np
89from typing import Type , TypeVar
9- from mpi4py import MPI
10- from time import time
10+ try :
11+ from mpi4py import MPI
12+ except ImportError :
13+ pass
14+
15+
16+ from time import time , sleep
17+ from blocks import BlockDecomposition
1118
1219T = TypeVar ("T" )
1320
@@ -230,6 +237,7 @@ def setupMPI(cls, comm:MPI.Intracomm, iLocX, nLocX):
230237 cls .comm = comm
231238 cls .iLocX = iLocX
232239 cls .nLocX = nLocX
240+ cls .mpiFile = None
233241
234242 @property
235243 def MPI_ON (self ):
@@ -241,18 +249,33 @@ def MPI_ROOT(self):
241249 if self .comm is None : return True
242250 return self .comm .Get_rank () == 0
243251
244- def MPI_FILE_OPEN (self , mode )-> MPI . File :
252+ def MPI_FILE_OPEN (self , mode ):
245253 amode = {
246254 "r" : MPI .MODE_RDONLY ,
247255 "a" : MPI .MODE_WRONLY | MPI .MODE_APPEND ,
248256 }[mode ]
249- return MPI .File .Open (self .comm , self .fileName , amode )
257+ self .mpiFile = MPI .File .Open (self .comm , self .fileName , amode )
258+
259+ def MPI_WRITE (self , data ):
260+ self .mpiFile .Write (data )
261+
262+ def MPI_WRITE_AT (self , offset , data :np .ndarray ):
263+ self .mpiFile .Write_at (offset , data )
264+
265+ def MPI_READ_AT (self , offset , data ):
266+ self .mpiFile .Read_at (offset , data )
267+
268+ def MPI_FILE_CLOSE (self ):
269+ self .mpiFile .Close ()
270+ self .mpiFile = None
250271
251272 def initialize (self ):
252273 if self .MPI_ROOT :
253274 super ().initialize ()
254- self .comm .Barrier ()
255- self .initialized = True
275+ if self .MPI_ON :
276+ self .comm .Barrier ()
277+ self .initialized = True
278+
256279
257280 def addField (self , time , field ):
258281 if not self .MPI_ON : return super ().addField (time , field )
@@ -265,15 +288,15 @@ def addField(self, time, field):
265288 f"expected { (self .nVar , self .nLocX )} shape, got { field .shape } "
266289
267290 offset0 = self .fileSize
268- mpiFile = self .MPI_FILE_OPEN (mode = "a" )
291+ self .MPI_FILE_OPEN (mode = "a" )
269292 if self .MPI_ROOT :
270- mpiFile . Write (np .array (time , dtype = T_DTYPE ))
293+ self . MPI_WRITE (np .array (time , dtype = T_DTYPE ))
271294 offset0 += self .tSize
272295
273296 for iVar in range (self .nVar ):
274297 offset = offset0 + (iVar * self .nX + self .iLocX )* self .itemSize
275- mpiFile . Write_at_all (offset , field [iVar ])
276- mpiFile . Close ()
298+ self . MPI_WRITE_AT (offset , field [iVar ])
299+ self . MPI_FILE_CLOSE ()
277300
278301
279302 def readField (self , idx ):
@@ -287,11 +310,11 @@ def readField(self, idx):
287310
288311 field = np .empty ((self .nVar , self .nLocX ), dtype = self .dtype )
289312
290- mpiFile = self .MPI_FILE_OPEN (mode = "r" )
313+ self .MPI_FILE_OPEN (mode = "r" )
291314 for iVar in range (self .nVar ):
292315 offset = offset0 + (iVar * self .nX + self .iLocX )* self .itemSize
293- mpiFile . Read_at_all (offset , field [iVar ])
294- mpiFile . Close ()
316+ self . MPI_READ_AT (offset , field [iVar ])
317+ self . MPI_FILE_CLOSE ()
295318
296319 return t , field
297320
@@ -331,9 +354,7 @@ def readHeader(self, f):
331354 # -------------------------------------------------------------------------
332355 @classmethod
333356 def setupMPI (cls , comm :MPI .Intracomm , iLocX , nLocX , iLocY , nLocY ):
334- cls .comm = comm
335- cls .iLocX = iLocX
336- cls .nLocX = nLocX
357+ super ().setupMPI (comm , iLocX , nLocX )
337358 cls .iLocY = iLocY
338359 cls .nLocY = nLocY
339360
@@ -349,18 +370,18 @@ def addField(self, time, field):
349370 f"expected { (self .nVar , self .nLocX , self .nLocY )} shape, got { field .shape } "
350371
351372 offset0 = self .fileSize
352- mpiFile = self .MPI_FILE_OPEN (mode = "a" )
373+ self .MPI_FILE_OPEN (mode = "a" )
353374 if self .MPI_ROOT :
354- mpiFile . Write (np .array (time , dtype = T_DTYPE ))
375+ self . MPI_WRITE (np .array (time , dtype = T_DTYPE ))
355376 offset0 += self .tSize
356377
357378 for iVar in range (self .nVar ):
358379 for iX in range (self .nLocX ):
359380 offset = offset0 + (
360381 iVar * self .nX * self .nY + (self .iLocX + iX )* self .nY + self .iLocY
361382 )* self .itemSize
362- mpiFile . Write_at_all (offset , field [iVar , iX ])
363- mpiFile . Close ()
383+ self . MPI_WRITE_AT (offset , field [iVar , iX ])
384+ self . MPI_FILE_CLOSE ()
364385
365386
366387 def readField (self , idx ):
@@ -374,14 +395,14 @@ def readField(self, idx):
374395
375396 field = np .empty ((self .nVar , self .nLocX , self .nLocY ), dtype = self .dtype )
376397
377- mpiFile = self .MPI_FILE_OPEN (mode = "r" )
398+ self .MPI_FILE_OPEN (mode = "r" )
378399 for iVar in range (self .nVar ):
379400 for iX in range (self .nLocX ):
380401 offset = offset0 + (
381402 iVar * self .nX * self .nY + (self .iLocX + iX )* self .nY + self .iLocY
382403 )* self .itemSize
383- mpiFile . Read_at_all (offset , field [iVar , iX ])
384- mpiFile . Close ()
404+ self . MPI_READ_AT (offset , field [iVar , iX ])
405+ self . MPI_FILE_CLOSE ()
385406
386407 return t , field
387408
@@ -404,7 +425,7 @@ def readField(self, idx):
404425 y = np .linspace (0 , 1 , num = 64 , endpoint = False )
405426 nY = y .size
406427
407- dim = 2
428+ dim = 1
408429 dType = np .float64
409430
410431 if dim == 1 :
@@ -416,42 +437,40 @@ def readField(self, idx):
416437 comm = MPI .COMM_WORLD
417438 MPI_SIZE = comm .Get_size ()
418439 MPI_RANK = comm .Get_rank ()
440+
441+ gridSizes = u0 .shape [1 :]
442+ algo = sys .argv [1 ] if len (sys .argv ) > 1 else "ChatGPT"
443+ blocks = BlockDecomposition (MPI_SIZE , gridSizes , algo , MPI_RANK )
444+ bounds = blocks .localBounds
419445 if MPI_SIZE > 1 :
420446 fileName = "test_MPI.pysdc"
421- if dim == 1 :
422- pSizeX = MPI_SIZE
423- pRankX = MPI_RANK
424- if dim == 2 :
425- assert MPI_SIZE == 4
426- pSizeX = MPI_SIZE // 2
427- pRankX = MPI_RANK // 2
428- pSizeY = MPI_SIZE // 2
429- pRankY = MPI_RANK % 2
430- else :
431- pSizeX , pRankX = 1 , 0
432- pSizeY , pRankY = 1 , 0
433-
434- def decomposeDirection (nItems , pSize , pRank ):
435- n0 = nItems // pSize
436- nRest = nItems - pSize * n0
437- nLoc = n0 + 1 * (pRank < nRest )
438- iLoc = pRank * n0 + nRest * (pRank >= nRest ) + pRank * (pRank < nRest )
439- return iLoc , nLoc
440-
441-
442- iLocX , nLocX = decomposeDirection (nX , pSizeX , pRankX )
447+
448+
443449 if dim == 1 :
450+ (iLocX , ), (nLocX , ) = bounds
451+ pRankX , = blocks .ranks
444452 Cart1D .setupMPI (comm , iLocX , nLocX )
445453 u0 = u0 [:, iLocX :iLocX + nLocX ]
446454
455+ MPI .COMM_WORLD .Barrier ()
456+ sleep (0.01 * MPI_RANK )
457+ print (f"[Rank { MPI_RANK } ] pRankX={ pRankX } ({ iLocX } , { nLocX } )" )
458+ MPI .COMM_WORLD .Barrier ()
459+
447460 f1 = Cart1D (dType , fileName )
448461 f1 .setHeader (nVar = u0 .shape [0 ], gridX = x )
449462
450463 if dim == 2 :
451- iLocY , nLocY = decomposeDirection (nY , pSizeY , pRankY )
464+ (iLocX , iLocY ), (nLocX , nLocY ) = bounds
465+ pRankX , pRankY = blocks .ranks
452466 Cart2D .setupMPI (comm , iLocX , nLocX , iLocY , nLocY )
453467 u0 = u0 [:, iLocX :iLocX + nLocX , iLocY :iLocY + nLocY ]
454468
469+ MPI .COMM_WORLD .Barrier ()
470+ sleep (0.01 * MPI_RANK )
471+ print (f"[Rank { MPI_RANK } ] pRankX={ pRankX } ({ iLocX } , { nLocX } ), pRankY={ pRankY } ({ iLocY } , { nLocY } )" )
472+ MPI .COMM_WORLD .Barrier ()
473+
455474 f1 = Cart2D (dType , fileName )
456475 f1 .setHeader (nVar = u0 .shape [0 ], gridX = x , gridY = y )
457476
@@ -465,7 +484,7 @@ def decomposeDirection(nItems, pSize, pRank):
465484 for t in np .arange (nTimes )/ nTimes :
466485 f1 .addField (t , t * u0 )
467486 if MPI_RANK == 0 :
468- print (f" -> done in { time ()- tBeg :1.2f } s !" )
487+ print (f" -> done in { time ()- tBeg :1.4f } s !" )
469488
470489 f2 = FieldsIO .fromFile (fileName )
471490 t , u = f2 .readField (2 )
0 commit comments