Skip to content

Remove uses of numba_basic.global_numba_func #1535

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jul 14, 2025
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 10 additions & 11 deletions pytensor/link/numba/dispatch/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.type import Type
from pytensor.ifelse import IfElse
from pytensor.link.numba.dispatch import basic as numba_basic
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

basic is the file where we are?

from pytensor.link.numba.dispatch.sparse import CSCMatrixType, CSRMatrixType
from pytensor.link.utils import (
compile_function_src,
Expand Down Expand Up @@ -402,24 +403,22 @@
return deepcopyop


@numba_njit
def makeslice(*x):
return slice(*x)


@numba_funcify.register(MakeSlice)
def numba_funcify_MakeSlice(op, **kwargs):
return global_numba_func(makeslice)
@numba_basic.numba_njit
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

numba_njit is defined in this module

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, cutting and pasting from the other file one without reading 😓

def makeslice(*x):
return slice(*x)

Check warning on line 410 in pytensor/link/numba/dispatch/basic.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/numba/dispatch/basic.py#L408-L410

Added lines #L408 - L410 were not covered by tests


@numba_njit
def shape(x):
return np.asarray(np.shape(x))
return makeslice

Check warning on line 412 in pytensor/link/numba/dispatch/basic.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/numba/dispatch/basic.py#L412

Added line #L412 was not covered by tests


@numba_funcify.register(Shape)
def numba_funcify_Shape(op, **kwargs):
return global_numba_func(shape)
@numba_basic.numba_njit
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed, same above

def shape(x):
return np.asarray(np.shape(x))

return shape


@numba_funcify.register(Shape_i)
Expand Down
130 changes: 60 additions & 70 deletions pytensor/link/numba/dispatch/scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,17 +141,16 @@ def {scalar_op_fn_name}({', '.join(input_names)}):
)(scalar_op_fn)


@numba_basic.numba_njit
def switch(condition, x, y):
if condition:
return x
else:
return y


@numba_funcify.register(Switch)
def numba_funcify_Switch(op, node, **kwargs):
return numba_basic.global_numba_func(switch)
@numba_basic.numba_njit
def switch(condition, x, y):
if condition:
return x
else:
return y

return switch


def binary_to_nary_func(inputs: list[Variable], binary_op_name: str, binary_op: str):
Expand Down Expand Up @@ -197,34 +196,32 @@ def cast(x):
return cast


@numba_basic.numba_njit
def identity(x):
return x


@numba_funcify.register(Identity)
@numba_funcify.register(TypeCastingOp)
def numba_funcify_type_casting(op, **kwargs):
return numba_basic.global_numba_func(identity)


@numba_basic.numba_njit
def clip(_x, _min, _max):
x = numba_basic.to_scalar(_x)
_min_scalar = numba_basic.to_scalar(_min)
_max_scalar = numba_basic.to_scalar(_max)

if x < _min_scalar:
return _min_scalar
elif x > _max_scalar:
return _max_scalar
else:
@numba_basic.numba_njit
def identity(x):
return x

return identity


@numba_funcify.register(Clip)
def numba_funcify_Clip(op, **kwargs):
return numba_basic.global_numba_func(clip)
@numba_basic.numba_njit
def clip(x, min_val, max_val):
x = numba_basic.to_scalar(x)
min_scalar = numba_basic.to_scalar(min_val)
max_scalar = numba_basic.to_scalar(max_val)

if x < min_scalar:
return min_scalar
elif x > max_scalar:
return max_scalar
else:
return x

return clip


@numba_funcify.register(Composite)
Expand All @@ -239,79 +236,72 @@ def numba_funcify_Composite(op, node, **kwargs):
return composite_fn


@numba_basic.numba_njit
def second(x, y):
return y


@numba_funcify.register(Second)
def numba_funcify_Second(op, node, **kwargs):
return numba_basic.global_numba_func(second)

@numba_basic.numba_njit
def second(x, y):
return y

@numba_basic.numba_njit
def reciprocal(x):
# TODO FIXME: This isn't really the behavior or `numpy.reciprocal` when
# `x` is an `int`
return 1 / x
return second


@numba_funcify.register(Reciprocal)
def numba_funcify_Reciprocal(op, node, **kwargs):
return numba_basic.global_numba_func(reciprocal)

@numba_basic.numba_njit
def reciprocal(x):
# TODO FIXME: This isn't really the behavior or `numpy.reciprocal` when
# `x` is an `int`
return 1 / x

@numba_basic.numba_njit
def sigmoid(x):
return 1 / (1 + np.exp(-x))
return reciprocal


@numba_funcify.register(Sigmoid)
def numba_funcify_Sigmoid(op, node, **kwargs):
return numba_basic.global_numba_func(sigmoid)

@numba_basic.numba_njit
def sigmoid(x):
return 1 / (1 + np.exp(-x))

@numba_basic.numba_njit
def gammaln(x):
return math.lgamma(x)
return sigmoid


@numba_funcify.register(GammaLn)
def numba_funcify_GammaLn(op, node, **kwargs):
return numba_basic.global_numba_func(gammaln)

@numba_basic.numba_njit
def gammaln(x):
return math.lgamma(x)

@numba_basic.numba_njit
def logp1mexp(x):
if x < np.log(0.5):
return np.log1p(-np.exp(x))
else:
return np.log(-np.expm1(x))
return gammaln


@numba_funcify.register(Log1mexp)
def numba_funcify_Log1mexp(op, node, **kwargs):
return numba_basic.global_numba_func(logp1mexp)

@numba_basic.numba_njit
def logp1mexp(x):
if x < np.log(0.5):
return np.log1p(-np.exp(x))
else:
return np.log(-np.expm1(x))

@numba_basic.numba_njit
def erf(x):
return math.erf(x)
return logp1mexp


@numba_funcify.register(Erf)
def numba_funcify_Erf(op, **kwargs):
return numba_basic.global_numba_func(erf)

@numba_basic.numba_njit
def erf(x):
return math.erf(x)

@numba_basic.numba_njit
def erfc(x):
return math.erfc(x)
return erf


@numba_funcify.register(Erfc)
def numba_funcify_Erfc(op, **kwargs):
return numba_basic.global_numba_func(erfc)
@numba_basic.numba_njit
def erfc(x):
return math.erfc(x)

return erfc


@numba_funcify.register(Softplus)
Expand Down