Skip to content

Commit dcc1559

Browse files
Merge pull request #2538 from devitocodes/thread-block-clusters
compiler: Misc extensions towards supporting thread block clustering
2 parents 3707554 + 86c5b69 commit dcc1559

File tree

14 files changed

+164
-54
lines changed

14 files changed

+164
-54
lines changed

devito/arch/archinfo.py

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -684,6 +684,12 @@ def limits(self, compiler=None, language=None):
684684
'max-block-dims': sys.maxsize,
685685
}
686686

687+
def supports(self, query, language=None):
688+
"""
689+
Return True if the platform supports a given feature, False otherwise.
690+
"""
691+
return False
692+
687693

688694
class Cpu64(Platform):
689695

@@ -836,7 +842,8 @@ class Device(Platform):
836842

837843
def __init__(self, name, cores_logical=None, cores_physical=None, isa='cpp',
838844
max_threads_per_block=1024, max_threads_dimx=1024,
839-
max_threads_dimy=1024, max_threads_dimz=64):
845+
max_threads_dimy=1024, max_threads_dimz=64,
846+
max_thread_block_cluster_size=8):
840847
super().__init__(name)
841848

842849
cpu_info = get_cpu_info()
@@ -849,6 +856,7 @@ def __init__(self, name, cores_logical=None, cores_physical=None, isa='cpp',
849856
self.max_threads_dimx = max_threads_dimx
850857
self.max_threads_dimy = max_threads_dimy
851858
self.max_threads_dimz = max_threads_dimz
859+
self.max_thread_block_cluster_size = max_thread_block_cluster_size
852860

853861
@classmethod
854862
def _mro(cls):
@@ -897,12 +905,6 @@ def limits(self, compiler=None, language=None):
897905
'max-block-dims': 3,
898906
}
899907

900-
def supports(self, query, language=None):
901-
"""
902-
Check if the device supports a given feature.
903-
"""
904-
return False
905-
906908

907909
class IntelDevice(Device):
908910

@@ -939,7 +941,7 @@ def supports(self, query, language=None):
939941
if query == 'async-loads' and cc >= 80:
940942
# Asynchronous pipeline loads -- introduced in Ampere
941943
return True
942-
elif query == 'tma' and cc >= 90:
944+
elif query in ('tma', 'thread-block-cluster') and cc >= 90:
943945
# Tensor Memory Accelerator -- introduced in Hopper
944946
return True
945947
else:
@@ -953,25 +955,23 @@ class Volta(NvidiaDevice):
953955
class Ampere(Volta):
954956

955957
def supports(self, query, language=None):
956-
if language != 'cuda':
957-
return False
958-
959958
if query == 'async-loads':
960959
return True
961-
962-
return super().supports(query, language)
960+
else:
961+
return super().supports(query, language)
963962

964963

965964
class Hopper(Ampere):
966965

967-
def supports(self, query, language=None):
968-
if language != 'cuda':
969-
return False
966+
def __init__(self, *args, **kwargs):
967+
kwargs.setdefault('max_thread_block_cluster_size', 16)
968+
super().__init__(*args, **kwargs)
970969

971-
if query == 'tma':
970+
def supports(self, query, language=None):
971+
if query in ('tma', 'thread-block-cluster'):
972972
return True
973-
974-
return super().supports(query, language)
973+
else:
974+
return super().supports(query, language)
975975

976976

977977
class Blackwell(Hopper):

devito/ir/iet/efunc.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -159,9 +159,9 @@ class DeviceFunction(Callable):
159159
"""
160160

161161
def __init__(self, name, body, retval='void', parameters=None,
162-
prefix='__global__', templates=None):
162+
prefix='__global__', templates=None, attributes=None):
163163
super().__init__(name, body, retval, parameters=parameters, prefix=prefix,
164-
templates=templates)
164+
templates=templates, attributes=attributes)
165165

166166

167167
class DeviceCall(Call):

devito/ir/iet/nodes.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -718,17 +718,21 @@ class Callable(Node):
718718
parameters : list of Basic, optional
719719
The objects in input to the Callable.
720720
prefix : list of str, optional
721-
Qualifiers to prepend to the Callable signature. None by defaults.
721+
Qualifiers to prepend to the Callable signature. None by default.
722722
templates : list of Basic, optional
723723
The template parameters of the Callable.
724+
attributes : list of str, optional
725+
Additional attributes to append to the Callable signature. An
726+
attributes is one or more keywords that appear in between the
727+
return type and the function name. None by default.
724728
"""
725729

726730
is_Callable = True
727731

728732
_traversable = ['body']
729733

