From ee725ee507c2c622dd1d5f62794af77c5c872c6e Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Thu, 19 Jun 2025 12:04:10 +0200 Subject: [PATCH] Allow accessing wrapped function attributes in PointFunc --- pymc/pytensorf.py | 6 ++++++ tests/test_pytensorf.py | 15 +++++++++++++++ 2 files changed, 21 insertions(+) diff --git a/pymc/pytensorf.py b/pymc/pytensorf.py index 78eb3f7bbc..b4b5281e0e 100644 --- a/pymc/pytensorf.py +++ b/pymc/pytensorf.py @@ -600,6 +600,12 @@ def __init__(self, f): def __call__(self, state): return self.f(**state) + def __getattr__(self, item): + """Allow access to the original function attributes.""" + if item == "f": + return self.f + return getattr(self.f, item) + class CallableTensor: """Turns a symbolic variable with one input into a function that returns symbolic arguments with the one variable replaced with the input.""" diff --git a/tests/test_pytensorf.py b/tests/test_pytensorf.py index 34360397a3..b0aee9282d 100644 --- a/tests/test_pytensorf.py +++ b/tests/test_pytensorf.py @@ -36,6 +36,7 @@ from pymc.exceptions import NotConstantValueError from pymc.logprob.utils import ParameterValueError from pymc.pytensorf import ( + PointFunc, collect_default_updates, compile, constant_fold, @@ -780,3 +781,17 @@ def test_hessian_sign_change_warning(func): res_neg = func(f, vars=[x]) res = func(f, vars=[x], negate_output=False) assert equal_computations([res_neg], [-res]) + + +def test_point_func(): + x, y = pt.vectors("x", "y") + outs = x * 2 + y**2 + f = compile([x, y], outs) + + point_f = PointFunc(f) + np.testing.assert_allclose(point_f({"y": [3], "x": [2]}), [4 + 9]) + + # Check we can access other methods of the wrapped pytensor function + dprint_res = point_f.dprint(file="str") + expected_dprint_res = point_f.f.dprint(file="str") + assert dprint_res == expected_dprint_res