Skip to content

Commit cd2e1a3

Browse files
authored
Allow accessing wrapped function attributes in PointFunc (#7823)
1 parent f6bfdfd commit cd2e1a3

File tree

2 files changed

+21
-0
lines changed

2 files changed

+21
-0
lines changed

pymc/pytensorf.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -565,6 +565,12 @@ def __init__(self, f):
565565
def __call__(self, state):
566566
return self.f(**state)
567567

568+
def __getattr__(self, item):
569+
"""Allow access to the original function attributes."""
570+
if item == "f":
571+
return self.f
572+
return getattr(self.f, item)
573+
568574

569575
class CallableTensor:
570576
"""Turns a symbolic variable with one input into a function that returns symbolic arguments with the one variable replaced with the input."""

tests/test_pytensorf.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from pymc.exceptions import NotConstantValueError
3636
from pymc.logprob.utils import ParameterValueError
3737
from pymc.pytensorf import (
38+
PointFunc,
3839
collect_default_updates,
3940
compile,
4041
constant_fold,
@@ -743,3 +744,17 @@ def test_hessian_sign_change_warning(func):
743744
res_neg = func(f, vars=[x])
744745
res = func(f, vars=[x], negate_output=False)
745746
assert equal_computations([res_neg], [-res])
747+
748+
749+
def test_point_func():
750+
x, y = pt.vectors("x", "y")
751+
outs = x * 2 + y**2
752+
f = compile([x, y], outs)
753+
754+
point_f = PointFunc(f)
755+
np.testing.assert_allclose(point_f({"y": [3], "x": [2]}), [4 + 9])
756+
757+
# Check we can access other methods of the wrapped pytensor function
758+
dprint_res = point_f.dprint(file="str")
759+
expected_dprint_res = point_f.f.dprint(file="str")
760+
assert dprint_res == expected_dprint_res

0 commit comments

Comments
 (0)