Skip to content

Commit f9f1bd7

Browse files
committed
dsl: Add doctests etc to Border convenience object
1 parent 2fe7283 commit f9f1bd7

File tree

1 file changed

+116
-37
lines changed

1 file changed

+116
-37
lines changed

devito/types/grid.py

Lines changed: 116 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -909,32 +909,95 @@ def bounds(self):
909909

910910

911911
class Border(SubDomainSet):
912-
# FIXME: Completely untested
913-
"""
914-
A convenience class for constructing a SubDomainSet which
915-
covers specified edges of the domain to a thickness of `border`.
916912
"""
913+
A convenience class for constructing a SubDomainSet which covers specified edges
914+
of the domain to a thickness of `border`. Note that none of the subdomains in this
915+
MultiSubDomain will overlap with one another.
917916
918-
def __init__(self, grid, border, dims=None, name='border'):
919-
self._border = border
917+
By default, this border covers all sides of the grid. Alternatively, it is possible
918+
to add the border selectively to specific sides by supplying, for example,
919+
`dims={y: 'left'}` to obtain only a border on the left (from index zero) side of the
920+
y dimension, or `dims={x: x, y: 'left'}` to obtain borders on both sides of the x
921+
dimension, but only on the left of the y. One can also supply a single dimension on
922+
which to construct a border, using `dims=x` or similar.
920923
921-
if dims is None:
922-
_border_dims = {d: d for d in grid.dimensions}
923-
elif isinstance(dims, Dimension):
924-
_border_dims = {dims: dims}
925-
elif isinstance(dims, dict):
926-
_border_dims = dims
927-
else:
928-
raise ValueError("Dimensions should be supplied as a single dimension, or "
929-
"a dict of the form `{x: x, y: 'left'}`")
924+
Parameters
925+
----------
926+
grid : Grid
927+
The computational grid on which the border should be constructed
928+
border : int
929+
The thickness of the border in gridpoints
930+
dims : Dimension, dict, or None, optional
931+
The dimensions on which a border should be applied. Default is None, corresponding
932+
to borders on both sides of all dimensions.
933+
name : str, optional
934+
A unique name for the SubDomainSet created. Default is 'border'.
930935
931-
self._border_dims = frozendict(_border_dims)
936+
Examples
937+
--------
938+
Set up a border surrounding the whole grid:
932939
933-
self._name = name
940+
>>> from devito import Grid, Border, Function, Eq, Operator
941+
>>> grid = Grid(shape=(7, 7))
942+
>>> x, y = grid.dimensions
943+
944+
>>> border = Border(grid, 2) # Border of thickness 2
945+
>>> f = Function(name='f', grid=grid, dtype=np.int32)
946+
>>> eq = Eq(f, f+1, subdomain=border)
947+
>>> summary = Operator(eq)()
948+
>>> f.data
949+
Data([[1, 1, 1, 1, 1, 1, 1],
950+
[1, 1, 1, 1, 1, 1, 1],
951+
[1, 1, 0, 0, 0, 1, 1],
952+
[1, 1, 0, 0, 0, 1, 1],
953+
[1, 1, 0, 0, 0, 1, 1],
954+
[1, 1, 1, 1, 1, 1, 1],
955+
[1, 1, 1, 1, 1, 1, 1]], dtype=int32)
956+
957+
Set up a border consisting of the right side of the x dimension and both sides
958+
of the y dimension:
959+
960+
>>> border2 = Border(grid, 2, dims={x: 'right', y: y})
961+
>>> g = Function(name='g', grid=grid, dtype=np.int32)
962+
>>> eq2 = Eq(g, g+1, subdomain=border2)
963+
>>> summary = Operator(eq2)()
964+
>>> g.data
965+
Data([[1, 1, 0, 0, 0, 1, 1],
966+
[1, 1, 0, 0, 0, 1, 1],
967+
[1, 1, 0, 0, 0, 1, 1],
968+
[1, 1, 0, 0, 0, 1, 1],
969+
[1, 1, 0, 0, 0, 1, 1],
970+
[1, 1, 1, 1, 1, 1, 1],
971+
[1, 1, 1, 1, 1, 1, 1]], dtype=int32)
972+
973+
Set up a border consisting of only the sides in the y dimension:
974+
975+
>>> border3 = Border(grid, 2, dims=y)
976+
>>> h = Function(name='h', grid=grid, dtype=np.int32)
977+
>>> eq3 = Eq(h, h+1, subdomain=border3)
978+
>>> summary = Operator(eq3)()
979+
>>> h.data
980+
Data([[1, 1, 0, 0, 0, 1, 1],
981+
[1, 1, 0, 0, 0, 1, 1],
982+
[1, 1, 0, 0, 0, 1, 1],
983+
[1, 1, 0, 0, 0, 1, 1],
984+
[1, 1, 0, 0, 0, 1, 1],
985+
[1, 1, 0, 0, 0, 1, 1],
986+
[1, 1, 0, 0, 0, 1, 1]], dtype=int32)
934987
935-
# FIXME: Really horrible, refactor once working
936-
ndomains, bounds = self._build_domain_map(grid)
988+
"""
989+
990+
DimSpec = None | dict[Dimension, Dimension | str]
991+
ParsedDimSpec = frozendict[Dimension, Dimension | str]
992+
993+
def __init__(self, grid: Grid, border: int | np.integer,
994+
dims: DimSpec = None, name: str = 'border') -> None:
995+
996+
self._name = name
997+
self._border = border
998+
self._border_dims = Border._parse_dims(dims, grid)
937999

1000+
ndomains, bounds = self._build_domains(grid)
9381001
super().__init__(N=ndomains, bounds=bounds, grid=grid)
9391002

9401003
@property
@@ -949,23 +1012,37 @@ def border_dims(self):
9491012
def name(self):
9501013
return self._name
9511014

952-
def _build_domain_map(self, grid):
1015+
@staticmethod
1016+
def _parse_dims(dims: DimSpec, grid: Grid) -> ParsedDimSpec:
1017+
if dims is None:
1018+
_border_dims = {d: d for d in grid.dimensions}
1019+
elif isinstance(dims, Dimension):
1020+
_border_dims = {dims: dims}
1021+
elif isinstance(dims, dict):
1022+
_border_dims = dims
1023+
else:
1024+
raise ValueError("Dimensions should be supplied as a single dimension, or "
1025+
"a dict of the form `{x: x, y: 'left'}`")
1026+
1027+
return frozendict(_border_dims)
1028+
1029+
def _build_domains(self, grid: Grid) -> tuple[int, tuple[np.ndarray]]:
9531030
"""
1031+
Constructs the bounds and ndomains kwargs for the SubDomainSet.
9541032
"""
9551033

956-
# TODO: Really doesn't need to be a method
957-
# TODO: Fix this mess
958-
959-
domain_map = {}
960-
interval_map = {}
1034+
domain_map = {} # Stores the side
1035+
interval_map = {} # Stores the mapping from the side to bounds
9611036

1037+
# Unpack the user-provided specification into a set of sides (on which
1038+
# a cartesian product is taken) and a mapper from those sides to a set of
1039+
# bounds for each dimension.
9621040
for d, s in zip(grid.dimensions, grid.shape):
9631041
if d in self.border_dims:
9641042
side = self.border_dims[d]
9651043

9661044
if isinstance(side, Dimension):
9671045
domain_map[d] = (LEFT, CENTER, RIGHT)
968-
# Also build a mapper here
9691046
interval_map[d] = {LEFT: (0, s-self.border),
9701047
CENTER: (self.border, self.border),
9711048
RIGHT: (s-self.border, 0)}
@@ -978,29 +1055,31 @@ def _build_domain_map(self, grid):
9781055
interval_map[d] = {CENTER: (0, self.border),
9791056
RIGHT: (s-self.border, 0)}
9801057
else:
981-
raise ValueError(f"Unrecognised side {side}")
1058+
raise ValueError(f"Unrecognised side value {side}")
9821059
else:
9831060
domain_map[d] = (CENTER,)
9841061
interval_map[d] = {CENTER: (0, 0)}
9851062

986-
self._domain_map = frozendict(domain_map)
987-
988-
abstract_domains = list(product(*self._domain_map.values()))
1063+
# Get the cartesian product, then remove any which solely consist of
1064+
# the central region. The sides are used to make this step more
1065+
# straightforward.
1066+
abstract_domains = list(product(*domain_map.values()))
9891067
for d in abstract_domains:
9901068
if all(i == CENTER for i in d):
9911069
abstract_domains.remove(d)
9921070

993-
concrete_domains = []
1071+
domains = []
9941072
for dom in abstract_domains:
995-
concrete_domains.append([interval_map[d][i]
996-
for d, i in zip(grid.dimensions, dom)])
1073+
domains.append([interval_map[d][i]
1074+
for d, i in zip(grid.dimensions, dom)])
9971075

998-
concrete_domains = np.array(concrete_domains)
1076+
domains = np.array(domains)
9991077

1000-
reshaped = np.reshape(concrete_domains, (concrete_domains.shape[0], concrete_domains.shape[1]*concrete_domains.shape[2]))
1078+
# Reshape the bounds into the format expected by the SubDomainSet init
1079+
shape = (domains.shape[0], domains.shape[1]*domains.shape[2])
1080+
bounds = np.reshape(domains, shape).T
10011081

1002-
bounds = reshaped.T
1003-
return concrete_domains.shape[0], tuple(bounds)
1082+
return domains.shape[0], tuple(bounds)
10041083

10051084

10061085
# Preset SubDomains

0 commit comments

Comments
 (0)