File tree Expand file tree Collapse file tree 3 files changed +67
-1
lines changed
Expand file tree Collapse file tree 3 files changed +67
-1
lines changed Original file line number Diff line number Diff line change @@ -448,16 +448,46 @@ def opfromgraph(*inputs):
448448 return opfromgraph
449449
450450
451+ def numba_funcify_debug (op , node , ** kwargs ):
452+ numba_fun = numba_funcify (op , node = node , ** kwargs )
453+
454+ if node is None :
455+ return numba_fun
456+
457+ args = ", " .join ([f"i{ i } " for i in range (len (node .inputs ))])
458+ str_op = str (op )
459+
460+ f_source = dedent (
461+ f"""
462+ def foo({ args } ):
463+ print("\\ nOp: ", "{ str_op } ")
464+ print(" inputs: ", { args } )
465+ outs = numba_fun({ args } )
466+ print(" outputs: ", outs)
467+ return outs
468+ """
469+ )
470+
471+ f = compile_function_src (
472+ f_source ,
473+ "foo" ,
474+ {** globals (), ** {"numba_fun" : numba_fun }},
475+ )
476+
477+ return numba_njit (f )
478+
479+
451480@numba_funcify .register (FunctionGraph )
452481def numba_funcify_FunctionGraph (
453482 fgraph ,
454483 node = None ,
455484 fgraph_name = "numba_funcified_fgraph" ,
485+ op_conversion_fn = numba_funcify ,
456486 ** kwargs ,
457487):
458488 return fgraph_to_python (
459489 fgraph ,
460- numba_funcify ,
490+ op_conversion_fn = op_conversion_fn ,
461491 type_conversion_fn = numba_typify ,
462492 fgraph_name = fgraph_name ,
463493 ** kwargs ,
Original file line number Diff line number Diff line change 44class NumbaLinker (JITLinker ):
55 """A `Linker` that JIT-compiles NumPy-based operations using Numba."""
66
7+ def __init__ (self , * args , debug : bool = False , ** kwargs ):
8+ super ().__init__ (* args , ** kwargs )
9+ self .debug = debug
10+
711 def fgraph_convert (self , fgraph , ** kwargs ):
812 from pytensor .link .numba .dispatch import numba_funcify
913
14+ if self .debug :
15+ from pytensor .link .numba .dispatch .basic import numba_funcify_debug
16+
17+ kwargs .setdefault ("op_conversion_fn" , numba_funcify_debug )
1018 return numba_funcify (fgraph , ** kwargs )
1119
1220 def jit_compile (self , fn ):
Original file line number Diff line number Diff line change 1+ from textwrap import dedent
2+
3+ from pytensor import function
4+ from pytensor .compile .mode import Mode
5+ from pytensor .link .numba import NumbaLinker
6+ from pytensor .tensor import vector
7+
8+
9+ def test_debug_mode (capsys ):
10+ x = vector ("x" )
11+ y = (x + 1 ).sum ()
12+
13+ debug_mode = Mode (linker = NumbaLinker (debug = True ))
14+ fn = function ([x ], y , mode = debug_mode )
15+
16+ fn ([0 , 1 ])
17+ captured = capsys .readouterr ()
18+ assert captured .out == dedent (
19+ """
20+ Op: Add
21+ inputs: [1.] [0. 1.]
22+ outputs: [1. 2.]
23+
24+ Op: Sum{axes=None}
25+ inputs: [1. 2.]
26+ outputs: 3.0
27+ """
28+ )
You can’t perform that action at this time.
0 commit comments