Skip to content

Commit 9e79f3a

Browse files
committed
Initial version of llvm elemwise impl
1 parent 38dc6c9 commit 9e79f3a

File tree

3 files changed

+481
-51
lines changed

3 files changed

+481
-51
lines changed

pytensor/link/numba/dispatch/elemwise.py

Lines changed: 207 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,17 @@
1-
import inspect
21
from functools import singledispatch
32
from numbers import Number
3+
import pickle
44
from textwrap import indent
5-
from typing import Any, Callable, Optional, Union
5+
from typing import Any, Callable, Literal, Optional, Union
6+
import base64
67

78
import numba
89
import numpy as np
10+
from llvmlite import ir
11+
from numba import TypingError, literal_unroll, types, literally
12+
from numba.core import cgutils
13+
from numba.cpython.unsafe.tuple import tuple_setitem
14+
from numba.np import arrayobj
915
from numpy.core.numeric import normalize_axis_index, normalize_axis_tuple
1016

1117
from pytensor import config
@@ -16,13 +22,12 @@
1622
create_numba_signature,
1723
create_tuple_creator,
1824
numba_funcify,
25+
numba_njit,
1926
use_optimized_cheap_pass,
2027
)
21-
from pytensor.link.utils import (
22-
compile_function_src,
23-
get_name_for_object,
24-
unique_name_generator,
25-
)
28+
from pytensor.link.numba.dispatch.helpers import check_broadcasting, tuple_mapper
29+
from pytensor.link.numba.dispatch import elemwise_codegen
30+
from pytensor.link.utils import compile_function_src, get_name_for_object
2631
from pytensor.scalar.basic import (
2732
AND,
2833
OR,
@@ -431,6 +436,170 @@ def axis_apply_fn(x):
431436
return axis_apply_fn
432437

433438

439+
_jit_options = {
440+
"fastmath": {
441+
"arcp", # Allow Reciprocal
442+
"contract", # Allow floating-point contraction
443+
"afn", # Approximate functions
444+
"reassoc",
445+
"nsz", # TODO Do we want this one?
446+
}
447+
}
448+
449+
@numba.extending.intrinsic(jit_options=_jit_options, prefer_literal=True)
450+
def _vectorized(
451+
typingctx,
452+
scalar_func,
453+
input_bc_patterns,
454+
output_bc_patterns,
455+
output_dtypes,
456+
inplace_pattern,
457+
inputs,
458+
):
459+
#if not isinstance(scalar_func, types.Literal):
460+
# raise TypingError("scalar func must be literal.")
461+
#scalar_func = scalar_func.literal_value
462+
463+
arg_types = [
464+
scalar_func,
465+
input_bc_patterns,
466+
output_bc_patterns,
467+
output_dtypes,
468+
inplace_pattern,
469+
inputs,
470+
]
471+
472+
if not isinstance(input_bc_patterns, types.Literal):
473+
raise TypingError("input_bc_patterns must be literal.")
474+
input_bc_patterns = input_bc_patterns.literal_value
475+
input_bc_patterns = pickle.loads(base64.decodebytes(input_bc_patterns.encode()))
476+
477+
if not isinstance(output_bc_patterns, types.Literal):
478+
raise TypeError("output_bc_patterns must be literal.")
479+
output_bc_patterns = output_bc_patterns.literal_value
480+
output_bc_patterns = pickle.loads(base64.decodebytes(output_bc_patterns.encode()))
481+
482+
if not isinstance(output_dtypes, types.Literal):
483+
raise TypeError("output_dtypes must be literal.")
484+
output_dtypes = output_dtypes.literal_value
485+
output_dtypes = pickle.loads(base64.decodebytes(output_dtypes.encode()))
486+
487+
if not isinstance(inplace_pattern, types.Literal):
488+
raise TypeError("inplace_pattern must be literal.")
489+
inplace_pattern = inplace_pattern.literal_value
490+
inplace_pattern = pickle.loads(base64.decodebytes(inplace_pattern.encode()))
491+
492+
n_inputs = len(inputs)
493+
n_outputs = len(output_bc_patterns)
494+
495+
if not len(inputs) > 0:
496+
raise TypingError("Empty argument list to elemwise op.")
497+
498+
if not n_outputs > 0:
499+
raise TypingError("Empty list of outputs for elemwise op.")
500+
501+
if not all(isinstance(input, types.Array) for input in inputs):
502+
raise TypingError("Inputs to elemwise must be arrays.")
503+
ndim = inputs[0].ndim
504+
505+
if not all(input.ndim == ndim for input in inputs):
506+
raise TypingError("Inputs to elemwise must have the same rank.")
507+
508+
if not all(len(pattern) == ndim for pattern in output_bc_patterns):
509+
raise TypingError("Invalid output broadcasting pattern.")
510+
511+
scalar_signature = typingctx.resolve_function_type(
512+
scalar_func, [in_type.dtype for in_type in inputs], {}
513+
)
514+
515+
# So we can access the constant values in codegen...
516+
input_bc_patterns_val = input_bc_patterns
517+
output_bc_patterns_val = output_bc_patterns
518+
output_dtypes_val = output_dtypes
519+
inplace_pattern_val = inplace_pattern
520+
input_types = inputs
521+
522+
#assert not inplace_pattern_val
523+
524+
def codegen(
525+
ctx,
526+
builder,
527+
sig,
528+
args,
529+
):
530+
531+
[_, _, _, _, _, inputs] = args
532+
inputs = cgutils.unpack_tuple(builder, inputs)
533+
inputs = [arrayobj.make_array(ty)(ctx, builder, val) for ty, val in zip(input_types, inputs)]
534+
in_shapes = [cgutils.unpack_tuple(builder, obj.shape) for obj in inputs]
535+
536+
iter_shape = elemwise_codegen.compute_itershape(
537+
ctx,
538+
builder,
539+
in_shapes,
540+
input_bc_patterns_val,
541+
)
542+
543+
outputs, output_types = elemwise_codegen.make_outputs(
544+
ctx,
545+
builder,
546+
iter_shape,
547+
output_bc_patterns_val,
548+
output_dtypes_val,
549+
inplace_pattern_val,
550+
inputs,
551+
input_types,
552+
)
553+
554+
def _check_input_shapes(*_):
555+
# TODO impl
556+
return
557+
558+
_check_input_shapes(
559+
ctx,
560+
builder,
561+
iter_shape,
562+
inputs,
563+
input_bc_patterns_val,
564+
)
565+
566+
elemwise_codegen.make_loop_call(
567+
typingctx,
568+
ctx,
569+
builder,
570+
scalar_func,
571+
scalar_signature,
572+
iter_shape,
573+
inputs,
574+
outputs,
575+
input_bc_patterns_val,
576+
output_bc_patterns_val,
577+
input_types,
578+
output_types,
579+
)
580+
581+
if len(outputs) == 1:
582+
if inplace_pattern:
583+
assert inplace_pattern[0][0] == 0
584+
ctx.nrt.incref(builder, sig.return_type, outputs[0]._getvalue())
585+
return outputs[0]._getvalue()
586+
587+
for inplace_idx in dict(inplace_pattern):
588+
ctx.nrt.incref(builder, sig.return_type.types[inplace_idx], outputs[inplace_idx]._get_value())
589+
return ctx.make_tuple(builder, sig.return_type, [out._getvalue() for out in outputs])
590+
591+
# TODO check inplace_pattern
592+
ret_type = types.Tuple([
593+
types.Array(numba.from_dtype(np.dtype(dtype)), ndim, "C")
594+
for dtype in output_dtypes
595+
])
596+
if len(output_dtypes) == 1:
597+
ret_type = ret_type.types[0]
598+
sig = ret_type(*arg_types)
599+
600+
return sig, codegen
601+
602+
434603
@numba_funcify.register(Elemwise)
435604
def numba_funcify_Elemwise(op, node, **kwargs):
436605
# Creating a new scalar node is more involved and unnecessary
@@ -441,55 +610,42 @@ def numba_funcify_Elemwise(op, node, **kwargs):
441610
scalar_inputs = [scalar(dtype=input.dtype) for input in node.inputs]
442611
scalar_node = op.scalar_op.make_node(*scalar_inputs)
443612

613+
flags = {
614+
"arcp", # Allow Reciprocal
615+
"contract", # Allow floating-point contraction
616+
"afn", # Approximate functions
617+
"reassoc",
618+
"nsz", # TODO Do we want this one?
619+
}
620+
444621
scalar_op_fn = numba_funcify(
445-
op.scalar_op, node=scalar_node, parent_node=node, inline="always", **kwargs
622+
op.scalar_op, node=scalar_node, parent_node=node, fastmath=flags, **kwargs
446623
)
447-
elemwise_fn = create_vectorize_func(scalar_op_fn, node, use_signature=False)
448-
elemwise_fn_name = elemwise_fn.__name__
449-
450-
if op.inplace_pattern:
451-
input_idx = op.inplace_pattern[0]
452-
sign_obj = inspect.signature(elemwise_fn.py_scalar_func)
453-
input_names = list(sign_obj.parameters.keys())
454-
455-
unique_names = unique_name_generator([elemwise_fn_name, "np"], suffix_sep="_")
456-
input_names = [unique_names(i, force_unique=True) for i in input_names]
457624

458-
updated_input_name = input_names[input_idx]
459-
460-
inplace_global_env = {elemwise_fn_name: elemwise_fn, "np": np}
461-
462-
inplace_elemwise_fn_name = f"{elemwise_fn_name}_inplace"
463-
464-
input_signature_str = ", ".join(input_names)
465-
466-
if node.inputs[input_idx].ndim > 0:
467-
inplace_elemwise_src = f"""
468-
def {inplace_elemwise_fn_name}({input_signature_str}):
469-
return {elemwise_fn_name}({input_signature_str + ", " + updated_input_name})
470-
"""
471-
else:
472-
# We can't perform in-place updates on Numba scalars, so we need to
473-
# convert them to NumPy scalars.
474-
# TODO: We should really prevent the rewrites from creating
475-
# in-place updates on scalars when the Numba mode is selected (or
476-
# in general?).
477-
inplace_elemwise_src = f"""
478-
def {inplace_elemwise_fn_name}({input_signature_str}):
479-
{updated_input_name}_scalar = np.asarray({updated_input_name})
480-
return {elemwise_fn_name}({input_signature_str + ", " + updated_input_name}_scalar).item()
481-
"""
482-
483-
inplace_elemwise_fn = compile_function_src(
484-
inplace_elemwise_src,
485-
inplace_elemwise_fn_name,
486-
{**globals(), **inplace_global_env},
487-
)
488-
return numba_basic.numba_njit(inline="always", fastmath=config.numba__fastmath)(
489-
inplace_elemwise_fn
625+
ndim = node.outputs[0].ndim
626+
output_bc_patterns = tuple([(False,) * ndim for _ in node.outputs])
627+
input_bc_patterns = tuple([input_var.broadcastable for input_var in node.inputs])
628+
output_dtypes = tuple(variable.dtype for variable in node.outputs)
629+
inplace_pattern = tuple(op.inplace_pattern.items())
630+
631+
# numba doesn't support nested literals right now...
632+
input_bc_patterns = base64.encodebytes(pickle.dumps(input_bc_patterns)).decode()
633+
output_bc_patterns = base64.encodebytes(pickle.dumps(output_bc_patterns)).decode()
634+
output_dtypes = base64.encodebytes(pickle.dumps(output_dtypes)).decode()
635+
inplace_pattern = base64.encodebytes(pickle.dumps(inplace_pattern)).decode()
636+
637+
@numba_njit
638+
def elemwise_wrapper(*inputs):
639+
return _vectorized(
640+
scalar_op_fn,
641+
input_bc_patterns,
642+
output_bc_patterns,
643+
output_dtypes,
644+
inplace_pattern,
645+
inputs,
490646
)
491647

492-
return elemwise_fn
648+
return elemwise_wrapper
493649

494650

495651
@numba_funcify.register(CAReduce)

0 commit comments

Comments
 (0)