Skip to content

Commit 1d9fa84

Browse files
aseyboldtricardoV94
authored andcommitted
fix(numba): Add warnings for objectmode
1 parent 426e035 commit 1d9fa84

File tree

1 file changed

+22
-8
lines changed

1 file changed

+22
-8
lines changed

pytensor/link/numba/dispatch/basic.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -560,11 +560,18 @@ def {fn_name}({", ".join(input_names)}):
560560
@numba_funcify.register(Subtensor)
561561
@numba_funcify.register(AdvancedSubtensor1)
562562
def numba_funcify_Subtensor(op, node, **kwargs):
563-
subtensor_def_src = create_index_func(
564-
node, objmode=isinstance(op, AdvancedSubtensor)
565-
)
563+
objmode = isinstance(op, AdvancedSubtensor)
564+
if objmode:
565+
warnings.warn(
566+
("Numba will use object mode to allow run " "AdvancedSubtensor."),
567+
UserWarning,
568+
)
569+
570+
subtensor_def_src = create_index_func(node, objmode=objmode)
566571

567-
global_env = {"np": np, "objmode": numba.objmode}
572+
global_env = {"np": np}
573+
if objmode:
574+
global_env["objmode"] = numba.objmode
568575

569576
subtensor_fn = compile_function_src(
570577
subtensor_def_src, "subtensor", {**globals(), **global_env}
@@ -575,11 +582,18 @@ def numba_funcify_Subtensor(op, node, **kwargs):
575582

576583
@numba_funcify.register(IncSubtensor)
577584
def numba_funcify_IncSubtensor(op, node, **kwargs):
578-
incsubtensor_def_src = create_index_func(
579-
node, objmode=isinstance(op, AdvancedIncSubtensor)
580-
)
585+
objmode = isinstance(op, AdvancedIncSubtensor)
586+
if objmode:
587+
warnings.warn(
588+
("Numba will use object mode to allow run " "AdvancedIncSubtensor."),
589+
UserWarning,
590+
)
591+
592+
incsubtensor_def_src = create_index_func(node, objmode=objmode)
581593

582-
global_env = {"np": np, "objmode": numba.objmode}
594+
global_env = {"np": np}
595+
if objmode:
596+
global_env["objmode"] = numba.objmode
583597

584598
incsubtensor_fn = compile_function_src(
585599
incsubtensor_def_src, "incsubtensor", {**globals(), **global_env}

0 commit comments

Comments
 (0)