File tree Expand file tree Collapse file tree 3 files changed +72
-1
lines changed
Expand file tree Collapse file tree 3 files changed +72
-1
lines changed Original file line number Diff line number Diff line change @@ -448,16 +448,46 @@ def opfromgraph(*inputs):
448448 return opfromgraph
449449
450450
451+ def numba_funcify_debug (op , node , ** kwargs ):
452+ numba_fun = numba_funcify (op , node = node , ** kwargs )
453+
454+ if node is None :
455+ return numba_fun
456+
457+ args = ", " .join ([f"i{ i } " for i in range (len (node .inputs ))])
458+ str_op = str (op )
459+
460+ f_source = dedent (
461+ f"""
462+ def foo({ args } ):
463+ print("\\ nOp: ", "{ str_op } ")
464+ print(" inputs: ", { args } )
465+ outs = numba_fun({ args } )
466+ print(" outputs: ", outs)
467+ return outs
468+ """
469+ )
470+
471+ f = compile_function_src (
472+ f_source ,
473+ "foo" ,
474+ {** globals (), ** {"numba_fun" : numba_fun }},
475+ )
476+
477+ return numba_njit (f )
478+
479+
451480@numba_funcify .register (FunctionGraph )
452481def numba_funcify_FunctionGraph (
453482 fgraph ,
454483 node = None ,
455484 fgraph_name = "numba_funcified_fgraph" ,
485+ op_conversion_fn = numba_funcify ,
456486 ** kwargs ,
457487):
458488 return fgraph_to_python (
459489 fgraph ,
460- numba_funcify ,
490+ op_conversion_fn = op_conversion_fn ,
461491 type_conversion_fn = numba_typify ,
462492 fgraph_name = fgraph_name ,
463493 ** kwargs ,
Original file line number Diff line number Diff line change 44class NumbaLinker (JITLinker ):
55 """A `Linker` that JIT-compiles NumPy-based operations using Numba."""
66
7+ def __init__ (self , * args , debug : bool = False , ** kwargs ):
8+ super ().__init__ (* args , ** kwargs )
9+ self .debug = debug
10+
711 def fgraph_convert (self , fgraph , ** kwargs ):
812 from pytensor .link .numba .dispatch import numba_funcify
913
14+ if self .debug :
15+ from pytensor .link .numba .dispatch .basic import numba_funcify_debug
16+
17+ kwargs .setdefault ("op_conversion_fn" , numba_funcify_debug )
1018 return numba_funcify (fgraph , ** kwargs )
1119
1220 def jit_compile (self , fn ):
Original file line number Diff line number Diff line change 1+ from textwrap import dedent
2+
3+ import pytest
4+
5+ from pytensor import function
6+ from pytensor .compile .mode import Mode
7+ from pytensor .link .numba import NumbaLinker
8+ from pytensor .tensor import vector
9+
10+
11+ pytest .importorskip ("numba" )
12+
13+
14+ def test_debug_mode (capsys ):
15+ x = vector ("x" )
16+ y = (x + 1 ).sum ()
17+
18+ debug_mode = Mode (linker = NumbaLinker (debug = True ))
19+ fn = function ([x ], y , mode = debug_mode )
20+
21+ assert fn ([0 , 1 ]) == 3.0
22+ captured = capsys .readouterr ()
23+ assert captured .out == dedent (
24+ """
25+ Op: Add
26+ inputs: [1.] [0. 1.]
27+ outputs: [1. 2.]
28+
29+ Op: Sum{axes=None}
30+ inputs: [1. 2.]
31+ outputs: 3.0
32+ """
33+ )
You can’t perform that action at this time.
0 commit comments