11from abc import ABC
22from collections import namedtuple
33from functools import cached_property
4+ from itertools import product
45
56import numpy as np
67from sympy import prod
78
89from devito import configuration
9- from devito .data import LEFT , RIGHT
10+ from devito .data import LEFT , RIGHT , CENTER
1011from devito .logger import warning
1112from devito .mpi import Distributor , MPI , SubDistributor
12- from devito .tools import ReducerMap , as_tuple
13+ from devito .tools import ReducerMap , as_tuple , frozendict
1314from devito .types .args import ArgProvider
1415from devito .types .basic import Scalar
1516from devito .types .dense import Function
1920 MultiSubDimension , DefaultDimension )
2021from devito .deprecations import deprecations
2122
22- __all__ = ['Grid' , 'SubDomain' , 'SubDomainSet' ]
23+ __all__ = ['Grid' , 'SubDomain' , 'SubDomainSet' , 'Border' ]
2324
2425
2526GlobalLocal = namedtuple ('GlobalLocal' , 'glb loc' )
@@ -871,7 +872,7 @@ def __subdomain_finalize_core__(self, grid):
871872 # Dimensions with identical names hash the same, hence tag them with the
872873 # SubDomainSet ID to make them unique so they can be used to key a dictionary
873874 # of replacements without risking overwriting.
874- i_dim = Dimension ('n_%s' % str (id (self )))
875+ i_dim = Dimension (f 'n_{ str (id (self ))} ' )
875876 d_dim = DefaultDimension (name = 'd' , default_value = 2 * grid .dim )
876877 sd_func = Function (name = self .name , grid = self ._grid ,
877878 shape = (self ._n_domains , 2 * grid .dim ),
@@ -885,7 +886,7 @@ def __subdomain_finalize_core__(self, grid):
885886 sd_func .data [:, idx ] = self ._local_bounds [idx ]
886887
887888 dimensions .append (MultiSubDimension (
888- 'i%s' % d .name , d , None , functions = sd_func ,
889+ f'i { d .name } ' , d , None , functions = sd_func ,
889890 bounds_indices = (2 * i , 2 * i + 1 ), implicit_dimension = i_dim
890891 ))
891892
@@ -907,6 +908,101 @@ def bounds(self):
907908 return self ._local_bounds
908909
909910
911+ class Border (SubDomainSet ):
912+ # FIXME: Completely untested
913+ """
914+ A convenience class for constructing a SubDomainSet which
915+ covers specified edges of the domain to a thickness of `border`.
916+ """
917+
918+ def __init__ (self , grid , border , dims = None , name = 'border' ):
919+ self ._border = border
920+
921+ if dims is None :
922+ _border_dims = {d : d for d in grid .dimensions }
923+ elif isinstance (dims , Dimension ):
924+ _border_dims = {dims : dims }
925+ elif isinstance (dims , dict ):
926+ _border_dims = dims
927+ else :
928+ raise ValueError ("Dimensions should be supplied as a single dimension, or "
929+ "a dict of the form `{x: x, y: 'left'}`" )
930+
931+ self ._border_dims = frozendict (_border_dims )
932+
933+ self ._name = name
934+
935+ # FIXME: Really horrible, refactor once working
936+ ndomains , bounds = self ._build_domain_map (grid )
937+
938+ super ().__init__ (N = ndomains , bounds = bounds , grid = grid )
939+
940+ @property
941+ def border (self ):
942+ return self ._border
943+
944+ @property
945+ def border_dims (self ):
946+ return self ._border_dims
947+
948+ @property
949+ def name (self ):
950+ return self ._name
951+
952+ def _build_domain_map (self , grid ):
953+ """
954+ """
955+
956+ # TODO: Really doesn't need to be a method
957+ # TODO: Fix this mess
958+
959+ domain_map = {}
960+ interval_map = {}
961+
962+ for d , s in zip (grid .dimensions , grid .shape ):
963+ if d in self .border_dims :
964+ side = self .border_dims [d ]
965+
966+ if isinstance (side , Dimension ):
967+ domain_map [d ] = (LEFT , CENTER , RIGHT )
968+ # Also build a mapper here
969+ interval_map [d ] = {LEFT : (0 , s - self .border ),
970+ CENTER : (self .border , self .border ),
971+ RIGHT : (s - self .border , 0 )}
972+ elif side == 'left' :
973+ domain_map [d ] = (LEFT , CENTER )
974+ interval_map [d ] = {LEFT : (0 , s - self .border ),
975+ CENTER : (self .border , 0 )}
976+ elif side == 'right' :
977+ domain_map [d ] = (CENTER , RIGHT )
978+ interval_map [d ] = {CENTER : (0 , self .border ),
979+ RIGHT : (s - self .border , 0 )}
980+ else :
981+ raise ValueError (f"Unrecognised side { side } " )
982+ else :
983+ domain_map [d ] = (CENTER ,)
984+ interval_map [d ] = {CENTER : (0 , 0 )}
985+
986+ self ._domain_map = frozendict (domain_map )
987+
988+ abstract_domains = list (product (* self ._domain_map .values ()))
989+ for d in abstract_domains :
990+ if all (i == CENTER for i in d ):
991+ abstract_domains .remove (d )
992+
993+ concrete_domains = []
994+ for dom in abstract_domains :
995+ concrete_domains .append ([interval_map [d ][i ]
996+ for d , i in zip (grid .dimensions , dom )])
997+
998+ concrete_domains = np .array (concrete_domains )
999+
1000+ reshaped = np .reshape (concrete_domains , (concrete_domains .shape [0 ], concrete_domains .shape [1 ]* concrete_domains .shape [2 ]))
1001+
1002+ bounds = reshaped .T
1003+ return concrete_domains .shape [0 ], tuple (bounds )
1004+
1005+
9101006# Preset SubDomains
9111007
9121008
0 commit comments