diff --git a/pytensor/link/numba/dispatch/basic.py b/pytensor/link/numba/dispatch/basic.py index c66a237f06..3ec05480ba 100644 --- a/pytensor/link/numba/dispatch/basic.py +++ b/pytensor/link/numba/dispatch/basic.py @@ -448,16 +448,46 @@ def opfromgraph(*inputs): return opfromgraph +def numba_funcify_debug(op, node, **kwargs): + numba_fun = numba_funcify(op, node=node, **kwargs) + + if node is None: + return numba_fun + + args = ", ".join([f"i{i}" for i in range(len(node.inputs))]) + str_op = str(op) + + f_source = dedent( + f""" + def foo({args}): + print("\\nOp: ", "{str_op}") + print(" inputs: ", {args}) + outs = numba_fun({args}) + print(" outputs: ", outs) + return outs + """ + ) + + f = compile_function_src( + f_source, + "foo", + {**globals(), **{"numba_fun": numba_fun}}, + ) + + return numba_njit(f) + + @numba_funcify.register(FunctionGraph) def numba_funcify_FunctionGraph( fgraph, node=None, fgraph_name="numba_funcified_fgraph", + op_conversion_fn=numba_funcify, **kwargs, ): return fgraph_to_python( fgraph, - numba_funcify, + op_conversion_fn=op_conversion_fn, type_conversion_fn=numba_typify, fgraph_name=fgraph_name, **kwargs, diff --git a/pytensor/link/numba/linker.py b/pytensor/link/numba/linker.py index 553c5ef217..87ac7c5172 100644 --- a/pytensor/link/numba/linker.py +++ b/pytensor/link/numba/linker.py @@ -4,9 +4,17 @@ class NumbaLinker(JITLinker): """A `Linker` that JIT-compiles NumPy-based operations using Numba.""" + def __init__(self, *args, debug: bool = False, **kwargs): + super().__init__(*args, **kwargs) + self.debug = debug + def fgraph_convert(self, fgraph, **kwargs): from pytensor.link.numba.dispatch import numba_funcify + if self.debug: + from pytensor.link.numba.dispatch.basic import numba_funcify_debug + + kwargs.setdefault("op_conversion_fn", numba_funcify_debug) return numba_funcify(fgraph, **kwargs) def jit_compile(self, fn): diff --git a/tests/link/numba/test_linker.py b/tests/link/numba/test_linker.py new file mode 100644 index 0000000000..cca1161c15 --- /dev/null +++ b/tests/link/numba/test_linker.py @@ -0,0 +1,33 @@ +from textwrap import dedent + +import pytest + +from pytensor import function +from pytensor.compile.mode import Mode +from pytensor.link.numba import NumbaLinker +from pytensor.tensor import vector + + +pytest.importorskip("numba") + + +def test_debug_mode(capsys): + x = vector("x") + y = (x + 1).sum() + + debug_mode = Mode(linker=NumbaLinker(debug=True)) + fn = function([x], y, mode=debug_mode) + + assert fn([0, 1]) == 3.0 + captured = capsys.readouterr() + assert captured.out == dedent( + """ + Op: Add + inputs: [1.] [0. 1.] + outputs: [1. 2.] + + Op: Sum{axes=None} + inputs: [1. 2.] + outputs: 3.0 + """ + )