@@ -909,32 +909,95 @@ def bounds(self):
909909
910910
911911class Border (SubDomainSet ):
912- # FIXME: Completely untested
913- """
914- A convenience class for constructing a SubDomainSet which
915- covers specified edges of the domain to a thickness of `border`.
916912 """
913+ A convenience class for constructing a SubDomainSet which covers specified edges
914+ of the domain to a thickness of `border`. Note that none of the subdomains in this
915+ MultiSubDomain will overlap with one another.
917916
918- def __init__ (self , grid , border , dims = None , name = 'border' ):
919- self ._border = border
917+ By default, this border covers all sides of the grid. Alternatively, it is possible
918+ to add the border selectively to specific sides by supplying, for example,
919+ `dims={y: 'left'}` to obtain only a border on the left (from index zero) side of the
920+ y dimension, or `dims={x: x, y: 'left'}` to obtain borders on both sides of the x
921+ dimension, but only on the left of the y. One can also supply a single dimension on
922+ which to construct a border, using `dims=x` or similar.
920923
921- if dims is None :
922- _border_dims = {d : d for d in grid .dimensions }
923- elif isinstance (dims , Dimension ):
924- _border_dims = {dims : dims }
925- elif isinstance (dims , dict ):
926- _border_dims = dims
927- else :
928- raise ValueError ("Dimensions should be supplied as a single dimension, or "
929- "a dict of the form `{x: x, y: 'left'}`" )
924+ Parameters
925+ ----------
926+ grid : Grid
927+ The computational grid on which the border should be constructed
928+ border : int
929+ The thickness of the border in gridpoints
930+ dims : Dimension, dict, or None, optional
931+ The dimensions on which a border should be applied. Default is None, corresponding
932+ to borders on both sides of all dimensions.
933+ name : str, optional
934+ A unique name for the SubDomainSet created. Default is 'border'.
930935
931- self ._border_dims = frozendict (_border_dims )
936+ Examples
937+ --------
938+ Set up a border surrounding the whole grid:
932939
933- self ._name = name
940+ >>> from devito import Grid, Border, Function, Eq, Operator
941+ >>> grid = Grid(shape=(7, 7))
942+ >>> x, y = grid.dimensions
943+
944+ >>> border = Border(grid, 2) # Border of thickness 2
945+ >>> f = Function(name='f', grid=grid, dtype=np.int32)
946+ >>> eq = Eq(f, f+1, subdomain=border)
947+ >>> summary = Operator(eq)()
948+ >>> f.data
949+ Data([[1, 1, 1, 1, 1, 1, 1],
950+ [1, 1, 1, 1, 1, 1, 1],
951+ [1, 1, 0, 0, 0, 1, 1],
952+ [1, 1, 0, 0, 0, 1, 1],
953+ [1, 1, 0, 0, 0, 1, 1],
954+ [1, 1, 1, 1, 1, 1, 1],
955+ [1, 1, 1, 1, 1, 1, 1]], dtype=int32)
956+
957+ Set up a border consisting of the right side of the x dimension and both sides
958+ of the y dimension:
959+
960+ >>> border2 = Border(grid, 2, dims={x: 'right', y: y})
961+ >>> g = Function(name='g', grid=grid, dtype=np.int32)
962+ >>> eq2 = Eq(g, g+1, subdomain=border2)
963+ >>> summary = Operator(eq2)()
964+ >>> g.data
965+ Data([[1, 1, 0, 0, 0, 1, 1],
966+ [1, 1, 0, 0, 0, 1, 1],
967+ [1, 1, 0, 0, 0, 1, 1],
968+ [1, 1, 0, 0, 0, 1, 1],
969+ [1, 1, 0, 0, 0, 1, 1],
970+ [1, 1, 1, 1, 1, 1, 1],
971+ [1, 1, 1, 1, 1, 1, 1]], dtype=int32)
972+
973+ Set up a border consisting of only the sides in the y dimension:
974+
975+ >>> border3 = Border(grid, 2, dims=y)
976+ >>> h = Function(name='h', grid=grid, dtype=np.int32)
977+ >>> eq3 = Eq(h, h+1, subdomain=border3)
978+ >>> summary = Operator(eq3)()
979+ >>> h.data
980+ Data([[1, 1, 0, 0, 0, 1, 1],
981+ [1, 1, 0, 0, 0, 1, 1],
982+ [1, 1, 0, 0, 0, 1, 1],
983+ [1, 1, 0, 0, 0, 1, 1],
984+ [1, 1, 0, 0, 0, 1, 1],
985+ [1, 1, 0, 0, 0, 1, 1],
986+ [1, 1, 0, 0, 0, 1, 1]], dtype=int32)
934987
935- # FIXME: Really horrible, refactor once working
936- ndomains , bounds = self ._build_domain_map (grid )
988+ """
989+
990+ DimSpec = None | dict [Dimension , Dimension | str ]
991+ ParsedDimSpec = frozendict [Dimension , Dimension | str ]
992+
993+ def __init__ (self , grid : Grid , border : int | np .integer ,
994+ dims : DimSpec = None , name : str = 'border' ) -> None :
995+
996+ self ._name = name
997+ self ._border = border
998+ self ._border_dims = Border ._parse_dims (dims , grid )
937999
1000+ ndomains , bounds = self ._build_domains (grid )
9381001 super ().__init__ (N = ndomains , bounds = bounds , grid = grid )
9391002
9401003 @property
@@ -949,23 +1012,37 @@ def border_dims(self):
9491012 def name (self ):
9501013 return self ._name
9511014
952- def _build_domain_map (self , grid ):
1015+ @staticmethod
1016+ def _parse_dims (dims : DimSpec , grid : Grid ) -> ParsedDimSpec :
1017+ if dims is None :
1018+ _border_dims = {d : d for d in grid .dimensions }
1019+ elif isinstance (dims , Dimension ):
1020+ _border_dims = {dims : dims }
1021+ elif isinstance (dims , dict ):
1022+ _border_dims = dims
1023+ else :
1024+ raise ValueError ("Dimensions should be supplied as a single dimension, or "
1025+ "a dict of the form `{x: x, y: 'left'}`" )
1026+
1027+ return frozendict (_border_dims )
1028+
1029+ def _build_domains (self , grid : Grid ) -> tuple [int , tuple [np .ndarray ]]:
9531030 """
1031+ Constructs the bounds and ndomains kwargs for the SubDomainSet.
9541032 """
9551033
956- # TODO: Really doesn't need to be a method
957- # TODO: Fix this mess
958-
959- domain_map = {}
960- interval_map = {}
1034+ domain_map = {} # Stores the side
1035+ interval_map = {} # Stores the mapping from the side to bounds
9611036
1037+ # Unpack the user-provided specification into a set of sides (on which
1038+ # a cartesian product is taken) and a mapper from those sides to a set of
1039+ # bounds for each dimension.
9621040 for d , s in zip (grid .dimensions , grid .shape ):
9631041 if d in self .border_dims :
9641042 side = self .border_dims [d ]
9651043
9661044 if isinstance (side , Dimension ):
9671045 domain_map [d ] = (LEFT , CENTER , RIGHT )
968- # Also build a mapper here
9691046 interval_map [d ] = {LEFT : (0 , s - self .border ),
9701047 CENTER : (self .border , self .border ),
9711048 RIGHT : (s - self .border , 0 )}
@@ -978,29 +1055,31 @@ def _build_domain_map(self, grid):
9781055 interval_map [d ] = {CENTER : (0 , self .border ),
9791056 RIGHT : (s - self .border , 0 )}
9801057 else :
981- raise ValueError (f"Unrecognised side { side } " )
1058+ raise ValueError (f"Unrecognised side value { side } " )
9821059 else :
9831060 domain_map [d ] = (CENTER ,)
9841061 interval_map [d ] = {CENTER : (0 , 0 )}
9851062
986- self ._domain_map = frozendict (domain_map )
987-
988- abstract_domains = list (product (* self ._domain_map .values ()))
1063+ # Get the cartesian product, then remove any which solely consist of
1064+ # the central region. The sides are used to make this step more
1065+ # straightforward.
1066+ abstract_domains = list (product (* domain_map .values ()))
9891067 for d in abstract_domains :
9901068 if all (i == CENTER for i in d ):
9911069 abstract_domains .remove (d )
9921070
993- concrete_domains = []
1071+ domains = []
9941072 for dom in abstract_domains :
995- concrete_domains .append ([interval_map [d ][i ]
996- for d , i in zip (grid .dimensions , dom )])
1073+ domains .append ([interval_map [d ][i ]
1074+ for d , i in zip (grid .dimensions , dom )])
9971075
998- concrete_domains = np .array (concrete_domains )
1076+ domains = np .array (domains )
9991077
1000- reshaped = np .reshape (concrete_domains , (concrete_domains .shape [0 ], concrete_domains .shape [1 ]* concrete_domains .shape [2 ]))
1078+ # Reshape the bounds into the format expected by the SubDomainSet init
1079+ shape = (domains .shape [0 ], domains .shape [1 ]* domains .shape [2 ])
1080+ bounds = np .reshape (domains , shape ).T
10011081
1002- bounds = reshaped .T
1003- return concrete_domains .shape [0 ], tuple (bounds )
1082+ return domains .shape [0 ], tuple (bounds )
10041083
10051084
10061085# Preset SubDomains
0 commit comments