Skip to content

Commit 2fe7283

Browse files
committed
api: First attempt at a Border convenience class
1 parent 4ee88fb commit 2fe7283

File tree

1 file changed

+101
-5
lines changed

1 file changed

+101
-5
lines changed

devito/types/grid.py

Lines changed: 101 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,16 @@
11
from abc import ABC
22
from collections import namedtuple
33
from functools import cached_property
4+
from itertools import product
45

56
import numpy as np
67
from sympy import prod
78

89
from devito import configuration
9-
from devito.data import LEFT, RIGHT
10+
from devito.data import LEFT, RIGHT, CENTER
1011
from devito.logger import warning
1112
from devito.mpi import Distributor, MPI, SubDistributor
12-
from devito.tools import ReducerMap, as_tuple
13+
from devito.tools import ReducerMap, as_tuple, frozendict
1314
from devito.types.args import ArgProvider
1415
from devito.types.basic import Scalar
1516
from devito.types.dense import Function
@@ -19,7 +20,7 @@
1920
MultiSubDimension, DefaultDimension)
2021
from devito.deprecations import deprecations
2122

22-
__all__ = ['Grid', 'SubDomain', 'SubDomainSet']
23+
__all__ = ['Grid', 'SubDomain', 'SubDomainSet', 'Border']
2324

2425

2526
GlobalLocal = namedtuple('GlobalLocal', 'glb loc')
@@ -871,7 +872,7 @@ def __subdomain_finalize_core__(self, grid):
871872
# Dimensions with identical names hash the same, hence tag them with the
872873
# SubDomainSet ID to make them unique so they can be used to key a dictionary
873874
# of replacements without risking overwriting.
874-
i_dim = Dimension('n_%s' % str(id(self)))
875+
i_dim = Dimension(f'n_{str(id(self))}')
875876
d_dim = DefaultDimension(name='d', default_value=2*grid.dim)
876877
sd_func = Function(name=self.name, grid=self._grid,
877878
shape=(self._n_domains, 2*grid.dim),
@@ -885,7 +886,7 @@ def __subdomain_finalize_core__(self, grid):
885886
sd_func.data[:, idx] = self._local_bounds[idx]
886887

887888
dimensions.append(MultiSubDimension(
888-
'i%s' % d.name, d, None, functions=sd_func,
889+
f'i{d.name}', d, None, functions=sd_func,
889890
bounds_indices=(2*i, 2*i+1), implicit_dimension=i_dim
890891
))
891892

@@ -907,6 +908,101 @@ def bounds(self):
907908
return self._local_bounds
908909

909910

911+
class Border(SubDomainSet):
912+
# FIXME: Completely untested
913+
"""
914+
A convenience class for constructing a SubDomainSet which
915+
covers specified edges of the domain to a thickness of `border`.
916+
"""
917+
918+
def __init__(self, grid, border, dims=None, name='border'):
919+
self._border = border
920+
921+
if dims is None:
922+
_border_dims = {d: d for d in grid.dimensions}
923+
elif isinstance(dims, Dimension):
924+
_border_dims = {dims: dims}
925+
elif isinstance(dims, dict):
926+
_border_dims = dims
927+
else:
928+
raise ValueError("Dimensions should be supplied as a single dimension, or "
929+
"a dict of the form `{x: x, y: 'left'}`")
930+
931+
self._border_dims = frozendict(_border_dims)
932+
933+
self._name = name
934+
935+
# FIXME: Really horrible, refactor once working
936+
ndomains, bounds = self._build_domain_map(grid)
937+
938+
super().__init__(N=ndomains, bounds=bounds, grid=grid)
939+
940+
@property
941+
def border(self):
942+
return self._border
943+
944+
@property
945+
def border_dims(self):
946+
return self._border_dims
947+
948+
@property
949+
def name(self):
950+
return self._name
951+
952+
def _build_domain_map(self, grid):
953+
"""
954+
"""
955+
956+
# TODO: Really doesn't need to be a method
957+
# TODO: Fix this mess
958+
959+
domain_map = {}
960+
interval_map = {}
961+
962+
for d, s in zip(grid.dimensions, grid.shape):
963+
if d in self.border_dims:
964+
side = self.border_dims[d]
965+
966+
if isinstance(side, Dimension):
967+
domain_map[d] = (LEFT, CENTER, RIGHT)
968+
# Also build a mapper here
969+
interval_map[d] = {LEFT: (0, s-self.border),
970+
CENTER: (self.border, self.border),
971+
RIGHT: (s-self.border, 0)}
972+
elif side == 'left':
973+
domain_map[d] = (LEFT, CENTER)
974+
interval_map[d] = {LEFT: (0, s-self.border),
975+
CENTER: (self.border, 0)}
976+
elif side == 'right':
977+
domain_map[d] = (CENTER, RIGHT)
978+
interval_map[d] = {CENTER: (0, self.border),
979+
RIGHT: (s-self.border, 0)}
980+
else:
981+
raise ValueError(f"Unrecognised side {side}")
982+
else:
983+
domain_map[d] = (CENTER,)
984+
interval_map[d] = {CENTER: (0, 0)}
985+
986+
self._domain_map = frozendict(domain_map)
987+
988+
abstract_domains = list(product(*self._domain_map.values()))
989+
for d in abstract_domains:
990+
if all(i == CENTER for i in d):
991+
abstract_domains.remove(d)
992+
993+
concrete_domains = []
994+
for dom in abstract_domains:
995+
concrete_domains.append([interval_map[d][i]
996+
for d, i in zip(grid.dimensions, dom)])
997+
998+
concrete_domains = np.array(concrete_domains)
999+
1000+
reshaped = np.reshape(concrete_domains, (concrete_domains.shape[0], concrete_domains.shape[1]*concrete_domains.shape[2]))
1001+
1002+
bounds = reshaped.T
1003+
return concrete_domains.shape[0], tuple(bounds)
1004+
1005+
9101006
# Preset SubDomains
9111007

9121008

0 commit comments

Comments
 (0)