Skip to content

Commit dd7717a

Browse files
committed
port Subset to cuda
1 parent b72d95f commit dd7717a

File tree

1 file changed

+13
-2
lines changed

1 file changed

+13
-2
lines changed

pyop2/gpu/cuda.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,17 @@ def _kernel_args_(self):
100100
return (m_gpu,)
101101

102102

103+
class Subset(Subset):
104+
"""
105+
ExtrudedSet for GPU.
106+
"""
107+
@cached_property
108+
def _kernel_args_(self):
109+
m_gpu = cuda.mem_alloc(int(self._indices.nbytes))
110+
cuda.memcpy_htod(m_gpu, self._indices)
111+
return self._superset._kernel_args_ + (m_gpu, )
112+
113+
103114
class Dat(petsc_Dat):
104115
"""
105116
Dat for GPU.
@@ -385,7 +396,7 @@ def argtypes(self):
385396
def argshapes(self):
386397
argshapes = ((), ())
387398
if self._iterset._argtypes_:
388-
# TODO: verify that this bogus value doesn't affect anyone.
399+
# FIXME: Do not put in a bogus value
389400
argshapes += ((), )
390401

391402
for arg in self._args:
@@ -605,7 +616,7 @@ def insn_needs_atomic(insn):
605616
raise ValueError("gpu_strategy can be 'scpt',"
606617
" 'user_specified_tile' or 'auto_tile'.")
607618
elif program.name in ["wrap_zero", "wrap_expression_kernel",
608-
"wrap_pyop2_kernel_uniform_extrusion",
619+
"wrap_expression", "wrap_pyop2_kernel_uniform_extrusion",
609620
"wrap_form_cell_integral_otherwise",
610621
]:
611622
from pyop2.gpu.snpt import snpt_transform

0 commit comments

Comments
 (0)