File tree Expand file tree Collapse file tree 2 files changed +21
-0
lines changed Expand file tree Collapse file tree 2 files changed +21
-0
lines changed Original file line number Diff line number Diff line change @@ -565,6 +565,12 @@ def __init__(self, f):
565
565
def __call__ (self , state ):
566
566
return self .f (** state )
567
567
568
+ def __getattr__ (self , item ):
569
+ """Allow access to the original function attributes."""
570
+ if item == "f" :
571
+ return self .f
572
+ return getattr (self .f , item )
573
+
568
574
569
575
class CallableTensor :
570
576
"""Turns a symbolic variable with one input into a function that returns symbolic arguments with the one variable replaced with the input."""
Original file line number Diff line number Diff line change 35
35
from pymc .exceptions import NotConstantValueError
36
36
from pymc .logprob .utils import ParameterValueError
37
37
from pymc .pytensorf import (
38
+ PointFunc ,
38
39
collect_default_updates ,
39
40
compile ,
40
41
constant_fold ,
@@ -743,3 +744,17 @@ def test_hessian_sign_change_warning(func):
743
744
res_neg = func (f , vars = [x ])
744
745
res = func (f , vars = [x ], negate_output = False )
745
746
assert equal_computations ([res_neg ], [- res ])
747
+
748
+
749
+ def test_point_func ():
750
+ x , y = pt .vectors ("x" , "y" )
751
+ outs = x * 2 + y ** 2
752
+ f = compile ([x , y ], outs )
753
+
754
+ point_f = PointFunc (f )
755
+ np .testing .assert_allclose (point_f ({"y" : [3 ], "x" : [2 ]}), [4 + 9 ])
756
+
757
+ # Check we can access other methods of the wrapped pytensor function
758
+ dprint_res = point_f .dprint (file = "str" )
759
+ expected_dprint_res = point_f .f .dprint (file = "str" )
760
+ assert dprint_res == expected_dprint_res
You can’t perform that action at this time.
0 commit comments