77
88from functools import partial
99
10+ from typing import Union , Tuple
1011import jax .numpy as jnp
1112import numpy as np
12- from jax import core
13+ from jax import core , dtypes
1314from jax .abstract_arrays import ShapedArray
1415from jax .interpreters import xla , batching
1516from jax .lax import scan
1617from jax .lib import xla_client
1718
19+ from .utils import GPUOperatorNotFound
20+
1821try :
1922 from . import gpu_ops
2023except ImportError :
2629_event_sum_prim = core .Primitive ("event_sum" )
2730
2831
29- def event_sum (events , pre2post , post_num , values ):
32+ def event_sum (events : jnp .ndarray ,
33+ pre2post : Tuple [jnp .ndarray , jnp .ndarray ],
34+ post_num : int ,
35+ values : Union [float , jnp .ndarray ]):
3036 # events
3137 if events .dtype != jnp .bool_ :
3238 raise ValueError (f'"events" must be a vector of bool, while we got { events .dtype } ' )
@@ -39,17 +45,16 @@ def event_sum(events, pre2post, post_num, values):
3945 if indices .dtype != indptr .dtype :
4046 raise ValueError (f"The dtype of pre2post[0] must be equal to that of pre2post[1], "
4147 f"while we got { (indices .dtype , indptr .dtype )} " )
42- if indices .dtype not in [jnp .uint32 , jnp .uint64 ]:
43- raise ValueError (f'The dtype of pre2post must be uint32 or uint64 , while we got { indices .dtype } ' )
48+ if indices .dtype not in [jnp .uint32 , jnp .uint64 , jnp . int32 , jnp . int64 ]:
49+ raise ValueError (f'The dtype of pre2post must be integer , while we got { indices .dtype } ' )
4450
4551 # output value
46- values = jnp .asarray ([ values ] )
47- if values . dtype not in [jnp .float32 , jnp .float64 ]:
48- raise ValueError (f'The dtype of "values" must be float32 or float64, while we got { values . dtype } .' )
49- if values .size not in [1 , indices .size ]:
52+ dtype = values . dtype if isinstance ( values , jnp .ndarray ) else dtypes . canonicalize_dtype ( type ( values ) )
53+ if dtype not in [jnp .float32 , jnp .float64 ]:
54+ raise ValueError (f'The dtype of "values" must be float32 or float64, while we got { dtype } .' )
55+ if np .size ( values ) not in [1 , indices .size ]:
5056 raise ValueError (f'The size of "values" must be 1 (a scalar) or len(pre2post[0]) (a vector), '
51- f'while we got { values .size } != 1 != { indices .size } ' )
52- values = values .flatten ()
57+ f'while we got { np .size (values )} != 1 != { indices .size } ' )
5358 # bind operator
5459 return _event_sum_prim .bind (events , indices , indptr , values , post_num = post_num )
5560
@@ -58,34 +63,27 @@ def _event_sum_abstract(events, indices, indptr, values, *, post_num):
5863 return ShapedArray (dtype = values .dtype , shape = (post_num ,))
5964
6065
61- _event_sum_prim .def_abstract_eval (_event_sum_abstract )
62- _event_sum_prim .def_impl (partial (xla .apply_primitive , _event_sum_prim ))
63-
64-
6566def _event_sum_translation (c , events , indices , indptr , values , * , post_num , platform = "cpu" ):
66- # The pre/post shape
67+ # The shape of pre/post
6768 pre_size = np .array (c .get_shape (events ).dimensions ()[0 ], dtype = np .uint32 )
6869 _pre_shape = x_shape (np .dtype (np .uint32 ), (), ())
6970 _post_shape = x_shape (np .dtype (np .uint32 ), (), ())
7071
7172 # The indices shape
7273 indices_shape = c .get_shape (indices )
7374 Itype = indices_shape .element_type ()
74- assert Itype in [np .uint32 , np .uint64 ]
7575
7676 # The value shape
7777 values_shape = c .get_shape (values )
7878 Ftype = values_shape .element_type ()
79- assert Ftype in [np .float32 , np .float64 ]
8079 values_dim = values_shape .dimensions ()
8180
8281 # We dispatch a different call depending on the dtype
83- f_type = b'_f32' if Ftype == np .float32 else b'_f64'
84- i_type = b'_i32' if Itype == np .uint32 else b'_i64'
82+ f_type = b'_f32' if Ftype in np .float32 else b'_f64'
83+ i_type = b'_i32' if Itype in [ np .uint32 , np . int32 ] else b'_i64'
8584
86- # And then the following is what changes between the GPU and CPU
8785 if platform == "cpu" :
88- v_type = b'_event_sum_homo' if values_dim [ 0 ] == 1 else b'_event_sum_heter'
86+ v_type = b'_event_sum_homo' if len ( values_dim ) == 0 else b'_event_sum_heter'
8987 return x_ops .CustomCallWithLayout (
9088 c ,
9189 platform .encode () + v_type + f_type + i_type ,
@@ -103,9 +101,12 @@ def _event_sum_translation(c, events, indices, indptr, values, *, post_num, plat
103101 c .get_shape (values )),
104102 shape_with_layout = x_shape (np .dtype (Ftype ), (post_num ,), (0 ,)),
105103 )
104+
105+ # GPU platform
106106 elif platform == 'gpu' :
107107 if gpu_ops is None :
108- raise ValueError ('Cannot find compiled gpu wheels.' )
108+ raise GPUOperatorNotFound ('event_sum' )
109+
109110 v_type = b'_event_sum_homo' if values_dim [0 ] == 1 else b'_event_sum_heter'
110111 opaque = gpu_ops .build_event_sum_descriptor (pre_size , post_num )
111112 return x_ops .CustomCallWithLayout (
@@ -127,11 +128,7 @@ def _event_sum_translation(c, events, indices, indptr, values, *, post_num, plat
127128 raise ValueError ("Unsupported platform, we only support 'cpu' or 'gpu'" )
128129
129130
130- xla .backend_specific_translations ["cpu" ][_event_sum_prim ] = partial (_event_sum_translation , platform = "cpu" )
131- xla .backend_specific_translations ["gpu" ][_event_sum_prim ] = partial (_event_sum_translation , platform = "gpu" )
132-
133-
134- def _event_sum_batch (args , axes ):
131+ def _event_sum_batch (args , axes , * , post_num ):
135132 batch_axes , batch_args , non_batch_args = [], {}, {}
136133 for ax_i , ax in enumerate (axes ):
137134 if ax is None :
@@ -143,19 +140,22 @@ def _event_sum_batch(args, axes):
143140 def f (_ , x ):
144141 pars = tuple ([(x [f'ax{ i } ' ] if i in batch_axes else non_batch_args [f'ax{ i } ' ])
145142 for i in range (len (axes ))])
146- return 0 , _event_sum_prim .bind (* pars )
143+ return 0 , _event_sum_prim .bind (* pars , post_num = post_num )
144+
147145 _ , outs = scan (f , 0 , batch_args )
148146 return outs , 0
149147
150148
149+ _event_sum_prim .def_abstract_eval (_event_sum_abstract )
150+ _event_sum_prim .def_impl (partial (xla .apply_primitive , _event_sum_prim ))
151151batching .primitive_batchers [_event_sum_prim ] = _event_sum_batch
152-
152+ xla .backend_specific_translations ["cpu" ][_event_sum_prim ] = partial (_event_sum_translation , platform = "cpu" )
153+ xla .backend_specific_translations ["gpu" ][_event_sum_prim ] = partial (_event_sum_translation , platform = "gpu" )
153154
154155# ---------------------------
155156# event sum kernel 2
156157# ---------------------------
157158
158-
159159_event_sum2_prim = core .Primitive ("event_sum2" )
160160
161161
0 commit comments