Skip to content

Commit 2284bef

Browse files
committed
examples: Add tests for Border and document in examples
1 parent f9f1bd7 commit 2284bef

File tree

3 files changed

+664
-41
lines changed

3 files changed

+664
-41
lines changed

devito/types/grid.py

Lines changed: 132 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -891,6 +891,7 @@ def __subdomain_finalize_core__(self, grid):
891891
))
892892

893893
self._dimensions = tuple(dimensions)
894+
self._subfunction = sd_func
894895

895896
def __subdomain_finalize__(self):
896897
self.__subdomain_finalize_core__(self.grid)
@@ -911,8 +912,7 @@ def bounds(self):
911912
class Border(SubDomainSet):
912913
"""
913914
A convenience class for constructing a SubDomainSet which covers specified edges
914-
of the domain to a thickness of `border`. Note that none of the subdomains in this
915-
MultiSubDomain will overlap with one another.
915+
of the domain to a thickness of `border`.
916916
917917
By default, this border covers all sides of the grid. Alternatively, it is possible
918918
to add the border selectively to specific sides by supplying, for example,
@@ -921,17 +921,28 @@ class Border(SubDomainSet):
921921
dimension, but only on the left of the y. One can also supply a single dimension on
922922
which to construct a border, using `dims=x` or similar.
923923
924+
Corners can be included, excluded, or overlapped by setting the `corners` kwarg.
925+
924926
Parameters
925927
----------
926928
grid : Grid
927929
The computational grid on which the border should be constructed
928-
border : int
929-
The thickness of the border in gridpoints
930+
border : int, tuple of int, or tuple of tuple of int
931+
The thickness of the border in gridpoints. A tuple with thickness for each
932+
dimension can be supplied if different thicknesses are required per-dimension.
933+
A tuple of tuples can also be supplied for more granular control of left and
934+
right border thicknesses for each dimension.
930935
dims : Dimension, dict, or None, optional
931936
The dimensions on which a border should be applied. Default is None, corresponding
932937
to borders on both sides of all dimensions.
938+
inset : int, tuple of int, or tuple of tuple of int, optional
939+
Inset the border from the edges by some number of gridpoints. Default is 0.
933940
name : str, optional
934941
A unique name for the SubDomainSet created. Default is 'border'.
942+
corners : str, optional
943+
Behaviour at the corners. Can be set to 'overlap' for overlapping subdomains at
944+
the corners, 'nooverlap' for non-overlapping corner subdomains, or 'nocorners'
945+
to omit the corners entirely. Default is `nooverlap`.
935946
936947
Examples
937948
--------
@@ -985,17 +996,42 @@ class Border(SubDomainSet):
985996
[1, 1, 0, 0, 0, 1, 1],
986997
[1, 1, 0, 0, 0, 1, 1]], dtype=int32)
987998
999+
which is equivalent to:
1000+
1001+
>>> border4 = Border(grid, 2, dims={y: y})
1002+
>>> i = Function(name='i', grid=grid, dtype=np.int32)
1003+
>>> eq4 = Eq(i, i+1, subdomain=border4)
1004+
>>> summary = Operator(eq4)()
1005+
>>> i.data
1006+
Data([[1, 1, 0, 0, 0, 1, 1],
1007+
[1, 1, 0, 0, 0, 1, 1],
1008+
[1, 1, 0, 0, 0, 1, 1],
1009+
[1, 1, 0, 0, 0, 1, 1],
1010+
[1, 1, 0, 0, 0, 1, 1],
1011+
[1, 1, 0, 0, 0, 1, 1],
1012+
[1, 1, 0, 0, 0, 1, 1]], dtype=int32)
1013+
9881014
"""
9891015

9901016
DimSpec = None | dict[Dimension, Dimension | str]
9911017
ParsedDimSpec = frozendict[Dimension, Dimension | str]
9921018

