Skip to content

Commit ef9ab20

Browse files
authored
Merge pull request #122 from PKU-NIP-Lab/numba4jax
Provide custom operators written in numba for jax jit
2 parents 918a0ca + b2a4c2b commit ef9ab20

File tree

5 files changed

+491
-0
lines changed

5 files changed

+491
-0
lines changed

extensions/brainpylib/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,6 @@
1010
from .event_prod import *
1111
from .atomic_sum import *
1212
from .atomic_prod import *
13+
from .operator.cpu_op import register_op_cpu
14+
from .operator.gpu_op import register_op_gpu
1315

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from cpu_op import *
2+
from gpu_op import *
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
import ctypes
2+
import ctypes.util
3+
import sys
4+
5+
from cffi import FFI
6+
from numba import cuda
7+
from numba import types
8+
9+
10+
class Dl_info(ctypes.Structure):
11+
"""
12+
Structure of the Dl_info returned by the CFFI of dl.dladdr
13+
"""
14+
15+
_fields_ = (
16+
("dli_fname", ctypes.c_char_p),
17+
("dli_fbase", ctypes.c_void_p),
18+
("dli_sname", ctypes.c_char_p),
19+
("dli_saddr", ctypes.c_void_p),
20+
)
21+
22+
23+
# Find the dynamic linker library path. Only works on unix-like os
24+
libdl_path = ctypes.util.find_library("dl")
25+
if libdl_path:
26+
# Load the dynamic linker dynamically
27+
libdl = ctypes.CDLL(libdl_path)
28+
29+
# Define dladdr to get the pointer to a symbol in a shared
30+
# library already loaded.
31+
# https://man7.org/linux/man-pages/man3/dladdr.3.html
32+
libdl.dladdr.argtypes = (ctypes.c_void_p, ctypes.POINTER(Dl_info))
33+
# restype is None as it returns by reference
34+
else:
35+
# On Windows it is nontrivial to have libdl, so we disable everything about
36+
# it and use other ways to find paths of libraries
37+
libdl = None
38+
39+
40+
def find_path_of_symbol_in_library(symbol):
41+
if libdl is None:
42+
raise ValueError("libdl not found.")
43+
44+
info = Dl_info()
45+
result = libdl.dladdr(symbol, ctypes.byref(info))
46+
if result and info.dli_fname:
47+
return info.dli_fname.decode(sys.getfilesystemencoding())
48+
else:
49+
raise ValueError("Cannot determine path of Library.")
50+
51+
52+
try:
53+
_libcuda = cuda.driver.find_driver()
54+
if sys.platform == "win32":
55+
libcuda_path = ctypes.util.find_library(_libcuda._name)
56+
else:
57+
libcuda_path = find_path_of_symbol_in_library(_libcuda.cuMemcpy)
58+
numba_cffi_loaded = True
59+
except Exception:
60+
numba_cffi_loaded = False
61+
62+
63+
if numba_cffi_loaded:
64+
# functions needed
65+
ffi = FFI()
66+
ffi.cdef("int cuMemcpy(void* dst, void* src, unsigned int len, int type);")
67+
ffi.cdef("int cuMemcpyAsync(void* dst, void* src, unsigned int len, int type, void* stream);")
68+
ffi.cdef("int cuStreamSynchronize(void* stream);")
69+
ffi.cdef("int cudaMallocHost(void** ptr, size_t size);")
70+
ffi.cdef("int cudaFreeHost(void* ptr);")
71+
72+
# load libraray
73+
# could ncuda.driver.find_library()
74+
libcuda = ffi.dlopen(libcuda_path)
75+
cuMemcpy = libcuda.cuMemcpy
76+
cuMemcpyAsync = libcuda.cuMemcpyAsync
77+
cuStreamSynchronize = libcuda.cuStreamSynchronize
78+
79+
memcpyHostToHost = types.int32(0)
80+
memcpyHostToDevice = types.int32(1)
81+
memcpyDeviceToHost = types.int32(2)
82+
memcpyDeviceToDevice = types.int32(3)
Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
# -*- coding: utf-8 -*-
2+
3+
import collections.abc
4+
import ctypes
5+
from functools import partial
6+
from types import LambdaType
7+
from typing import Callable, Union, Sequence
8+
9+
import jax.numpy as jnp
10+
import numba
11+
import numpy as np
12+
from jax import core
13+
from jax.abstract_arrays import ShapedArray
14+
from jax.interpreters import xla
15+
from jax.lib import xla_client
16+
from numba import types
17+
from numba.core.dispatcher import Dispatcher
18+
19+
_lambda_no = 0
20+
ctypes.pythonapi.PyCapsule_New.argtypes = [
21+
ctypes.c_void_p, # void* pointer
22+
ctypes.c_char_p, # const char *name
23+
ctypes.c_void_p, # PyCapsule_Destructor destructor
24+
]
25+
ctypes.pythonapi.PyCapsule_New.restype = ctypes.py_object
26+
27+
28+
def _compile_cpu_signature(func, input_dtypes, input_shapes,
29+
output_dtypes, output_shapes):
30+
code_scope = dict(
31+
func_to_call=func,
32+
input_shapes=input_shapes,
33+
input_dtypes=input_dtypes,
34+
output_shapes=output_shapes,
35+
output_dtypes=output_dtypes,
36+
carray=numba.carray,
37+
)
38+
39+
args_in = [
40+
f'carray(input_ptrs[{i}], input_shapes[{i}], dtype=input_dtypes[{i}])'
41+
for i in range(len(input_shapes))
42+
]
43+
args_out = [
44+
f'carray(output_ptrs[{i}], output_shapes[{i}], dtype=output_dtypes[{i}])'
45+
for i in range(len(output_shapes))
46+
]
47+
48+
code_string = '''
49+
def xla_cpu_custom_call_target(output_ptrs, input_ptrs):
50+
args_out = (
51+
{args_out}
52+
)
53+
args_in = (
54+
{args_in}
55+
)
56+
func_to_call(args_out, args_in)
57+
'''.format(args_in=",\n ".join(args_in),
58+
args_out=",\n ".join(args_out))
59+
# print(code_string)
60+
exec(compile(code_string.strip(), '', 'exec'), code_scope)
61+
62+
new_f = code_scope['xla_cpu_custom_call_target']
63+
wrapper = numba.cfunc(types.void(types.CPointer(types.voidptr),
64+
types.CPointer(types.voidptr)))
65+
xla_c_rule = wrapper(new_f)
66+
target_name = xla_c_rule.native_name.encode("ascii")
67+
capsule = ctypes.pythonapi.PyCapsule_New(
68+
xla_c_rule.address, # A CFFI pointer to a function
69+
b"xla._CUSTOM_CALL_TARGET", # A binary string
70+
None # PyCapsule object run at destruction
71+
)
72+
xla_client.register_custom_call_target(target_name, capsule, "cpu")
73+
return target_name
74+
75+
76+
def _func_translation(func, abs_eval_fn, c, *inputs):
77+
input_shapes = [c.get_shape(arg) for arg in inputs]
78+
input_dtypes = tuple(shape.element_type() for shape in input_shapes)
79+
input_dimensions = tuple(shape.dimensions() for shape in input_shapes)
80+
output_abstract_arrays = abs_eval_fn(*tuple(ShapedArray(shape.dimensions(), shape.element_type())
81+
for shape in input_shapes))
82+
output_shapes = tuple(array.shape for array in output_abstract_arrays)
83+
output_dtypes = tuple(array.dtype for array in output_abstract_arrays)
84+
output_layouts = map(lambda shape: range(len(shape) - 1, -1, -1), output_shapes)
85+
xla_output_shapes = [xla_client.Shape.array_shape(*arg)
86+
for arg in zip(output_dtypes, output_shapes, output_layouts)]
87+
xla_output_shape = xla_client.Shape.tuple_shape(xla_output_shapes)
88+
target_name = _compile_cpu_signature(func,
89+
input_dtypes, input_dimensions,
90+
output_dtypes, output_shapes)
91+
92+
return xla_client.ops.CustomCallWithLayout(
93+
c,
94+
target_name,
95+
operands=inputs,
96+
operand_shapes_with_layout=input_shapes,
97+
shape_with_layout=xla_output_shape,
98+
)
99+
100+
101+
def register_cpu_op(
102+
func: Callable,
103+
out_shapes: Union[Callable, ShapedArray, Sequence[ShapedArray]]
104+
):
105+
# primitive
106+
prim = core.Primitive(f'_lambda_func{_lambda_no}'
107+
if (isinstance(func, LambdaType) and func.__name__ == "<lambda>")
108+
else func.__name__)
109+
prim.multiple_results = True
110+
111+
# user defined function
112+
if not isinstance(func, Dispatcher):
113+
func = numba.jit(fastmath=True, nopython=True)(func)
114+
115+
# output shape evaluation function
116+
def abs_eval_rule(*input_shapes):
117+
if callable(out_shapes):
118+
shapes = out_shapes(*input_shapes)
119+
elif isinstance(out_shapes, ShapedArray):
120+
shapes = [out_shapes]
121+
elif isinstance(out_shapes, (tuple, list)):
122+
shapes = out_shapes
123+
for elem in out_shapes:
124+
if not isinstance(elem, ShapedArray):
125+
raise ValueError(f'Elements in "out_shapes" must be instances of '
126+
f'jax.abstract_arrays.ShapedArray, but we got '
127+
f'{type(elem)}: {elem}')
128+
else:
129+
raise ValueError(f'Unknown type {type(out_shapes)}, only '
130+
f'supports function, ShapedArray or '
131+
f'list/tuple of ShapedArray.')
132+
133+
# output shapes
134+
if not isinstance(shapes, collections.abc.Collection):
135+
return [shapes]
136+
else:
137+
return shapes
138+
139+
# output evaluation function
140+
def eval_rule(*inputs):
141+
# compute the output shapes
142+
output_shapes = abs_eval_rule(*inputs)
143+
# Preallocate the outputs
144+
outputs = tuple(np.zeros(shape.shape, dtype=shape.dtype) for shape in output_shapes)
145+
# convert inputs to a tuple
146+
inputs = tuple(np.asarray(arg) for arg in inputs)
147+
# call the kernel
148+
func(outputs, inputs)
149+
# Return the outputs
150+
return tuple(outputs)
151+
152+
def bind_primitive(*inputs):
153+
result = prim.bind(*inputs)
154+
return result[0] if len(result) == 1 else result
155+
156+
# binding
157+
prim.def_abstract_eval(abs_eval_rule)
158+
prim.def_impl(eval_rule)
159+
# registering
160+
xla.backend_specific_translations['cpu'][prim] = partial(_func_translation, func, abs_eval_rule)
161+
return bind_primitive
162+
163+
164+
if __name__ == '__main__':
165+
def abs_eval(*ins):
166+
return ins
167+
168+
import brainpy as bp
169+
bp.math.set_platform('cpu')
170+
171+
def custom_op(outs, ins):
172+
y, y1 = outs
173+
x, x2 = ins
174+
y[:] = x + 1
175+
y1[:] = x2 + 2
176+
177+
178+
z = jnp.ones((1, 2), dtype=jnp.float32)
179+
op = register_cpu_op(custom_op, abs_eval)
180+
181+
from jax import jit
182+
jit_op = jit(op)
183+
184+
print(jit_op(z, z))

0 commit comments

Comments
 (0)