730734
def __init__(self, name, body, retval, parameters=None, prefix=None,
731-
templates=None):
735+
templates=None, attributes=None):
732736
self.name = name
733737
if not isinstance(body, CallableBody):
734738
self.body = CallableBody(body)
@@ -738,6 +742,7 @@ def __init__(self, name, body, retval, parameters=None, prefix=None,
738742
self.prefix = as_tuple(prefix)
739743
self.parameters = as_tuple(parameters)
740744
self.templates = as_tuple(templates)
745+
self.attributes = as_tuple(attributes)
741746

742747
def __repr__(self):
743748
param_types = [ctypes_to_cstr(i._C_ctype) for i in self.parameters]

devito/ir/iet/visitors.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -315,14 +315,25 @@ def _args_call(self, args):
315315

316316
def _gen_signature(self, o, is_declaration=False):
317317
decls = self._args_decl(o.parameters)
318+
318319
prefix = ' '.join(o.prefix + (self._gen_rettype(o.retval),))
319-
signature = c.FunctionDeclaration(c.Value(prefix, o.name), decls)
320+
321+
if o.attributes:
322+
# NOTE: ugly, but I can't bother extending `c.FunctionDeclaration`
323+
# for such a tiny thing
324+
v = f"{' '.join(o.attributes)} {o.name}"
325+
else:
326+
v = o.name
327+
328+
signature = c.FunctionDeclaration(c.Value(prefix, v), decls)
329+
320330
if o.templates:
321331
tparams = ', '.join([i.inline() for i in self._args_decl(o.templates)])
322332
if is_declaration:
323333
signature = TemplateDecl(tparams, signature)
324334
else:
325335
signature = c.Template(tparams, signature)
336+
326337
return signature
327338

328339
def _blankline_logic(self, children):

devito/ir/support/properties.py

Lines changed: 54 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,21 @@ def __init__(self, name, val=None):
9595
A Dimension along which shared-memory prefetching is feasible and beneficial.
9696
"""
9797

98+
INIT_CORE_SHM = Property('init-core-shm')
99+
"""
100+
A Dimension along which the shared-memory CORE data region is initialized.
101+
"""
102+
103+
INIT_HALO_LEFT_SHM = Property('init-halo-left-shm')
104+
"""
105+
A Dimension along which the shared-memory left-HALO data region is initialized.
106+
"""
107+
108+
INIT_HALO_RIGHT_SHM = Property('init-halo-right-shm')
109+
"""
110+
A Dimension along which the shared-memory right-HALO data region is initialized.
111+
"""
112+
98113

99114
# Bundles
100115
PARALLELS = {PARALLEL, PARALLEL_INDEP, PARALLEL_IF_ATOMIC, PARALLEL_IF_PVT}
@@ -122,9 +137,13 @@ def normalize_properties(*args):
122137
else:
123138
drop = set()
124139

125-
# SEPARABLE <=> all are SEPARABLE
126-
if not all(SEPARABLE in p for p in args):
127-
drop.add(SEPARABLE)
140+
# A property X must be dropped if not all set of properties in `args`
141+
# contain X. For example, if one set of properties contains SEPARABLE and
142+
# another does not, then the resulting set of properties should not contain
143+
# SEPARABLE.
144+
for i in (SEPARABLE, INIT_CORE_SHM, INIT_HALO_LEFT_SHM, INIT_HALO_RIGHT_SHM):
145+
if not all(i in p for p in args):
146+
drop.add(i)
128147

129148
properties = set()
130149
for p in args:
@@ -190,6 +209,11 @@ def update_properties(properties, exprs):
190209
else:
191210
properties = properties.drop(properties=PREFETCHABLE_SHM)
192211

212+
# Remove properties that are trivially incompatible with `exprs`
213+
if not all(e.lhs.function._mem_shared for e in as_tuple(exprs)):
214+
drop = {INIT_CORE_SHM, INIT_HALO_LEFT_SHM, INIT_HALO_RIGHT_SHM}
215+
properties = properties.drop(properties=drop)
216+
193217
return properties
194218

195219

@@ -269,10 +293,16 @@ def block(self, dims, kind='default'):
269293
return Properties(m)
270294

271295
def inbound(self, dims):
272-
m = dict(self)
273-
for d in as_tuple(dims):
274-
m[d] = set(m.get(d, [])) | {INBOUND}
275-
return Properties(m)
296+
return self.add(dims, INBOUND)
297+
298+
def init_core_shm(self, dims):
299+
return self.add(dims, INIT_CORE_SHM)
300+
301+
def init_halo_left_shm(self, dims):
302+
return self.add(dims, INIT_HALO_LEFT_SHM)
303+
304+
def init_halo_right_shm(self, dims):
305+
return self.add(dims, INIT_HALO_RIGHT_SHM)
276306

277307
def is_parallel(self, dims):
278308
return any(len(self[d] & {PARALLEL, PARALLEL_INDEP}) > 0
@@ -299,13 +329,28 @@ def is_blockable(self, d):
299329
def is_blockable_small(self, d):
300330
return TILABLE_SMALL in self.get(d, set())
301331

302-
def is_prefetchable(self, dims=None, v=PREFETCHABLE):
332+
def _is_property_any(self, dims, v):
303333
if dims is None:
304334
dims = list(self)
305335
return any(v in self.get(d, set()) for d in as_tuple(dims))
306336

337+
def is_prefetchable(self, dims=None, v=PREFETCHABLE):
338+
return self._is_property_any(dims, PREFETCHABLE)
339+
307340
def is_prefetchable_shm(self, dims=None):
308-
return self.is_prefetchable(dims, PREFETCHABLE_SHM)
341+
return self._is_property_any(dims, PREFETCHABLE_SHM)
342+
343+
def is_core_init(self, dims=None):
344+
return self._is_property_any(dims, INIT_CORE_SHM)
345+
346+
def is_halo_left_init(self, dims=None):
347+
return self._is_property_any(dims, INIT_HALO_LEFT_SHM)
348+
349+
def is_halo_right_init(self, dims=None):
350+
return self._is_property_any(dims, INIT_HALO_RIGHT_SHM)
351+
352+
def is_halo_init(self, dims=None):
353+
return self.is_halo_left_init(dims) or self.is_halo_right_init(dims)
309354

310355
@property
311356
def nblockable(self):

devito/ir/support/space.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1056,6 +1056,10 @@ def outermost(self):
10561056
def innermost(self):
10571057
return self[-1]
10581058

1059+
@cached_property
1060+
def concrete(self):
1061+
return self.project(lambda d: not d.is_Virtual)
1062+
10591063
@cached_property
10601064
def itintervals(self):
10611065
return tuple(self[d] for d in self.itdims)

devito/operator/operator.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -703,6 +703,15 @@ def _postprocess_errors(self, retval):
703703
"due to excessive register pressure in one of the Operator "
704704
"kernels. Try supplying a smaller `par-tile` value."
705705
)
706+
elif retval == error_mapper['KernelLaunchClusterConfig']:
707+
raise ExecutionError(
708+
"Kernel launch failed due to an invalid thread block cluster "
709+
"configuration. This is probably due to a `tbc-tile` value that "
710+
"does not perfectly divide the number of blocks launched for a "
711+
"kernel. This is a known, strong limitation which effectively "
712+
"prevents the use of `tbc-tile` in realistic scenarios, but it "
713+
"will be removed in future versions."
714+
)
706715
elif retval == error_mapper['KernelLaunchUnknown']:
707716
raise ExecutionError(
708717
"Kernel launch failed due to an unknown error. This might "

devito/passes/clusters/cse.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,9 @@ def cse(cluster, sregistry=None, options=None, **kwargs):
7070
min_cost = options['cse-min-cost']
7171
mode = options['cse-algo']
7272

73+
if cluster.is_fence:
74+
return cluster
75+
7376
make = lambda: CTemp(name=sregistry.make_name(), dtype=cluster.dtype)
7477

7578
exprs = _cse(cluster, make, min_cost=min_cost, mode=mode)

devito/passes/clusters/misc.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -228,11 +228,12 @@ def _key(self, c):
228228
# If there are writes to thread-shared object, make it part of the key.
229229
# This will promote fusion of non-adjacent Clusters writing to (some
230230
# form of) shared memory, which in turn will minimize the number of
231-
# necessary barriers Same story for reads from thread-shared objects
231+
# necessary barriers. Same story for reads from thread-shared objects
232232
weak.extend([
233233
any(f._mem_shared for f in c.scope.writes),
234234
any(f._mem_shared for f in c.scope.reads)
235235
])
236+
weak.append(c.properties.is_core_init())
236237

237238
# Prefetchable Clusters should get merged, if possible
238239
weak.append(c.properties.is_prefetchable_shm())

devito/passes/iet/errors.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,5 +109,6 @@ class Retval(LocalObject, Expr):
109109
'Stability': 100,
110110
'KernelLaunch': 200,
111111
'KernelLaunchOutOfResources': 201,
112-
'KernelLaunchUnknown': 202,
112+
'KernelLaunchClusterConfig': 202,
113+
'KernelLaunchUnknown': 203,
113114
}

0 commit comments

Comments
 (0)