Skip to content

Commit b2a4c2b

Browse files
committed
feat: support function name retrieve of a lambda function
1 parent 6ae57ba commit b2a4c2b

File tree

3 files changed

+76
-71
lines changed

3 files changed

+76
-71
lines changed
Lines changed: 59 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -1,88 +1,82 @@
1-
import sys
21
import ctypes
32
import ctypes.util
4-
from numba import cuda as ncuda
5-
from cffi import FFI
3+
import sys
64

5+
from cffi import FFI
6+
from numba import cuda
77
from numba import types
88

9+
910
class Dl_info(ctypes.Structure):
10-
"""
11-
Structure of the Dl_info returned by the CFFI of dl.dladdr
12-
"""
11+
"""
12+
Structure of the Dl_info returned by the CFFI of dl.dladdr
13+
"""
1314

14-
_fields_ = (
15-
("dli_fname", ctypes.c_char_p),
16-
("dli_fbase", ctypes.c_void_p),
17-
("dli_sname", ctypes.c_char_p),
18-
("dli_saddr", ctypes.c_void_p),
19-
)
15+
_fields_ = (
16+
("dli_fname", ctypes.c_char_p),
17+
("dli_fbase", ctypes.c_void_p),
18+
("dli_sname", ctypes.c_char_p),
19+
("dli_saddr", ctypes.c_void_p),
20+
)
2021

2122

2223
# Find the dynamic linker library path. Only works on unix-like os
2324
libdl_path = ctypes.util.find_library("dl")
2425
if libdl_path:
25-
# Load the dynamic linker dynamically
26-
libdl = ctypes.CDLL(libdl_path)
27-
28-
# Define dladdr to get the pointer to a symbol in a shared
29-
# library already loaded.
30-
# https://man7.org/linux/man-pages/man3/dladdr.3.html
31-
libdl.dladdr.argtypes = (ctypes.c_void_p, ctypes.POINTER(Dl_info))
32-
# restype is None as it returns by reference
26+
# Load the dynamic linker dynamically
27+
libdl = ctypes.CDLL(libdl_path)
28+
29+
# Define dladdr to get the pointer to a symbol in a shared
30+
# library already loaded.
31+
# https://man7.org/linux/man-pages/man3/dladdr.3.html
32+
libdl.dladdr.argtypes = (ctypes.c_void_p, ctypes.POINTER(Dl_info))
33+
# restype is None as it returns by reference
3334
else:
34-
# On Windows it is nontrivial to have libdl, so we disable everything about
35-
# it and use other ways to find paths of libraries
36-
libdl = None
35+
# On Windows it is nontrivial to have libdl, so we disable everything about
36+
# it and use other ways to find paths of libraries
37+
libdl = None
3738

3839

3940
def find_path_of_symbol_in_library(symbol):
40-
if libdl is None:
41-
raise ValueError("libdl not found.")
42-
43-
info = Dl_info()
41+
if libdl is None:
42+
raise ValueError("libdl not found.")
4443

45-
result = libdl.dladdr(symbol, ctypes.byref(info))
46-
47-
if result and info.dli_fname:
48-
return info.dli_fname.decode(sys.getfilesystemencoding())
49-
else:
50-
raise ValueError("Cannot determine path of Library.")
44+
info = Dl_info()
45+
result = libdl.dladdr(symbol, ctypes.byref(info))
46+
if result and info.dli_fname:
47+
return info.dli_fname.decode(sys.getfilesystemencoding())
48+
else:
49+
raise ValueError("Cannot determine path of Library.")
5150

5251

5352
try:
54-
_libcuda = ncuda.driver.find_driver()
55-
56-
if sys.platform == "win32":
57-
libcuda_path = ctypes.util.find_library(_libcuda._name)
58-
else:
59-
libcuda_path = find_path_of_symbol_in_library(_libcuda.cuMemcpy)
60-
61-
numba_cffi_loaded = True
53+
_libcuda = cuda.driver.find_driver()
54+
if sys.platform == "win32":
55+
libcuda_path = ctypes.util.find_library(_libcuda._name)
56+
else:
57+
libcuda_path = find_path_of_symbol_in_library(_libcuda.cuMemcpy)
58+
numba_cffi_loaded = True
6259
except Exception:
63-
numba_cffi_loaded = False
60+
numba_cffi_loaded = False
6461

65-
if numba_cffi_loaded:
6662

