Skip to content

Commit a5eb20a

Browse files
committed
port Subset to cuda
1 parent b72d95f commit a5eb20a

File tree

1 file changed

+16
-1
lines changed

1 file changed

+16
-1
lines changed

pyop2/gpu/cuda.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,17 @@ def _kernel_args_(self):
100100
return (m_gpu,)
101101

102102

103+
class Subset(Subset):
104+
"""
105+
ExtrudedSet for GPU.
106+
"""
107+
@cached_property
108+
def _kernel_args_(self):
109+
m_gpu = cuda.mem_alloc(int(self._indices.nbytes))
110+
cuda.memcpy_htod(m_gpu, self._indices)
111+
return self._superset._kernel_args_ + (m_gpu, )
112+
113+
103114
class Dat(petsc_Dat):
104115
"""
105116
Dat for GPU.
@@ -318,8 +329,11 @@ def code_to_compile(self):
318329
builder.set_kernel(self._kernel)
319330

320331
wrapper = generate(builder)
332+
print('Compiling...', wrapper.name)
321333

322334
code, processed_program, args_to_make_global = generate_gpu_kernel(wrapper, self.args, self.argshapes)
335+
336+
print(code)
323337
for i, arg_to_make_global in enumerate(args_to_make_global):
324338
numpy.save(self.ith_added_global_arg_i(i),
325339
arg_to_make_global)
@@ -386,6 +400,7 @@ def argshapes(self):
386400
argshapes = ((), ())
387401
if self._iterset._argtypes_:
388402
# TODO: verify that this bogus value doesn't affect anyone.
403+
# raise NotImplementedError()
389404
argshapes += ((), )
390405

391406
for arg in self._args:
@@ -605,7 +620,7 @@ def insn_needs_atomic(insn):
605620
raise ValueError("gpu_strategy can be 'scpt',"
606621
" 'user_specified_tile' or 'auto_tile'.")
607622
elif program.name in ["wrap_zero", "wrap_expression_kernel",
608-
"wrap_pyop2_kernel_uniform_extrusion",
623+
"wrap_expression", "wrap_pyop2_kernel_uniform_extrusion",
609624
"wrap_form_cell_integral_otherwise",
610625
]:
611626
from pyop2.gpu.snpt import snpt_transform

0 commit comments

Comments
 (0)