Skip to content

Commit b72d95f

Browse files
committed
do not fail on extruded meshes
1 parent 6b665db commit b72d95f

File tree

2 files changed

+17
-6
lines changed

2 files changed

+17
-6
lines changed

pyop2/gpu/cuda.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -85,10 +85,21 @@ def _kernel_args_(self):
8585

8686
class Arg(Arg):
8787
"""
88-
Arg for GPU
88+
Arg for GPU.
8989
"""
9090

9191

92+
class ExtrudedSet(ExtrudedSet):
93+
"""
94+
ExtrudedSet for GPU.
95+
"""
96+
@cached_property
97+
def _kernel_args_(self):
98+
m_gpu = cuda.mem_alloc(int(self.layers_array.nbytes))
99+
cuda.memcpy_htod(m_gpu, self.layers_array)
100+
return (m_gpu,)
101+
102+
92103
class Dat(petsc_Dat):
93104
"""
94105
Dat for GPU.
@@ -373,11 +384,9 @@ def argtypes(self):
373384
@cached_property
374385
def argshapes(self):
375386
argshapes = ((), ())
376-
# argtypes += self._iterset._argtypes_
377387
if self._iterset._argtypes_:
378-
raise NotImplementedError("Do not know what to do when"
379-
" self._iterset._argtypes is not empty, is this the case"
380-
" when we have extruded mesh")
388+
# TODO: verify that this bogus value doesn't affect anyone.
389+
argshapes += ((), )
381390

382391
for arg in self._args:
383392
argshapes += (arg.data.shape, )

pyop2/gpu/snpt.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33

44
def snpt_transform(kernel, block_size):
55
"""
6-
SNPT := Single 'n' Per Thread transformation.
6+
SNPT := Single 'n' Per Thread.
7+
8+
Implements outer-loop parallelization strategy.
79
810
PyOP2 uses 'n' as the outer loop iname. In Firedrake 'n' might denote
911
either a cell or a DOF.

0 commit comments

Comments
 (0)