67-
# functions needed
68-
ffi = FFI()
69-
ffi.cdef("int cuMemcpy(void* dst, void* src, unsigned int len, int type);")
70-
ffi.cdef(
71-
"int cuMemcpyAsync(void* dst, void* src, unsigned int len, int type, void* stream);"
72-
)
73-
ffi.cdef("int cuStreamSynchronize(void* stream);")
74-
75-
ffi.cdef("int cudaMallocHost(void** ptr, size_t size);")
76-
ffi.cdef("int cudaFreeHost(void* ptr);")
77-
78-
# load libraray
79-
# could ncuda.driver.find_library()
80-
libcuda = ffi.dlopen(libcuda_path)
81-
cuMemcpy = libcuda.cuMemcpy
82-
cuMemcpyAsync = libcuda.cuMemcpyAsync
83-
cuStreamSynchronize = libcuda.cuStreamSynchronize
84-
85-
memcpyHostToHost = types.int32(0)
86-
memcpyHostToDevice = types.int32(1)
87-
memcpyDeviceToHost = types.int32(2)
88-
memcpyDeviceToDevice = types.int32(3)
63+
if numba_cffi_loaded:
64+
# functions needed
65+
ffi = FFI()
66+
ffi.cdef("int cuMemcpy(void* dst, void* src, unsigned int len, int type);")
67+
ffi.cdef("int cuMemcpyAsync(void* dst, void* src, unsigned int len, int type, void* stream);")
68+
ffi.cdef("int cuStreamSynchronize(void* stream);")
69+
ffi.cdef("int cudaMallocHost(void** ptr, size_t size);")
70+
ffi.cdef("int cudaFreeHost(void* ptr);")
71+
72+
# load libraray
73+
# could ncuda.driver.find_library()
74+
libcuda = ffi.dlopen(libcuda_path)
75+
cuMemcpy = libcuda.cuMemcpy
76+
cuMemcpyAsync = libcuda.cuMemcpyAsync
77+
cuStreamSynchronize = libcuda.cuStreamSynchronize
78+
79+
memcpyHostToHost = types.int32(0)
80+
memcpyHostToDevice = types.int32(1)
81+
memcpyDeviceToHost = types.int32(2)
82+
memcpyDeviceToDevice = types.int32(3)

extensions/brainpylib/operator/cpu_op.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import collections.abc
44
import ctypes
55
from functools import partial
6+
from types import LambdaType
67
from typing import Callable, Union, Sequence
78

89
import jax.numpy as jnp
@@ -15,6 +16,7 @@
1516
from numba import types
1617
from numba.core.dispatcher import Dispatcher
1718

19+
_lambda_no = 0
1820
ctypes.pythonapi.PyCapsule_New.argtypes = [
1921
ctypes.c_void_p, # void* pointer
2022
ctypes.c_char_p, # const char *name
@@ -52,9 +54,9 @@ def xla_cpu_custom_call_target(output_ptrs, input_ptrs):
5254
{args_in}
5355
)
5456
func_to_call(args_out, args_in)
55-
'''.format(args_in=",\n\t".join(args_in),
56-
args_out=",\n\t".join(args_out))
57-
print(code_string)
57+
'''.format(args_in=",\n ".join(args_in),
58+
args_out=",\n ".join(args_out))
59+
# print(code_string)
5860
exec(compile(code_string.strip(), '', 'exec'), code_scope)
5961

6062
new_f = code_scope['xla_cpu_custom_call_target']
@@ -101,7 +103,9 @@ def register_cpu_op(
101103
out_shapes: Union[Callable, ShapedArray, Sequence[ShapedArray]]
102104
):
103105
# primitive
104-
prim = core.Primitive(func.__name__)
106+
prim = core.Primitive(f'_lambda_func{_lambda_no}'
107+
if (isinstance(func, LambdaType) and func.__name__ == "<lambda>")
108+
else func.__name__)
105109
prim.multiple_results = True
106110

107111
# user defined function
@@ -161,6 +165,8 @@ def bind_primitive(*inputs):
161165
def abs_eval(*ins):
162166
return ins
163167

168+
import brainpy as bp
169+
bp.math.set_platform('cpu')
164170

165171
def custom_op(outs, ins):
166172
y, y1 = outs
@@ -170,6 +176,9 @@ def custom_op(outs, ins):
170176

171177

172178
z = jnp.ones((1, 2), dtype=jnp.float32)
173-
jit_op = register_cpu_op(custom_op, abs_eval)
179+
op = register_cpu_op(custom_op, abs_eval)
180+
181+
from jax import jit
182+
jit_op = jit(op)
174183

175184
print(jit_op(z, z))

extensions/brainpylib/operator/gpu_op.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,9 @@ def register_gpu_op(
141141
if not _cuda.numba_cffi_loaded:
142142
raise RuntimeError("Numba cffi could not be loaded.")
143143
# primitive
144-
prim = core.Primitive(f'_lambda_func{_lambda_no}' if isinstance(func, LambdaType) else func.__name__)
144+
prim = core.Primitive(f'_lambda_func{_lambda_no}'
145+
if (isinstance(func, LambdaType) and func.__name__ == "<lambda>")
146+
else func.__name__)
145147
prim.multiple_results = True
146148

147149
# user defined function

0 commit comments

Comments
 (0)