1
- import inspect
2
1
from functools import singledispatch
3
2
from numbers import Number
3
+ import pickle
4
4
from textwrap import indent
5
- from typing import Any , Callable , Optional , Union
5
+ from typing import Any , Callable , Literal , Optional , Union
6
+ import base64
6
7
7
8
import numba
8
9
import numpy as np
10
+ from llvmlite import ir
11
+ from numba import TypingError , literal_unroll , types , literally
12
+ from numba .core import cgutils
13
+ from numba .cpython .unsafe .tuple import tuple_setitem
14
+ from numba .np import arrayobj
9
15
from numpy .core .numeric import normalize_axis_index , normalize_axis_tuple
10
16
11
17
from pytensor import config
16
22
create_numba_signature ,
17
23
create_tuple_creator ,
18
24
numba_funcify ,
25
+ numba_njit ,
19
26
use_optimized_cheap_pass ,
20
27
)
21
- from pytensor .link .utils import (
22
- compile_function_src ,
23
- get_name_for_object ,
24
- unique_name_generator ,
25
- )
28
+ from pytensor .link .numba .dispatch .helpers import check_broadcasting , tuple_mapper
29
+ from pytensor .link .numba .dispatch import elemwise_codegen
30
+ from pytensor .link .utils import compile_function_src , get_name_for_object
26
31
from pytensor .scalar .basic import (
27
32
AND ,
28
33
OR ,
@@ -431,6 +436,170 @@ def axis_apply_fn(x):
431
436
return axis_apply_fn
432
437
433
438
439
+ _jit_options = {
440
+ "fastmath" : {
441
+ "arcp" , # Allow Reciprocal
442
+ "contract" , # Allow floating-point contraction
443
+ "afn" , # Approximate functions
444
+ "reassoc" ,
445
+ "nsz" , # TODO Do we want this one?
446
+ }
447
+ }
448
+
449
+ @numba .extending .intrinsic (jit_options = _jit_options , prefer_literal = True )
450
+ def _vectorized (
451
+ typingctx ,
452
+ scalar_func ,
453
+ input_bc_patterns ,
454
+ output_bc_patterns ,
455
+ output_dtypes ,
456
+ inplace_pattern ,
457
+ inputs ,
458
+ ):
459
+ #if not isinstance(scalar_func, types.Literal):
460
+ # raise TypingError("scalar func must be literal.")
461
+ #scalar_func = scalar_func.literal_value
462
+
463
+ arg_types = [
464
+ scalar_func ,
465
+ input_bc_patterns ,
466
+ output_bc_patterns ,
467
+ output_dtypes ,
468
+ inplace_pattern ,
469
+ inputs ,
470
+ ]
471
+
472
+ if not isinstance (input_bc_patterns , types .Literal ):
473
+ raise TypingError ("input_bc_patterns must be literal." )
474
+ input_bc_patterns = input_bc_patterns .literal_value
475
+ input_bc_patterns = pickle .loads (base64 .decodebytes (input_bc_patterns .encode ()))
476
+
477
+ if not isinstance (output_bc_patterns , types .Literal ):
478
+ raise TypeError ("output_bc_patterns must be literal." )
479
+ output_bc_patterns = output_bc_patterns .literal_value
480
+ output_bc_patterns = pickle .loads (base64 .decodebytes (output_bc_patterns .encode ()))
481
+
482
+ if not isinstance (output_dtypes , types .Literal ):
483
+ raise TypeError ("output_dtypes must be literal." )
484
+ output_dtypes = output_dtypes .literal_value
485
+ output_dtypes = pickle .loads (base64 .decodebytes (output_dtypes .encode ()))
486
+
487
+ if not isinstance (inplace_pattern , types .Literal ):
488
+ raise TypeError ("inplace_pattern must be literal." )
489
+ inplace_pattern = inplace_pattern .literal_value
490
+ inplace_pattern = pickle .loads (base64 .decodebytes (inplace_pattern .encode ()))
491
+
492
+ n_inputs = len (inputs )
493
+ n_outputs = len (output_bc_patterns )
494
+
495
+ if not len (inputs ) > 0 :
496
+ raise TypingError ("Empty argument list to elemwise op." )
497
+
498
+ if not n_outputs > 0 :
499
+ raise TypingError ("Empty list of outputs for elemwise op." )
500
+
501
+ if not all (isinstance (input , types .Array ) for input in inputs ):
502
+ raise TypingError ("Inputs to elemwise must be arrays." )
503
+ ndim = inputs [0 ].ndim
504
+
505
+ if not all (input .ndim == ndim for input in inputs ):
506
+ raise TypingError ("Inputs to elemwise must have the same rank." )
507
+
508
+ if not all (len (pattern ) == ndim for pattern in output_bc_patterns ):
509
+ raise TypingError ("Invalid output broadcasting pattern." )
510
+
511
+ scalar_signature = typingctx .resolve_function_type (
512
+ scalar_func , [in_type .dtype for in_type in inputs ], {}
513
+ )
514
+
515
+ # So we can access the constant values in codegen...
516
+ input_bc_patterns_val = input_bc_patterns
517
+ output_bc_patterns_val = output_bc_patterns
518
+ output_dtypes_val = output_dtypes
519
+ inplace_pattern_val = inplace_pattern
520
+ input_types = inputs
521
+
522
+ #assert not inplace_pattern_val
523
+
524
+ def codegen (
525
+ ctx ,
526
+ builder ,
527
+ sig ,
528
+ args ,
529
+ ):
530
+
531
+ [_ , _ , _ , _ , _ , inputs ] = args
532
+ inputs = cgutils .unpack_tuple (builder , inputs )
533
+ inputs = [arrayobj .make_array (ty )(ctx , builder , val ) for ty , val in zip (input_types , inputs )]
534
+ in_shapes = [cgutils .unpack_tuple (builder , obj .shape ) for obj in inputs ]
535
+
536
+ iter_shape = elemwise_codegen .compute_itershape (
537
+ ctx ,
538
+ builder ,
539
+ in_shapes ,
540
+ input_bc_patterns_val ,
541
+ )
542
+
543
+ outputs , output_types = elemwise_codegen .make_outputs (
544
+ ctx ,
545
+ builder ,
546
+ iter_shape ,
547
+ output_bc_patterns_val ,
548
+ output_dtypes_val ,
549
+ inplace_pattern_val ,
550
+ inputs ,
551
+ input_types ,
552
+ )
553
+
554
+ def _check_input_shapes (* _ ):
555
+ # TODO impl
556
+ return
557
+
558
+ _check_input_shapes (
559
+ ctx ,
560
+ builder ,
561
+ iter_shape ,
562
+ inputs ,
563
+ input_bc_patterns_val ,
564
+ )
565
+
566
+ elemwise_codegen .make_loop_call (
567
+ typingctx ,
568
+ ctx ,
569
+ builder ,
570
+ scalar_func ,
571
+ scalar_signature ,
572
+ iter_shape ,
573
+ inputs ,
574
+ outputs ,
575
+ input_bc_patterns_val ,
576
+ output_bc_patterns_val ,
577
+ input_types ,
578
+ output_types ,
579
+ )
580
+
581
+ if len (outputs ) == 1 :
582
+ if inplace_pattern :
583
+ assert inplace_pattern [0 ][0 ] == 0
584
+ ctx .nrt .incref (builder , sig .return_type , outputs [0 ]._getvalue ())
585
+ return outputs [0 ]._getvalue ()
586
+
587
+ for inplace_idx in dict (inplace_pattern ):
588
+ ctx .nrt .incref (builder , sig .return_type .types [inplace_idx ], outputs [inplace_idx ]._get_value ())
589
+ return ctx .make_tuple (builder , sig .return_type , [out ._getvalue () for out in outputs ])
590
+
591
+ # TODO check inplace_pattern
592
+ ret_type = types .Tuple ([
593
+ types .Array (numba .from_dtype (np .dtype (dtype )), ndim , "C" )
594
+ for dtype in output_dtypes
595
+ ])
596
+ if len (output_dtypes ) == 1 :
597
+ ret_type = ret_type .types [0 ]
598
+ sig = ret_type (* arg_types )
599
+
600
+ return sig , codegen
601
+
602
+
434
603
@numba_funcify .register (Elemwise )
435
604
def numba_funcify_Elemwise (op , node , ** kwargs ):
436
605
# Creating a new scalar node is more involved and unnecessary
@@ -441,55 +610,42 @@ def numba_funcify_Elemwise(op, node, **kwargs):
441
610
scalar_inputs = [scalar (dtype = input .dtype ) for input in node .inputs ]
442
611
scalar_node = op .scalar_op .make_node (* scalar_inputs )
443
612
613
+ flags = {
614
+ "arcp" , # Allow Reciprocal
615
+ "contract" , # Allow floating-point contraction
616
+ "afn" , # Approximate functions
617
+ "reassoc" ,
618
+ "nsz" , # TODO Do we want this one?
619
+ }
620
+
444
621
scalar_op_fn = numba_funcify (
445
- op .scalar_op , node = scalar_node , parent_node = node , inline = "always" , ** kwargs
622
+ op .scalar_op , node = scalar_node , parent_node = node , fastmath = flags , ** kwargs
446
623
)
447
- elemwise_fn = create_vectorize_func (scalar_op_fn , node , use_signature = False )
448
- elemwise_fn_name = elemwise_fn .__name__
449
-
450
- if op .inplace_pattern :
451
- input_idx = op .inplace_pattern [0 ]
452
- sign_obj = inspect .signature (elemwise_fn .py_scalar_func )
453
- input_names = list (sign_obj .parameters .keys ())
454
-
455
- unique_names = unique_name_generator ([elemwise_fn_name , "np" ], suffix_sep = "_" )
456
- input_names = [unique_names (i , force_unique = True ) for i in input_names ]
457
624
458
- updated_input_name = input_names [input_idx ]
459
-
460
- inplace_global_env = {elemwise_fn_name : elemwise_fn , "np" : np }
461
-
462
- inplace_elemwise_fn_name = f"{ elemwise_fn_name } _inplace"
463
-
464
- input_signature_str = ", " .join (input_names )
465
-
466
- if node .inputs [input_idx ].ndim > 0 :
467
- inplace_elemwise_src = f"""
468
- def { inplace_elemwise_fn_name } ({ input_signature_str } ):
469
- return { elemwise_fn_name } ({ input_signature_str + ", " + updated_input_name } )
470
- """
471
- else :
472
- # We can't perform in-place updates on Numba scalars, so we need to
473
- # convert them to NumPy scalars.
474
- # TODO: We should really prevent the rewrites from creating
475
- # in-place updates on scalars when the Numba mode is selected (or
476
- # in general?).
477
- inplace_elemwise_src = f"""
478
- def { inplace_elemwise_fn_name } ({ input_signature_str } ):
479
- { updated_input_name } _scalar = np.asarray({ updated_input_name } )
480
- return { elemwise_fn_name } ({ input_signature_str + ", " + updated_input_name } _scalar).item()
481
- """
482
-
483
- inplace_elemwise_fn = compile_function_src (
484
- inplace_elemwise_src ,
485
- inplace_elemwise_fn_name ,
486
- {** globals (), ** inplace_global_env },
487
- )
488
- return numba_basic .numba_njit (inline = "always" , fastmath = config .numba__fastmath )(
489
- inplace_elemwise_fn
625
+ ndim = node .outputs [0 ].ndim
626
+ output_bc_patterns = tuple ([(False ,) * ndim for _ in node .outputs ])
627
+ input_bc_patterns = tuple ([input_var .broadcastable for input_var in node .inputs ])
628
+ output_dtypes = tuple (variable .dtype for variable in node .outputs )
629
+ inplace_pattern = tuple (op .inplace_pattern .items ())
630
+
631
+ # numba doesn't support nested literals right now...
632
+ input_bc_patterns = base64 .encodebytes (pickle .dumps (input_bc_patterns )).decode ()
633
+ output_bc_patterns = base64 .encodebytes (pickle .dumps (output_bc_patterns )).decode ()
634
+ output_dtypes = base64 .encodebytes (pickle .dumps (output_dtypes )).decode ()
635
+ inplace_pattern = base64 .encodebytes (pickle .dumps (inplace_pattern )).decode ()
636
+
637
+ @numba_njit
638
+ def elemwise_wrapper (* inputs ):
639
+ return _vectorized (
640
+ scalar_op_fn ,
641
+ input_bc_patterns ,
642
+ output_bc_patterns ,
643
+ output_dtypes ,
644
+ inplace_pattern ,
645
+ inputs ,
490
646
)
491
647
492
- return elemwise_fn
648
+ return elemwise_wrapper
493
649
494
650
495
651
@numba_funcify .register (CAReduce )
0 commit comments