Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 21 additions & 7 deletions pymc/model/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
import pytensor.tensor as pt
import scipy.sparse as sps

from pytensor.compile import DeepCopyOp, Function, get_mode
from pytensor.compile import DeepCopyOp, Function, ProfileStats, get_mode
from pytensor.compile.sharedvalue import SharedVariable
from pytensor.graph.basic import Constant, Variable, ancestors, graph_inputs
from pytensor.tensor.random.op import RandomVariable
Expand Down Expand Up @@ -1657,7 +1657,15 @@ def compile_fn(
return PointFunc(fn)
return fn

def profile(self, outs, *, n=1000, point=None, profile=True, **kwargs):
def profile(
self,
outs,
*,
n=1000,
point=None,
profile=True,
**compile_fn_kwargs,
) -> ProfileStats:
"""Compile and profile a PyTensor function which returns ``outs`` and takes values of model vars as a dict as an argument.

Parameters
Expand All @@ -1668,16 +1676,22 @@ def profile(self, outs, *, n=1000, point=None, profile=True, **kwargs):
point : Point
Point to pass to the function
profile : True or ProfileStats
args, kwargs
Compilation args
compile_fn_kwargs
Compilation kwargs for :func:`pymc.model.core.Model.compile_fn`

Returns
-------
ProfileStats
pytensor.compile.profiling.ProfileStats
Use .summary() to print stats.
"""
kwargs.setdefault("on_unused_input", "ignore")
f = self.compile_fn(outs, inputs=self.value_vars, point_fn=False, profile=profile, **kwargs)
compile_fn_kwargs.setdefault("on_unused_input", "ignore")
f = self.compile_fn(
outs,
inputs=self.value_vars,
point_fn=False,
profile=profile,
**compile_fn_kwargs,
)
if point is None:
point = self.initial_point()

Expand Down