Skip to content

Commit 1cc9863

Browse files
committed
Use more readable ignore_logprob helper in logprob submodule
1 parent 0bda194 commit 1cc9863

File tree

6 files changed

+17
-39
lines changed

6 files changed

+17
-39
lines changed

pymc/logprob/censoring.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,9 @@
5252
MeasurableVariable,
5353
_logcdf,
5454
_logprob,
55-
assign_custom_measurable_outputs,
5655
)
5756
from pymc.logprob.rewriting import measurable_ir_rewrites_db
58-
from pymc.logprob.utils import CheckParameterValue
57+
from pymc.logprob.utils import CheckParameterValue, ignore_logprob
5958

6059

6160
class MeasurableClip(MeasurableElemwise):
@@ -95,7 +94,7 @@ def find_measurable_clips(fgraph: FunctionGraph, node: Node) -> Optional[List[Me
9594
upper_bound = upper_bound if (upper_bound is not base_var) else pt.constant(np.inf)
9695

9796
# Make base_var unmeasurable
98-
unmeasurable_base_var = assign_custom_measurable_outputs(base_var.owner)
97+
unmeasurable_base_var = ignore_logprob(base_var)
9998
clipped_rv_node = measurable_clip.make_node(unmeasurable_base_var, lower_bound, upper_bound)
10099
clipped_rv = clipped_rv_node.outputs[0]
101100

@@ -198,7 +197,7 @@ def find_measurable_roundings(fgraph: FunctionGraph, node: Node) -> Optional[Lis
198197
return None
199198

200199
# Make base_var unmeasurable
201-
unmeasurable_base_var = assign_custom_measurable_outputs(base_var.owner)
200+
unmeasurable_base_var = ignore_logprob(base_var)
202201

203202
rounded_op = MeasurableRound(node.op.scalar_op)
204203
rounded_rv = rounded_op.make_node(unmeasurable_base_var).default_output()

pymc/logprob/cumsum.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -41,13 +41,9 @@
4141
from pytensor.graph.rewriting.basic import node_rewriter
4242
from pytensor.tensor.extra_ops import CumOp
4343

44-
from pymc.logprob.abstract import (
45-
MeasurableVariable,
46-
_logprob,
47-
assign_custom_measurable_outputs,
48-
logprob,
49-
)
44+
from pymc.logprob.abstract import MeasurableVariable, _logprob, logprob
5045
from pymc.logprob.rewriting import PreserveRVMappings, measurable_ir_rewrites_db
46+
from pymc.logprob.utils import ignore_logprob
5147

5248

5349
class MeasurableCumsum(CumOp):
@@ -112,7 +108,7 @@ def find_measurable_cumsums(fgraph, node) -> Optional[List[MeasurableCumsum]]:
112108

113109
new_op = MeasurableCumsum(axis=node.op.axis or 0, mode="add")
114110
# Make base_var unmeasurable
115-
unmeasurable_base_rv = assign_custom_measurable_outputs(base_rv.owner)
111+
unmeasurable_base_rv = ignore_logprob(base_rv)
116112
new_rv = new_op.make_node(unmeasurable_base_rv).default_output()
117113
new_rv.name = rv.name
118114

pymc/logprob/mixture.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -69,18 +69,14 @@
6969
from pytensor.tensor.type_other import NoneConst, NoneTypeT, SliceConstant, SliceType
7070
from pytensor.tensor.var import TensorVariable
7171

72-
from pymc.logprob.abstract import (
73-
MeasurableVariable,
74-
_logprob,
75-
assign_custom_measurable_outputs,
76-
logprob,
77-
)
72+
from pymc.logprob.abstract import MeasurableVariable, _logprob, logprob
7873
from pymc.logprob.rewriting import (
7974
local_lift_DiracDelta,
8075
logprob_rewrites_db,
8176
subtensor_ops,
8277
)
8378
from pymc.logprob.tensor import naive_bcast_rv_lift
79+
from pymc.logprob.utils import ignore_logprob
8480

8581

8682
def is_newaxis(x):
@@ -328,9 +324,7 @@ def mixture_replace(fgraph, node):
328324
# We create custom types for the mixture components and assign them
329325
# null `get_measurable_outputs` dispatches so that they aren't
330326
# erroneously encountered in places like `factorized_joint_logprob`.
331-
new_node = assign_custom_measurable_outputs(component_rv.owner)
332-
out_idx = component_rv.owner.outputs.index(component_rv)
333-
new_comp_rv = new_node.outputs[out_idx]
327+
new_comp_rv = ignore_logprob(component_rv)
334328
new_mixture_rvs.append(new_comp_rv)
335329

336330
# Replace this sub-graph with a `MixtureRV`
@@ -379,9 +373,7 @@ def switch_mixture_replace(fgraph, node):
379373
and component_rv not in rv_map_feature.rv_values
380374
):
381375
return None
382-
new_node = assign_custom_measurable_outputs(component_rv.owner)
383-
out_idx = component_rv.owner.outputs.index(component_rv)
384-
new_comp_rv = new_node.outputs[out_idx]
376+
new_comp_rv = ignore_logprob(component_rv)
385377
mixture_rvs.append(new_comp_rv)
386378

387379
mix_op = MixtureRV(

pymc/logprob/tensor.py

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -50,13 +50,9 @@
5050
local_rv_size_lift,
5151
)
5252

53-
from pymc.logprob.abstract import (
54-
MeasurableVariable,
55-
_logprob,
56-
assign_custom_measurable_outputs,
57-
logprob,
58-
)
53+
from pymc.logprob.abstract import MeasurableVariable, _logprob, logprob
5954
from pymc.logprob.rewriting import PreserveRVMappings, measurable_ir_rewrites_db
55+
from pymc.logprob.utils import ignore_logprob
6056

6157

6258
@node_rewriter([BroadcastTo])
@@ -233,12 +229,7 @@ def find_measurable_stacks(
233229
return None # pragma: no cover
234230

235231
# Make base_vars unmeasurable
236-
base_to_unmeasurable_vars = {
237-
base_var: assign_custom_measurable_outputs(base_var.owner).outputs[
238-
base_var.owner.outputs.index(base_var)
239-
]
240-
for base_var in base_vars
241-
}
232+
base_to_unmeasurable_vars = {base_var: ignore_logprob(base_var) for base_var in base_vars}
242233

243234
def replacement_fn(var, replacements):
244235
if var in base_to_unmeasurable_vars:
@@ -339,7 +330,7 @@ def find_measurable_dimshuffles(fgraph, node) -> Optional[List[MeasurableDimShuf
339330
return None # pragma: no cover
340331

341332
# Make base_vars unmeasurable
342-
base_var = assign_custom_measurable_outputs(base_var.owner)
333+
base_var = ignore_logprob(base_var)
343334

344335
measurable_dimshuffle = MeasurableDimShuffle(node.op.input_broadcastable, node.op.new_order)(
345336
base_var

pymc/logprob/transforms.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -78,11 +78,10 @@
7878
MeasurableVariable,
7979
_get_measurable_outputs,
8080
_logprob,
81-
assign_custom_measurable_outputs,
8281
logprob,
8382
)
8483
from pymc.logprob.rewriting import PreserveRVMappings, measurable_ir_rewrites_db
85-
from pymc.logprob.utils import walk_model
84+
from pymc.logprob.utils import ignore_logprob, walk_model
8685

8786

8887
class TransformedVariable(Op):
@@ -549,7 +548,7 @@ def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> Optional[Li
549548

550549
# Make base_measure outputs unmeasurable
551550
# This seems to be the only thing preventing nested rewrites from being erased
552-
measurable_input = assign_custom_measurable_outputs(measurable_input.owner)
551+
measurable_input = ignore_logprob(measurable_input)
553552

554553
scalar_op = node.op.scalar_op
555554
measurable_input_idx = 0

pymc/logprob/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,7 @@ def ignore_logprob(rv: TensorVariable) -> TensorVariable:
275275
op_type = type(node.op)
276276
if op_type.__name__.startswith(prefix):
277277
return rv
278+
# By default `assign_custom_measurable_outputs` makes all outputs unmeasurable
278279
new_node = assign_custom_measurable_outputs(node, type_prefix=prefix)
279280
return new_node.outputs[node.outputs.index(rv)]
280281

0 commit comments

Comments
 (0)