33import collections .abc
44import ctypes
55from functools import partial
6+ from types import LambdaType
67from typing import Callable , Union , Sequence
78
89import jax .numpy as jnp
1415from jax .lib import xla_client
1516from numba import types
1617from numba .core .dispatcher import Dispatcher
18+
1719import _cuda
1820
21+ _lambda_no = 0
1922ctypes .pythonapi .PyCapsule_New .argtypes = [
2023 ctypes .c_void_p , # void* pointer
2124 ctypes .c_char_p , # const char *name
2629
2730def _compile_gpu_signature (func , input_dtypes , input_shapes ,
2831 output_dtypes , output_shapes ):
29- from _cuda import (
30- cuMemcpyAsync ,
31- cuStreamSynchronize ,
32- memcpyHostToHost ,
33- memcpyHostToDevice ,
34- memcpyDeviceToHost ,
35- memcpyDeviceToDevice ,
32+ input_byte_size = tuple (
33+ np . prod ( shape ) * dtype . itemsize
34+ for ( shape , dtype ) in zip ( input_shapes , input_dtypes )
35+ )
36+ output_byte_size = tuple (
37+ np . prod ( shape ) * dtype . itemsize
38+ for ( shape , dtype ) in zip ( output_shapes , output_dtypes )
3639 )
3740
3841 code_scope = dict (
@@ -41,28 +44,28 @@ def _compile_gpu_signature(func, input_dtypes, input_shapes,
4144 input_dtypes = input_dtypes ,
4245 output_shapes = output_shapes ,
4346 output_dtypes = output_dtypes ,
44- carray = numba . carray ,
45- )
46-
47- input_byte_size = tuple (
48- np . prod ( shape ) * dtype . itemsize
49- for ( shape , dtype ) in zip ( input_shapes , input_dtypes )
50- )
51- output_byte_size = tuple (
52- np . prod ( shape ) * dtype . itemsize
53- for ( shape , dtype ) in zip ( output_shapes , output_dtypes )
47+ empty = np . empty ,
48+ input_byte_size = input_byte_size ,
49+ output_byte_size = output_byte_size ,
50+ cuMemcpyAsync = _cuda . cuMemcpyAsync ,
51+ cuStreamSynchronize = _cuda . cuStreamSynchronize ,
52+ memcpyHostToHost = _cuda . memcpyHostToHost ,
53+ memcpyHostToDevice = _cuda . memcpyHostToDevice ,
54+ memcpyDeviceToHost = _cuda . memcpyDeviceToHost ,
55+ memcpyDeviceToDevice = _cuda . memcpyDeviceToDevice ,
56+ n_in = len ( input_shapes ),
5457 )
5558
5659 args_in = [
57- f'carray(input_ptrs[ { i } ], input_shapes[{ i } ], dtype=input_dtypes[{ i } ])'
60+ f'empty( input_shapes[{ i } ], dtype=input_dtypes[{ i } ])'
5861 for i in range (len (input_shapes ))
5962 ]
6063 cuMemcpyAsync_in = [
6164 f'cuMemcpyAsync(args_in[{ i } ].ctypes.data, inout_gpu_ptrs[{ i } ], input_byte_size[{ i } ], memcpyDeviceToHost, stream)'
6265 for i in range (len (input_shapes ))
6366 ]
6467 args_out = [
65- f'carray(output_ptrs[ { i } ], output_shapes[{ i } ], dtype=output_dtypes[{ i } ])'
68+ f'empty( output_shapes[{ i } ], dtype=output_dtypes[{ i } ])'
6669 for i in range (len (output_shapes ))
6770 ]
6871 cuMemcpyAsync_out = [
@@ -83,11 +86,11 @@ def xla_gpu_custom_call_target(stream, inout_gpu_ptrs, opaque, opaque_len):
8386 cuStreamSynchronize(stream)
8487 func_to_call(args_out, args_in)
8588 {cuMemcpyAsync_out}
86- ''' .format (args_in = ",\n \t " .join (args_in ),
87- args_out = ",\n \t " .join (args_out ),
88- cuMemcpyAsync_in = ", \n \t " .join (cuMemcpyAsync_in ),
89- cuMemcpyAsync_out = ", \n \t " .join (cuMemcpyAsync_out ))
90- print (code_string )
89+ ''' .format (args_in = ",\n " .join (args_in ),
90+ args_out = ",\n " .join (args_out ),
91+ cuMemcpyAsync_in = "\n " .join (cuMemcpyAsync_in ),
92+ cuMemcpyAsync_out = "\n " .join (cuMemcpyAsync_out ))
93+ # print(code_string)
9194 exec (compile (code_string .strip (), '' , 'exec' ), code_scope )
9295
9396 new_f = code_scope ['xla_gpu_custom_call_target' ]
@@ -138,7 +141,7 @@ def register_gpu_op(
138141 if not _cuda .numba_cffi_loaded :
139142 raise RuntimeError ("Numba cffi could not be loaded." )
140143 # primitive
141- prim = core .Primitive (func .__name__ )
144+ prim = core .Primitive (f'_lambda_func { _lambda_no } ' if isinstance ( func , LambdaType ) else func .__name__ )
142145 prim .multiple_results = True
143146
144147 # user defined function
@@ -207,6 +210,10 @@ def custom_op(outs, ins):
207210
208211
209212 z = jnp .ones ((1 , 2 ), dtype = jnp .float32 )
210- jit_op = register_gpu_op (custom_op , abs_eval )
213+ op = register_gpu_op (custom_op , abs_eval )
214+
215+ from jax import jit
216+
217+ jit_op = jit (op )
211218
212219 print (jit_op (z , z ))
0 commit comments