File tree Expand file tree Collapse file tree 2 files changed +21
-0
lines changed Expand file tree Collapse file tree 2 files changed +21
-0
lines changed Original file line number Diff line number Diff line change @@ -565,6 +565,12 @@ def __init__(self, f):
565565 def __call__ (self , state ):
566566 return self .f (** state )
567567
568+ def __getattr__ (self , item ):
569+ """Allow access to the original function attributes."""
570+ if item == "f" :
571+ return self .f
572+ return getattr (self .f , item )
573+
568574
569575class CallableTensor :
570576 """Turns a symbolic variable with one input into a function that returns symbolic arguments with the one variable replaced with the input."""
Original file line number Diff line number Diff line change 3535from pymc .exceptions import NotConstantValueError
3636from pymc .logprob .utils import ParameterValueError
3737from pymc .pytensorf import (
38+ PointFunc ,
3839 collect_default_updates ,
3940 compile ,
4041 constant_fold ,
@@ -743,3 +744,17 @@ def test_hessian_sign_change_warning(func):
743744 res_neg = func (f , vars = [x ])
744745 res = func (f , vars = [x ], negate_output = False )
745746 assert equal_computations ([res_neg ], [- res ])
747+
748+
749+ def test_point_func ():
750+ x , y = pt .vectors ("x" , "y" )
751+ outs = x * 2 + y ** 2
752+ f = compile ([x , y ], outs )
753+
754+ point_f = PointFunc (f )
755+ np .testing .assert_allclose (point_f ({"y" : [3 ], "x" : [2 ]}), [4 + 9 ])
756+
757+ # Check we can access other methods of the wrapped pytensor function
758+ dprint_res = point_f .dprint (file = "str" )
759+ expected_dprint_res = point_f .f .dprint (file = "str" )
760+ assert dprint_res == expected_dprint_res
You can’t perform that action at this time.
0 commit comments