@@ -264,6 +264,13 @@ def wrapper(f):
264264 ### ========================================================================
265265 ### Math Operations
266266 ### ========================================================================
267+ def handle_exp2 (self , op , val ):
268+ return self .region_graph .create_proxy (
269+ "call_function" ,
270+ target = op ,
271+ args = (val ,),
272+ kwargs = {},
273+ )
267274
268275 def handle_vector_constant (
269276 self , op , shape : Tuple [int , ...], dtype , value : int | float
@@ -278,15 +285,82 @@ def handle_vector_constant(
278285 ### ========================================================================
279286 ### Reduction Operations
280287 ### ========================================================================
288+ def handle_vector_max (self , op , vector , axis = None , acc = None ):
289+ return self .region_graph .create_proxy (
290+ "call_function" ,
291+ target = op ,
292+ args = (vector , axis , acc ),
293+ kwargs = {},
294+ )
295+
296+ def handle_vector_sum (self , op , vector , axis = None , acc = None ):
297+ return self .region_graph .create_proxy (
298+ "call_function" ,
299+ target = op ,
300+ args = (vector , axis , acc ),
301+ kwargs = {},
302+ )
281303
282- def handle_vector_dot (self , op , lhs , rhs , acc ):
304+ def handle_vector_dot (self , op , lhs , rhs , acc = None ):
283305 return self .region_graph .create_proxy (
284306 "call_function" ,
285307 target = op ,
286308 args = (lhs , rhs , acc ),
287309 kwargs = {},
288310 )
289311
312+ ### ========================================================================
313+ ### Shape Manipulation Operations
314+ ### ========================================================================
315+ def handle_vector_broadcast (self , op , vector , leading_sizes ):
316+ return self .region_graph .create_proxy (
317+ "call_function" ,
318+ target = op ,
319+ args = (vector , leading_sizes ),
320+ kwargs = {},
321+ )
322+
323+ def handle_vector_broadcast_in_dim (self , op , vector , shape , broadcast_dimensions ):
324+ # Currently, we do not have a corressponding op in MLIR, so
325+ # we trace this to broadcast + transpose.
326+ # TODO: Add a vector dialect op for this in MLIR.
327+
328+ # Remove broadcast_dimensions from shape.
329+ shape_with_leading = tuple (
330+ dim for i , dim in enumerate (shape ) if i not in broadcast_dimensions
331+ )
332+
333+ # Broadcast
334+ broadcasted_vector = self .region_graph .create_proxy (
335+ "call_function" ,
336+ target = ops .vector_broadcast ,
337+ args = (vector , shape_with_leading ),
338+ kwargs = {},
339+ )
340+
341+ # Get the permutation for the transpose.
342+ permutation = tuple (
343+ i for i in range (len (shape )) if i not in broadcast_dimensions
344+ )
345+ permutation = permutation + tuple (broadcast_dimensions )
346+ print (permutation )
347+
348+ # Transpose
349+ return self .region_graph .create_proxy (
350+ "call_function" ,
351+ target = ops .vector_transpose ,
352+ args = (broadcasted_vector , permutation ),
353+ kwargs = {},
354+ )
355+
356+ def handle_vector_transpose (self , op , vector , permutation ):
357+ return self .region_graph .create_proxy (
358+ "call_function" ,
359+ target = op ,
360+ args = (vector , permutation ),
361+ kwargs = {},
362+ )
363+
290364
291365###############################################################################
292366# Launch context
0 commit comments