Skip to content

Commit 6ae57ba

Browse files
committed
bug: fix gpu op custom bug
1 parent 3e1600a commit 6ae57ba

File tree

1 file changed

+33
-26
lines changed

1 file changed

+33
-26
lines changed

extensions/brainpylib/operator/gpu_op.py

Lines changed: 33 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import collections.abc
44
import ctypes
55
from functools import partial
6+
from types import LambdaType
67
from typing import Callable, Union, Sequence
78

89
import jax.numpy as jnp
@@ -14,8 +15,10 @@
1415
from jax.lib import xla_client
1516
from numba import types
1617
from numba.core.dispatcher import Dispatcher
18+
1719
import _cuda
1820

21+
_lambda_no = 0
1922
ctypes.pythonapi.PyCapsule_New.argtypes = [
2023
ctypes.c_void_p, # void* pointer
2124
ctypes.c_char_p, # const char *name
@@ -26,13 +29,13 @@
2629

2730
def _compile_gpu_signature(func, input_dtypes, input_shapes,
2831
output_dtypes, output_shapes):
29-
from _cuda import (
30-
cuMemcpyAsync,
31-
cuStreamSynchronize,
32-
memcpyHostToHost,
33-
memcpyHostToDevice,
34-
memcpyDeviceToHost,
35-
memcpyDeviceToDevice,
32+
input_byte_size = tuple(
33+
np.prod(shape) * dtype.itemsize
34+
for (shape, dtype) in zip(input_shapes, input_dtypes)
35+
)
36+
output_byte_size = tuple(
37+
np.prod(shape) * dtype.itemsize
38+
for (shape, dtype) in zip(output_shapes, output_dtypes)
3639
)
3740

3841
code_scope = dict(
@@ -41,28 +44,28 @@ def _compile_gpu_signature(func, input_dtypes, input_shapes,
4144
input_dtypes=input_dtypes,
4245
output_shapes=output_shapes,
4346
output_dtypes=output_dtypes,
44-
carray=numba.carray,
45-
)
46-
47-
input_byte_size = tuple(
48-
np.prod(shape) * dtype.itemsize
49-
for (shape, dtype) in zip(input_shapes, input_dtypes)
50-
)
51-
output_byte_size = tuple(
52-
np.prod(shape) * dtype.itemsize
53-
for (shape, dtype) in zip(output_shapes, output_dtypes)
47+
empty=np.empty,
48+
input_byte_size=input_byte_size,
49+
output_byte_size=output_byte_size,
50+
cuMemcpyAsync=_cuda.cuMemcpyAsync,
51+
cuStreamSynchronize=_cuda.cuStreamSynchronize,
52+
memcpyHostToHost=_cuda.memcpyHostToHost,
53+
memcpyHostToDevice=_cuda.memcpyHostToDevice,
54+
memcpyDeviceToHost=_cuda.memcpyDeviceToHost,
55+
memcpyDeviceToDevice=_cuda.memcpyDeviceToDevice,
56+
n_in=len(input_shapes),
5457
)
5558

5659
args_in = [
57-
f'carray(input_ptrs[{i}], input_shapes[{i}], dtype=input_dtypes[{i}])'
60+
f'empty(input_shapes[{i}], dtype=input_dtypes[{i}])'
5861
for i in range(len(input_shapes))
5962
]
6063
cuMemcpyAsync_in = [
6164
f'cuMemcpyAsync(args_in[{i}].ctypes.data, inout_gpu_ptrs[{i}], input_byte_size[{i}], memcpyDeviceToHost, stream)'
6265
for i in range(len(input_shapes))
6366
]
6467
args_out = [
65-
f'carray(output_ptrs[{i}], output_shapes[{i}], dtype=output_dtypes[{i}])'
68+
f'empty(output_shapes[{i}], dtype=output_dtypes[{i}])'
6669
for i in range(len(output_shapes))
6770
]
6871
cuMemcpyAsync_out = [
@@ -83,11 +86,11 @@ def xla_gpu_custom_call_target(stream, inout_gpu_ptrs, opaque, opaque_len):
8386
cuStreamSynchronize(stream)
8487
func_to_call(args_out, args_in)
8588
{cuMemcpyAsync_out}
86-
'''.format(args_in=",\n\t".join(args_in),
87-
args_out=",\n\t".join(args_out),
88-
cuMemcpyAsync_in=",\n\t".join(cuMemcpyAsync_in),
89-
cuMemcpyAsync_out=",\n\t".join(cuMemcpyAsync_out))
90-
print(code_string)
89+
'''.format(args_in=",\n ".join(args_in),
90+
args_out=",\n ".join(args_out),
91+
cuMemcpyAsync_in="\n ".join(cuMemcpyAsync_in),
92+
cuMemcpyAsync_out="\n ".join(cuMemcpyAsync_out))
93+
# print(code_string)
9194
exec(compile(code_string.strip(), '', 'exec'), code_scope)
9295

9396
new_f = code_scope['xla_gpu_custom_call_target']
@@ -138,7 +141,7 @@ def register_gpu_op(
138141
if not _cuda.numba_cffi_loaded:
139142
raise RuntimeError("Numba cffi could not be loaded.")
140143
# primitive
141-
prim = core.Primitive(func.__name__)
144+
prim = core.Primitive(f'_lambda_func{_lambda_no}' if isinstance(func, LambdaType) else func.__name__)
142145
prim.multiple_results = True
143146

144147
# user defined function
@@ -207,6 +210,10 @@ def custom_op(outs, ins):
207210

208211

209212
z = jnp.ones((1, 2), dtype=jnp.float32)
210-
jit_op = register_gpu_op(custom_op, abs_eval)
213+
op = register_gpu_op(custom_op, abs_eval)
214+
215+
from jax import jit
216+
217+
jit_op = jit(op)
211218

212219
print(jit_op(z, z))

0 commit comments

Comments
 (0)