993-
def __init__(self, grid: Grid, border: int | np.integer,
994-
dims: DimSpec = None, name: str = 'border') -> None:
1019+
BorderInt = int | np.integer
1020+
BorderSpec = BorderInt | tuple[BorderInt] | tuple[tuple[BorderInt]]
1021+
ParsedBorderSpec = tuple[tuple[BorderInt]]
1022+
1023+
def __init__(self, grid: Grid, border: BorderSpec,
1024+
dims: DimSpec = None, name: str = 'border',
1025+
inset: BorderSpec = 0, corners: str = 'nooverlap') -> None:
9951026

9961027
self._name = name
997-
self._border = border
1028+
self._border = Border._parse_border(border, grid)
9981029
self._border_dims = Border._parse_dims(dims, grid)
1030+
self._inset = Border._parse_border(inset, grid, mode='inset')
1031+
1032+
if corners not in ('overlap', 'nooverlap', 'nocorners'):
1033+
raise ValueError(f"Unrecognised corners option: {corners}")
1034+
self._corners = corners
9991035

10001036
ndomains, bounds = self._build_domains(grid)
10011037
super().__init__(N=ndomains, bounds=bounds, grid=grid)
@@ -1012,6 +1048,18 @@ def border_dims(self):
10121048
def name(self):
10131049
return self._name
10141050

