File tree Expand file tree Collapse file tree 2 files changed +21
-0
lines changed
Expand file tree Collapse file tree 2 files changed +21
-0
lines changed Original file line number Diff line number Diff line change @@ -600,6 +600,12 @@ def __init__(self, f):
600600 def __call__ (self , state ):
601601 return self .f (** state )
602602
603+ def __getattr__ (self , item ):
604+ """Allow access to the original function attributes."""
605+ if item == "f" :
606+ return self .f
607+ return getattr (self .f , item )
608+
603609
604610class CallableTensor :
605611 """Turns a symbolic variable with one input into a function that returns symbolic arguments with the one variable replaced with the input."""
Original file line number Diff line number Diff line change 3636from pymc .exceptions import NotConstantValueError
3737from pymc .logprob .utils import ParameterValueError
3838from pymc .pytensorf import (
39+ PointFunc ,
3940 collect_default_updates ,
4041 compile ,
4142 constant_fold ,
@@ -780,3 +781,17 @@ def test_hessian_sign_change_warning(func):
780781 res_neg = func (f , vars = [x ])
781782 res = func (f , vars = [x ], negate_output = False )
782783 assert equal_computations ([res_neg ], [- res ])
784+
785+
786+ def test_point_func ():
787+ x , y = pt .vectors ("x" , "y" )
788+ outs = x * 2 + y ** 2
789+ f = compile ([x , y ], outs )
790+
791+ point_f = PointFunc (f )
792+ np .testing .assert_allclose (point_f ({"y" : [3 ], "x" : [2 ]}), [4 + 9 ])
793+
794+ # Check we can access other methods of the wrapped pytensor function
795+ dprint_res = point_f .dprint (file = "str" )
796+ expected_dprint_res = point_f .f .dprint (file = "str" )
797+ assert dprint_res == expected_dprint_res
You can’t perform that action at this time.
0 commit comments