@@ -387,7 +387,8 @@ def entry_point_compute_mpas_transect_masks():
387387 engine = args .engine )
388388
389389
390- def compute_mpas_flood_fill_mask (dsMesh , fcSeed , logger = None , workers = - 1 ):
390+ def compute_mpas_flood_fill_mask (dsMesh , fcSeed , daGrow = None , logger = None ,
391+ workers = - 1 ):
391392 """
392393 Flood fill from the given set of seed points to create a contiguous mask.
393394 The flood fill operates using cellsOnCell, starting from the cells
@@ -401,6 +402,11 @@ def compute_mpas_flood_fill_mask(dsMesh, fcSeed, logger=None, workers=-1):
401402 fcSeed : geometric_features.FeatureCollection
402403 A feature collection containing points at which to start the flood fill
403404
405+ daGrow : xarray.DataArray, optional
406+ A data array of size ``nCells`` with a mask that is 1 anywhere the
407+ flood fill is allowed to grow. The default is that the mask is all
408+ ones.
409+
404410 logger : logging.Logger, optional
405411 A logger for the output if not stdout
406412
@@ -426,17 +432,22 @@ def compute_mpas_flood_fill_mask(dsMesh, fcSeed, logger=None, workers=-1):
426432 if logger is not None :
427433 logger .info (' Computing flood fill mask on cells:' )
428434
429- mask = _compute_seed_mask (fcSeed , lon , lat , workers )
435+ seedMask = _compute_seed_mask (fcSeed , lon , lat , workers )
430436
431437 cellsOnCell = dsMesh .cellsOnCell .values - 1
432438
433- mask = _flood_fill_mask (mask , cellsOnCell )
439+ if daGrow is not None :
440+ growMask = daGrow .values
441+ else :
442+ growMask = numpy .ones (dsMesh .sizes ['nCells' ])
443+
444+ seedMask = _flood_fill_mask (seedMask , growMask , cellsOnCell )
434445
435446 if logger is not None :
436447 logger .info (' Adding masks to dataset...' )
437448 # create a new data array for the mask
438449 masksVarName = 'cellSeedMask'
439- dsMasks [masksVarName ] = (('nCells' ,), numpy .array (mask , dtype = int ))
450+ dsMasks [masksVarName ] = (('nCells' ,), numpy .array (seedMask , dtype = int ))
440451
441452 if logger is not None :
442453 logger .info (' Done.' )
@@ -1183,30 +1194,31 @@ def _compute_seed_mask(fcSeed, lon, lat, workers):
11831194 return mask
11841195
11851196
1186- def _flood_fill_mask (mask , cellsOnCell ):
1197+ def _flood_fill_mask (seedMask , growMask , cellsOnCell ):
11871198 """
11881199 Flood fill starting with a mask of seed points
11891200 """
11901201
11911202 maxNeighbors = cellsOnCell .shape [1 ]
11921203
11931204 while True :
1194- neighbors = cellsOnCell [mask == 1 , :]
1205+ neighbors = cellsOnCell [seedMask == 1 , :]
11951206 maskCount = 0
11961207 for iNeighbor in range (maxNeighbors ):
11971208 indices = neighbors [:, iNeighbor ]
1198- # we only want to mask valid neighbors and locations that aren't
1199- # already masked
1209+ # we only want to mask valid neighbors, locations that aren't
1210+ # already masked, and locations that we're allowed to flood
12001211 indices = indices [indices >= 0 ]
1201- localMask = mask [indices ] == 0
1212+ localMask = numpy .logical_and (seedMask [indices ] == 0 ,
1213+ growMask [indices ] == 1 )
12021214 maskCount += numpy .count_nonzero (localMask )
12031215 indices = indices [localMask ]
1204- mask [indices ] = 1
1216+ seedMask [indices ] = 1
12051217
12061218 if maskCount == 0 :
12071219 break
12081220
1209- return mask
1221+ return seedMask
12101222
12111223
12121224def _compute_edge_sign (dsMesh , edgeMask , shape ):
0 commit comments