1051+
@property
1052+
def corners(self):
1053+
return self._corners
1054+
1055+
@property
1056+
def inset(self):
1057+
return self._inset
1058+
1059+
def _inset_flat(self, i):
1060+
"""Flattened access into self.inset"""
1061+
return self.inset[i // len(self.inset)][i % 2]
1062+
10151063
@staticmethod
10161064
def _parse_dims(dims: DimSpec, grid: Grid) -> ParsedDimSpec:
10171065
if dims is None:
@@ -1026,52 +1074,110 @@ def _parse_dims(dims: DimSpec, grid: Grid) -> ParsedDimSpec:
10261074

10271075
return frozendict(_border_dims)
10281076

1077+
@staticmethod
1078+
def _parse_border(border: BorderSpec, grid: Grid,
1079+
mode: str = 'border') -> ParsedBorderSpec:
1080+
if isinstance(border, (int, np.integer)):
1081+
return ((border, border),)*len(grid.dimensions)
1082+
1083+
else: # Tuple guaranteed by typing
1084+
if not len(border) == len(grid.dimensions):
1085+
raise ValueError(f"Length of {mode} specification should "
1086+
"match number of dimensions")
1087+
retval = []
1088+
for b, d in zip(border, grid.dimensions):
1089+
if isinstance(b, tuple):
1090+
if not len(b) == 2:
1091+
raise ValueError(f"{b}: more than two thicknesses supplied "
1092+
f"for dimension {d}")
1093+
retval.append(b)
1094+
else:
1095+
retval.append((b, b))
1096+
1097+
return tuple(retval)
1098+
10291099
def _build_domains(self, grid: Grid) -> tuple[int, tuple[np.ndarray]]:
10301100
"""
10311101
Constructs the bounds and ndomains kwargs for the SubDomainSet.
10321102
"""
1103+
if self.corners == 'overlap':
1104+
return self._build_domains_overlap(grid)
1105+
else:
1106+
return self._build_domains_nooverlap(grid)
1107+
1108+
def _build_domains_overlap(self, grid: Grid) -> tuple[int, tuple[np.ndarray]]:
1109+
1110+
bounds = []
1111+
iterations = zip(grid.dimensions, grid.shape, self.border, strict=True)
1112+
for i, (d, s, b) in enumerate(iterations):
1113+
1114+
if d in self.border_dims:
1115+
# Note: slightly counterintuitive since a left-side boundary only has
1116+
# right-side thickness
1117+
bounds_l = [s - b[0] - self.inset[i][0] if j == 2*i+1
1118+
else self._inset_flat(j)
1119+
for j in range(2*len(grid.dimensions))]
1120+
bounds_r = [s - b[1] - self.inset[i][1] if j == 2*i
1121+
else self._inset_flat(j)
1122+
for j in range(2*len(grid.dimensions))]
10331123

1124+
side = self.border_dims[d]
1125+
if isinstance(side, Dimension):
1126+
bounds.extend([bounds_l, bounds_r])
1127+
elif side == 'left':
1128+
bounds.append(bounds_l)
1129+
elif side == 'right':
1130+
bounds.append(bounds_r)
1131+
else:
1132+
raise ValueError(f"Unrecognised side value: {side}")
1133+
1134+
# Need to transpose array to fit into expected format for SubDomainSet
1135+
return len(bounds), tuple(np.array(bounds).T)
1136+
1137+
def _build_domains_nooverlap(self, grid: Grid) -> tuple[int, tuple[np.ndarray]]:
10341138
domain_map = {} # Stores the side
10351139
interval_map = {} # Stores the mapping from the side to bounds
10361140

10371141
# Unpack the user-provided specification into a set of sides (on which
10381142
# a cartesian product is taken) and a mapper from those sides to a set of
10391143
# bounds for each dimension.
1040-
for d, s in zip(grid.dimensions, grid.shape):
1144+
for d, s, b, i in zip(grid.dimensions, grid.shape, self.border, self.inset):
10411145
if d in self.border_dims:
10421146
side = self.border_dims[d]
10431147

10441148
if isinstance(side, Dimension):
10451149
domain_map[d] = (LEFT, CENTER, RIGHT)
1046-
interval_map[d] = {LEFT: (0, s-self.border),
1047-
CENTER: (self.border, self.border),
1048-
RIGHT: (s-self.border, 0)}
1150+
interval_map[d] = {LEFT: (i[0], s - b[0] - i[0]),
1151+
CENTER: (b[0] + i[0], b[1] + i[1]),
1152+
RIGHT: (s - b[1] - i[1], i[1])}
10491153
elif side == 'left':
10501154
domain_map[d] = (LEFT, CENTER)
1051-
interval_map[d] = {LEFT: (0, s-self.border),
1052-
CENTER: (self.border, 0)}
1155+
# For intuitive behaviour, 'nocorners' should always skip corners
1156+
centerval = b[1] + i[1] if self.corners == 'nocorners' else i[1]
1157+
interval_map[d] = {LEFT: (i[0], s - b[0] - i[0]),
1158+
CENTER: (b[0] + i[0], centerval)}
10531159
elif side == 'right':
10541160
domain_map[d] = (CENTER, RIGHT)
1055-
interval_map[d] = {CENTER: (0, self.border),
1056-
RIGHT: (s-self.border, 0)}
1161+
centerval = b[0] + i[0] if self.corners == 'nocorners' else i[0]
1162+
interval_map[d] = {CENTER: (centerval, b[1] + i[1]),
1163+
RIGHT: (s - b[1] - i[1], i[1])}
10571164
else:
10581165
raise ValueError(f"Unrecognised side value {side}")
10591166
else:
10601167
domain_map[d] = (CENTER,)
10611168
interval_map[d] = {CENTER: (0, 0)}
10621169

1063-
# Get the cartesian product, then remove any which solely consist of
1064-
# the central region. The sides are used to make this step more
1065-
# straightforward.
1066-
abstract_domains = list(product(*domain_map.values()))
1067-
for d in abstract_domains:
1068-
if all(i == CENTER for i in d):
1069-
abstract_domains.remove(d)
1070-
1170+
# Get the cartesian product, then select the required domains. The sides are used
1171+
# to make this step more straightforward.
1172+
maybe_domains = list(product(*domain_map.values()))
10711173
domains = []
1072-
for dom in abstract_domains:
1073-
domains.append([interval_map[d][i]
1074-
for d, i in zip(grid.dimensions, dom)])
1174+
for d in maybe_domains:
1175+
if not all(i is CENTER for i in d):
1176+
# Don't add any domains that are completely centered
1177+
if self.corners != 'nocorners' or any(i is CENTER for i in d):
1178+
# Don't add corners if 'no corners' option selected
1179+
domains.append([interval_map[dim][dom] for (dim, dom)
1180+
in zip(grid.dimensions, d)])
10751181

10761182
domains = np.array(domains)
10771183

0 commit comments

Comments
 (0)