@@ -410,7 +410,7 @@ def transform_input(inputs):
410410 marginalized_rv .type , dependent_logps
411411 )
412412
413- rv_shape = constant_fold (tuple (marginalized_rv .shape ))
413+ rv_shape = constant_fold (tuple (marginalized_rv .shape ), raise_not_constant = False )
414414 rv_domain = get_domain_of_finite_discrete_rv (marginalized_rv )
415415 rv_domain_tensor = pt .moveaxis (
416416 pt .full (
@@ -579,6 +579,15 @@ def is_elemwise_subgraph(rv_to_marginalize, other_input_rvs, output_rvs):
579579 return True
580580
581581
582+ from pytensor .graph .basic import graph_inputs
583+
584+
585+ def collect_shared_vars (outputs , blockers ):
586+ return [
587+ inp for inp in graph_inputs (outputs , blockers = blockers ) if isinstance (inp , SharedVariable )
588+ ]
589+
590+
582591def replace_finite_discrete_marginal_subgraph (fgraph , rv_to_marginalize , all_rvs ):
583592 # TODO: This should eventually be integrated in a more general routine that can
584593 # identify other types of supported marginalization, of which finite discrete
@@ -621,27 +630,21 @@ def replace_finite_discrete_marginal_subgraph(fgraph, rv_to_marginalize, all_rvs
621630 rvs_to_marginalize = [rv_to_marginalize , * dependent_rvs ]
622631
623632 outputs = rvs_to_marginalize
624- # Clone replace inner RV rng inputs so that we can be sure of the update order
625- # replace_inputs = {rng: rng.type() for rng in updates_rvs_to_marginalize.keys()}
626- # Clone replace outter RV inputs, so that their shared RNGs don't make it into
627- # the inner graph of the marginalized RVs
628- # FIXME: This shouldn't be needed!
629- replace_inputs = {}
630- replace_inputs .update ({input_rv : input_rv .type () for input_rv in input_rvs })
631- cloned_outputs = clone_replace (outputs , replace = replace_inputs )
633+ # We are strict about shared variables in SymbolicRandomVariables
634+ inputs = input_rvs + collect_shared_vars (rvs_to_marginalize , blockers = input_rvs )
632635
633636 if isinstance (rv_to_marginalize .owner .op , DiscreteMarkovChain ):
634637 marginalize_constructor = DiscreteMarginalMarkovChainRV
635638 else :
636639 marginalize_constructor = FiniteDiscreteMarginalRV
637640
638641 marginalization_op = marginalize_constructor (
639- inputs = list ( replace_inputs . values ()) ,
640- outputs = cloned_outputs ,
642+ inputs = inputs ,
643+ outputs = outputs ,
641644 ndim_supp = ndim_supp ,
642645 )
643646
644- marginalized_rvs = marginalization_op (* replace_inputs . keys () )
647+ marginalized_rvs = marginalization_op (* inputs )
645648 fgraph .replace_all (tuple (zip (rvs_to_marginalize , marginalized_rvs )))
646649 return rvs_to_marginalize , marginalized_rvs
647650
0 commit comments