Skip to content

Commit ee725ee

Browse files
committed
Allow accessing wrapped function attributes in PointFunc
1 parent 0f1bfa9 commit ee725ee

File tree

2 files changed

+21
-0
lines changed

2 files changed

+21
-0
lines changed

pymc/pytensorf.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -600,6 +600,12 @@ def __init__(self, f):
600600
def __call__(self, state):
601601
return self.f(**state)
602602

603+
def __getattr__(self, item):
604+
"""Allow access to the original function attributes."""
605+
if item == "f":
606+
return self.f
607+
return getattr(self.f, item)
608+
603609

604610
class CallableTensor:
605611
"""Turns a symbolic variable with one input into a function that returns symbolic arguments with the one variable replaced with the input."""

tests/test_pytensorf.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from pymc.exceptions import NotConstantValueError
3737
from pymc.logprob.utils import ParameterValueError
3838
from pymc.pytensorf import (
39+
PointFunc,
3940
collect_default_updates,
4041
compile,
4142
constant_fold,
@@ -780,3 +781,17 @@ def test_hessian_sign_change_warning(func):
780781
res_neg = func(f, vars=[x])
781782
res = func(f, vars=[x], negate_output=False)
782783
assert equal_computations([res_neg], [-res])
784+
785+
786+
def test_point_func():
787+
x, y = pt.vectors("x", "y")
788+
outs = x * 2 + y**2
789+
f = compile([x, y], outs)
790+
791+
point_f = PointFunc(f)
792+
np.testing.assert_allclose(point_f({"y": [3], "x": [2]}), [4 + 9])
793+
794+
# Check we can access other methods of the wrapped pytensor function
795+
dprint_res = point_f.dprint(file="str")
796+
expected_dprint_res = point_f.f.dprint(file="str")
797+
assert dprint_res == expected_dprint_res

0 commit comments

Comments
 (0)