@@ -212,8 +212,10 @@ def logp(self, vars=None, **kwargs):
212
212
return m ._logp (vars = vars , ** kwargs )
213
213
214
214
def clone (self ):
215
- m = MarginalModel ()
216
- vars = self .basic_RVs + self .potentials + self .deterministics + self .marginalized_rvs
215
+ m = MarginalModel (coords = self .coords )
216
+ model_vars = self .basic_RVs + self .potentials + self .deterministics + self .marginalized_rvs
217
+ data_vars = [var for name , var in self .named_vars .items () if var not in model_vars ]
218
+ vars = model_vars + data_vars
217
219
cloned_vars = clone_replace (vars )
218
220
vars_to_clone = {var : cloned_var for var , cloned_var in zip (vars , cloned_vars )}
219
221
m .vars_to_clone = vars_to_clone
@@ -598,7 +600,7 @@ def replace_finite_discrete_marginal_subgraph(fgraph, rv_to_marginalize, all_rvs
598
600
# can ultimately be generated that is proportional to the support domain and not
599
601
# to the variables dimensions
600
602
# We don't need to worry about this if the RV is scalar.
601
- if np .prod (constant_fold (tuple (rv_to_marginalize .shape ))) > 1 :
603
+ if np .prod (constant_fold (tuple (rv_to_marginalize .shape ), raise_not_constant = False )) != 1 :
602
604
if not is_elemwise_subgraph (rv_to_marginalize , dependent_rvs_input_rvs , dependent_rvs ):
603
605
raise NotImplementedError (
604
606
"The subgraph between a marginalized RV and its dependents includes non Elemwise operations. "
@@ -682,7 +684,7 @@ def finite_discrete_marginal_rv_logp(op, values, *inputs, **kwargs):
682
684
# batched dimensions of the marginalized RV
683
685
684
686
# PyMC does not allow RVs in the logp graph, even if we are just using the shape
685
- marginalized_rv_shape = constant_fold (tuple (marginalized_rv .shape ))
687
+ marginalized_rv_shape = constant_fold (tuple (marginalized_rv .shape ), raise_not_constant = False )
686
688
marginalized_rv_domain = get_domain_of_finite_discrete_rv (marginalized_rv )
687
689
marginalized_rv_domain_tensor = pt .moveaxis (
688
690
pt .full (
0 commit comments