Skip to content

Commit 4bbb5c3

Browse files
committed
.WIP Fix interdependent replacements
1 parent 17f503f commit 4bbb5c3

File tree

1 file changed

+38
-7
lines changed

1 file changed

+38
-7
lines changed

pymc/initial_point.py

Lines changed: 38 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,9 @@
2020
import pytensor
2121
import pytensor.tensor as pt
2222

23+
from pytensor import graph_replace
2324
from pytensor.compile.ops import TypeCastingOp
24-
from pytensor.graph.basic import Apply, Variable, ancestors
25+
from pytensor.graph.basic import Apply, Variable, ancestors, walk
2526
from pytensor.graph.fg import FunctionGraph
2627
from pytensor.graph.rewriting.db import RewriteDatabaseQuery, SequenceDB
2728
from pytensor.tensor.variable import TensorVariable
@@ -201,6 +202,17 @@ def make_node(self, var):
201202
return Apply(self, [var], [var.type()])
202203

203204

205+
def non_support_point_ancestors(value):
206+
def expand(r: Variable):
207+
node = r.owner
208+
if node is not None and not isinstance(node.op, InitialPoint):
209+
# Stop graph traversal at InitialPoint ops
210+
return node.inputs
211+
return None
212+
213+
yield from walk([value], expand, bfs=False)
214+
215+
204216
initial_point_op = InitialPoint()
205217

206218

@@ -253,13 +265,13 @@ def make_initial_point_expression(
253265
if initial_point_rewriter:
254266
initial_point_rewriter.rewrite(initial_point_fgraph)
255267

256-
free_rvs_clone = initial_point_fgraph.outputs
268+
ip_variables = initial_point_fgraph.outputs.copy()
269+
free_rvs_clone = [ip.owner.inputs[0] for ip in ip_variables]
270+
n_rvs = len(free_rvs_clone)
257271

258272
initial_values = []
259273
initial_values_transformed = []
260-
for original_variable, ip_variable in zip(free_rvs, free_rvs_clone):
261-
# Extract the variable from the initial_point operation
262-
[variable] = ip_variable.owner.inputs
274+
for original_variable, variable in zip(free_rvs, free_rvs_clone):
263275
strategy = initval_strategies.get(original_variable)
264276

265277
if strategy is None:
@@ -269,6 +281,20 @@ def make_initial_point_expression(
269281
if strategy == "support_point":
270282
try:
271283
value = support_point(variable)
284+
285+
# If a support point expression depends on other free_RVs that are not
286+
# wrapped in InitialPoint, we need to replace them with their wrapped versions
287+
# This can only happen for multi-output distributions, where the initial point
288+
# of some outputs depends on the initial point of other outputs from the same node.
289+
other_free_rvs = set(free_rvs_clone) - {variable}
290+
support_point_replacements = {
291+
ancestor: ip_variables[free_rvs_clone.index(ancestor)]
292+
for ancestor in non_support_point_ancestors(value)
293+
if ancestor in other_free_rvs
294+
}
295+
if support_point_replacements:
296+
value = graph_replace(value, support_point_replacements)
297+
272298
except NotImplementedError:
273299
warnings.warn(
274300
f"Support point not defined for variable {variable} of type "
@@ -314,14 +340,19 @@ def make_initial_point_expression(
314340

315341
initial_values.append(value)
316342

343+
for initial_value in initial_values:
344+
# Adding the initial value to the fgraph outputs seems to help
345+
# with interdependenncies. WHY????
346+
initial_point_fgraph.add_output(initial_value)
347+
317348
# We now replace all rvs by the respective initial_point expressions
318349
# in the constrained (untransformed) space. We do this in reverse topological
319350
# order, so that later nodes do not reintroduce expressions with earlier
320351
# rvs that would need to once again be replaced by their initial_points
321-
toposort_replace(initial_point_fgraph, tuple(zip(free_rvs_clone, initial_values)), reverse=True)
352+
toposort_replace(initial_point_fgraph, tuple(zip(ip_variables, initial_values)), reverse=True)
322353

323354
if not return_transformed:
324-
return initial_point_fgraph.outputs
355+
return initial_point_fgraph.outputs[:n_rvs]
325356

326357
# Because the unconstrained (transformed) expressions are a subgraph of the
327358
# constrained initial point they were also automatically updated inplace

0 commit comments

Comments
 (0)