|
| 1 | +# -*- coding: utf-8 -*- |
| 2 | + |
| 3 | +import collections.abc |
| 4 | +import ctypes |
| 5 | +from functools import partial |
| 6 | +from types import LambdaType |
| 7 | +from typing import Callable, Union, Sequence |
| 8 | + |
| 9 | +import jax.numpy as jnp |
| 10 | +import numba |
| 11 | +import numpy as np |
| 12 | +from jax import core |
| 13 | +from jax.abstract_arrays import ShapedArray |
| 14 | +from jax.interpreters import xla |
| 15 | +from jax.lib import xla_client |
| 16 | +from numba import types |
| 17 | +from numba.core.dispatcher import Dispatcher |
| 18 | + |
| 19 | +_lambda_no = 0 |
| 20 | +ctypes.pythonapi.PyCapsule_New.argtypes = [ |
| 21 | + ctypes.c_void_p, # void* pointer |
| 22 | + ctypes.c_char_p, # const char *name |
| 23 | + ctypes.c_void_p, # PyCapsule_Destructor destructor |
| 24 | +] |
| 25 | +ctypes.pythonapi.PyCapsule_New.restype = ctypes.py_object |
| 26 | + |
| 27 | + |
| 28 | +def _compile_cpu_signature(func, input_dtypes, input_shapes, |
| 29 | + output_dtypes, output_shapes): |
| 30 | + code_scope = dict( |
| 31 | + func_to_call=func, |
| 32 | + input_shapes=input_shapes, |
| 33 | + input_dtypes=input_dtypes, |
| 34 | + output_shapes=output_shapes, |
| 35 | + output_dtypes=output_dtypes, |
| 36 | + carray=numba.carray, |
| 37 | + ) |
| 38 | + |
| 39 | + args_in = [ |
| 40 | + f'carray(input_ptrs[{i}], input_shapes[{i}], dtype=input_dtypes[{i}])' |
| 41 | + for i in range(len(input_shapes)) |
| 42 | + ] |
| 43 | + args_out = [ |
| 44 | + f'carray(output_ptrs[{i}], output_shapes[{i}], dtype=output_dtypes[{i}])' |
| 45 | + for i in range(len(output_shapes)) |
| 46 | + ] |
| 47 | + |
| 48 | + code_string = ''' |
| 49 | +def xla_cpu_custom_call_target(output_ptrs, input_ptrs): |
| 50 | + args_out = ( |
| 51 | + {args_out} |
| 52 | + ) |
| 53 | + args_in = ( |
| 54 | + {args_in} |
| 55 | + ) |
| 56 | + func_to_call(args_out, args_in) |
| 57 | + '''.format(args_in=",\n ".join(args_in), |
| 58 | + args_out=",\n ".join(args_out)) |
| 59 | + # print(code_string) |
| 60 | + exec(compile(code_string.strip(), '', 'exec'), code_scope) |
| 61 | + |
| 62 | + new_f = code_scope['xla_cpu_custom_call_target'] |
| 63 | + wrapper = numba.cfunc(types.void(types.CPointer(types.voidptr), |
| 64 | + types.CPointer(types.voidptr))) |
| 65 | + xla_c_rule = wrapper(new_f) |
| 66 | + target_name = xla_c_rule.native_name.encode("ascii") |
| 67 | + capsule = ctypes.pythonapi.PyCapsule_New( |
| 68 | + xla_c_rule.address, # A CFFI pointer to a function |
| 69 | + b"xla._CUSTOM_CALL_TARGET", # A binary string |
| 70 | + None # PyCapsule object run at destruction |
| 71 | + ) |
| 72 | + xla_client.register_custom_call_target(target_name, capsule, "cpu") |
| 73 | + return target_name |
| 74 | + |
| 75 | + |
| 76 | +def _func_translation(func, abs_eval_fn, c, *inputs): |
| 77 | + input_shapes = [c.get_shape(arg) for arg in inputs] |
| 78 | + input_dtypes = tuple(shape.element_type() for shape in input_shapes) |
| 79 | + input_dimensions = tuple(shape.dimensions() for shape in input_shapes) |
| 80 | + output_abstract_arrays = abs_eval_fn(*tuple(ShapedArray(shape.dimensions(), shape.element_type()) |
| 81 | + for shape in input_shapes)) |
| 82 | + output_shapes = tuple(array.shape for array in output_abstract_arrays) |
| 83 | + output_dtypes = tuple(array.dtype for array in output_abstract_arrays) |
| 84 | + output_layouts = map(lambda shape: range(len(shape) - 1, -1, -1), output_shapes) |
| 85 | + xla_output_shapes = [xla_client.Shape.array_shape(*arg) |
| 86 | + for arg in zip(output_dtypes, output_shapes, output_layouts)] |
| 87 | + xla_output_shape = xla_client.Shape.tuple_shape(xla_output_shapes) |
| 88 | + target_name = _compile_cpu_signature(func, |
| 89 | + input_dtypes, input_dimensions, |
| 90 | + output_dtypes, output_shapes) |
| 91 | + |
| 92 | + return xla_client.ops.CustomCallWithLayout( |
| 93 | + c, |
| 94 | + target_name, |
| 95 | + operands=inputs, |
| 96 | + operand_shapes_with_layout=input_shapes, |
| 97 | + shape_with_layout=xla_output_shape, |
| 98 | + ) |
| 99 | + |
| 100 | + |
| 101 | +def register_cpu_op( |
| 102 | + func: Callable, |
| 103 | + out_shapes: Union[Callable, ShapedArray, Sequence[ShapedArray]] |
| 104 | +): |
| 105 | + # primitive |
| 106 | + prim = core.Primitive(f'_lambda_func{_lambda_no}' |
| 107 | + if (isinstance(func, LambdaType) and func.__name__ == "<lambda>") |
| 108 | + else func.__name__) |
| 109 | + prim.multiple_results = True |
| 110 | + |
| 111 | + # user defined function |
| 112 | + if not isinstance(func, Dispatcher): |
| 113 | + func = numba.jit(fastmath=True, nopython=True)(func) |
| 114 | + |
| 115 | + # output shape evaluation function |
| 116 | + def abs_eval_rule(*input_shapes): |
| 117 | + if callable(out_shapes): |
| 118 | + shapes = out_shapes(*input_shapes) |
| 119 | + elif isinstance(out_shapes, ShapedArray): |
| 120 | + shapes = [out_shapes] |
| 121 | + elif isinstance(out_shapes, (tuple, list)): |
| 122 | + shapes = out_shapes |
| 123 | + for elem in out_shapes: |
| 124 | + if not isinstance(elem, ShapedArray): |
| 125 | + raise ValueError(f'Elements in "out_shapes" must be instances of ' |
| 126 | + f'jax.abstract_arrays.ShapedArray, but we got ' |
| 127 | + f'{type(elem)}: {elem}') |
| 128 | + else: |
| 129 | + raise ValueError(f'Unknown type {type(out_shapes)}, only ' |
| 130 | + f'supports function, ShapedArray or ' |
| 131 | + f'list/tuple of ShapedArray.') |
| 132 | + |
| 133 | + # output shapes |
| 134 | + if not isinstance(shapes, collections.abc.Collection): |
| 135 | + return [shapes] |
| 136 | + else: |
| 137 | + return shapes |
| 138 | + |
| 139 | + # output evaluation function |
| 140 | + def eval_rule(*inputs): |
| 141 | + # compute the output shapes |
| 142 | + output_shapes = abs_eval_rule(*inputs) |
| 143 | + # Preallocate the outputs |
| 144 | + outputs = tuple(np.zeros(shape.shape, dtype=shape.dtype) for shape in output_shapes) |
| 145 | + # convert inputs to a tuple |
| 146 | + inputs = tuple(np.asarray(arg) for arg in inputs) |
| 147 | + # call the kernel |
| 148 | + func(outputs, inputs) |
| 149 | + # Return the outputs |
| 150 | + return tuple(outputs) |
| 151 | + |
| 152 | + def bind_primitive(*inputs): |
| 153 | + result = prim.bind(*inputs) |
| 154 | + return result[0] if len(result) == 1 else result |
| 155 | + |
| 156 | + # binding |
| 157 | + prim.def_abstract_eval(abs_eval_rule) |
| 158 | + prim.def_impl(eval_rule) |
| 159 | + # registering |
| 160 | + xla.backend_specific_translations['cpu'][prim] = partial(_func_translation, func, abs_eval_rule) |
| 161 | + return bind_primitive |
| 162 | + |
| 163 | + |
| 164 | +if __name__ == '__main__': |
| 165 | + def abs_eval(*ins): |
| 166 | + return ins |
| 167 | + |
| 168 | + import brainpy as bp |
| 169 | + bp.math.set_platform('cpu') |
| 170 | + |
| 171 | + def custom_op(outs, ins): |
| 172 | + y, y1 = outs |
| 173 | + x, x2 = ins |
| 174 | + y[:] = x + 1 |
| 175 | + y1[:] = x2 + 2 |
| 176 | + |
| 177 | + |
| 178 | + z = jnp.ones((1, 2), dtype=jnp.float32) |
| 179 | + op = register_cpu_op(custom_op, abs_eval) |
| 180 | + |
| 181 | + from jax import jit |
| 182 | + jit_op = jit(op) |
| 183 | + |
| 184 | + print(jit_op(z, z)) |
0 commit comments