2020import pytensor
2121import pytensor .tensor as pt
2222
23+ from pytensor import graph_replace
2324from pytensor .compile .ops import TypeCastingOp
24- from pytensor .graph .basic import Apply , Variable , ancestors
25+ from pytensor .graph .basic import Apply , Variable , ancestors , walk
2526from pytensor .graph .fg import FunctionGraph
2627from pytensor .graph .rewriting .db import RewriteDatabaseQuery , SequenceDB
2728from pytensor .tensor .variable import TensorVariable
@@ -201,6 +202,17 @@ def make_node(self, var):
201202 return Apply (self , [var ], [var .type ()])
202203
203204
205+ def non_support_point_ancestors (value ):
206+ def expand (r : Variable ):
207+ node = r .owner
208+ if node is not None and not isinstance (node .op , InitialPoint ):
209+ # Stop graph traversal at InitialPoint ops
210+ return node .inputs
211+ return None
212+
213+ yield from walk ([value ], expand , bfs = False )
214+
215+
204216initial_point_op = InitialPoint ()
205217
206218
@@ -253,13 +265,13 @@ def make_initial_point_expression(
253265 if initial_point_rewriter :
254266 initial_point_rewriter .rewrite (initial_point_fgraph )
255267
256- free_rvs_clone = initial_point_fgraph .outputs
268+ ip_variables = initial_point_fgraph .outputs .copy ()
269+ free_rvs_clone = [ip .owner .inputs [0 ] for ip in ip_variables ]
270+ n_rvs = len (free_rvs_clone )
257271
258272 initial_values = []
259273 initial_values_transformed = []
260- for original_variable , ip_variable in zip (free_rvs , free_rvs_clone ):
261- # Extract the variable from the initial_point operation
262- [variable ] = ip_variable .owner .inputs
274+ for original_variable , variable in zip (free_rvs , free_rvs_clone ):
263275 strategy = initval_strategies .get (original_variable )
264276
265277 if strategy is None :
@@ -269,6 +281,20 @@ def make_initial_point_expression(
269281 if strategy == "support_point" :
270282 try :
271283 value = support_point (variable )
284+
285+ # If a support point expression depends on other free_RVs that are not
286+ # wrapped in InitialPoint, we need to replace them with their wrapped versions
287+ # This can only happen for multi-output distributions, where the initial point
288+ # of some outputs depends on the initial point of other outputs from the same node.
289+ other_free_rvs = set (free_rvs_clone ) - {variable }
290+ support_point_replacements = {
291+ ancestor : ip_variables [free_rvs_clone .index (ancestor )]
292+ for ancestor in non_support_point_ancestors (value )
293+ if ancestor in other_free_rvs
294+ }
295+ if support_point_replacements :
296+ value = graph_replace (value , support_point_replacements )
297+
272298 except NotImplementedError :
273299 warnings .warn (
274300 f"Support point not defined for variable { variable } of type "
@@ -314,14 +340,19 @@ def make_initial_point_expression(
314340
315341 initial_values .append (value )
316342
343+ for initial_value in initial_values :
344+ # Adding the initial value to the fgraph outputs seems to help
345+ # with interdependenncies. WHY????
346+ initial_point_fgraph .add_output (initial_value )
347+
317348 # We now replace all rvs by the respective initial_point expressions
318349 # in the constrained (untransformed) space. We do this in reverse topological
319350 # order, so that later nodes do not reintroduce expressions with earlier
320351 # rvs that would need to once again be replaced by their initial_points
321- toposort_replace (initial_point_fgraph , tuple (zip (free_rvs_clone , initial_values )), reverse = True )
352+ toposort_replace (initial_point_fgraph , tuple (zip (ip_variables , initial_values )), reverse = True )
322353
323354 if not return_transformed :
324- return initial_point_fgraph .outputs
355+ return initial_point_fgraph .outputs [: n_rvs ]
325356
326357 # Because the unconstrained (transformed) expressions are a subgraph of the
327358 # constrained initial point they were also automatically updated inplace
0 commit comments