Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 0 additions & 87 deletions pytensor/link/numba/dispatch/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,93 +243,6 @@ def impl_to_scalar(x):
raise TypingError(f"{x} must be a scalar compatible type.")


def enable_slice_literals():
"""Enable lowering for ``SliceLiteral``s.

TODO: This can be removed once https://github.com/numba/numba/pull/6996 is merged
and a release is made.
"""
from numba.core import types
from numba.core.datamodel.models import SliceModel
from numba.core.datamodel.registry import register_default
from numba.core.imputils import lower_cast, lower_constant
from numba.core.types.misc import SliceLiteral
from numba.cpython.slicing import get_defaults

register_default(numba.types.misc.SliceLiteral)(SliceModel)

@property
def key(self):
return self.name

SliceLiteral.key = key

def make_slice_from_constant(context, builder, ty, pyval):
sli = context.make_helper(builder, ty)
lty = context.get_value_type(types.intp)

(
default_start_pos,
default_start_neg,
default_stop_pos,
default_stop_neg,
default_step,
) = (context.get_constant(types.intp, x) for x in get_defaults(context))

step = pyval.step
if step is None:
step_is_neg = False
step = default_step
else:
step_is_neg = step < 0
step = lty(step)

start = pyval.start
if start is None:
if step_is_neg:
start = default_start_neg
else:
start = default_start_pos
else:
start = lty(start)

stop = pyval.stop
if stop is None:
if step_is_neg:
stop = default_stop_neg
else:
stop = default_stop_pos
else:
stop = lty(stop)

sli.start = start
sli.stop = stop
sli.step = step

return sli._getvalue()

@lower_constant(numba.types.SliceType)
def constant_slice(context, builder, ty, pyval):
if isinstance(ty, types.Literal):
typ = ty.literal_type
else:
typ = ty

return make_slice_from_constant(context, builder, typ, pyval)

@lower_cast(numba.types.misc.SliceLiteral, numba.types.SliceType)
def cast_from_literal(context, builder, fromty, toty, val):
return make_slice_from_constant(
context,
builder,
toty,
fromty.literal_value,
)


enable_slice_literals()


def create_tuple_creator(f, n):
"""Construct a compile-time ``tuple``-comprehension-like loop.

Expand Down