@@ -62,14 +62,24 @@ def parse_vars(model: Model, vars: ModelVariable | Sequence[ModelVariable]) -> l
6262 return [model [var ] if isinstance (var , str ) else var for var in vars_seq ]
6363
6464
65- def remove_minibatched_nodes (model : Model ):
65+ def remove_minibatched_nodes (model : pm . Model ) -> pm . Model :
6666 """Remove all uses of pm.Minibatch in the Model."""
67+ fgraph , _ = fgraph_from_model (model )
6768
68- @node_rewriter ([MinibatchOp ])
69- def local_remove_minibatch (fgraph , node ):
70- return node .inputs
69+ replacements = {}
70+ for var in fgraph .apply_nodes :
71+ if isinstance (var .op , MinibatchOp ):
72+ for inp , out in zip (var .inputs , var .outputs ):
73+ replacements [out ] = inp
7174
72- remove_minibatch = out2in (local_remove_minibatch )
73- fgraph , _ = fgraph_from_model (model )
74- remove_minibatch .apply (fgraph )
75+ old_outs , old_coords , old_dim_lengths = fgraph .outputs , fgraph ._coords , fgraph ._dim_lengths
76+ # Using `rebuild_strict=False` means all coords, names, and dim information is lost
77+ # So we need to restore it from the old fgraph
78+ new_outs = pytensor .clone_replace (old_outs , replacements , rebuild_strict = False )
79+ for old_out , new_out in zip (old_outs , new_outs ):
80+ new_out .name = old_out .name
81+ fgraph = pytensor .graph .fg .FunctionGraph (outputs = new_outs , clone = False )
82+ fgraph ._coords = old_coords
83+ fgraph ._dim_lengths = old_dim_lengths
7584 return model_from_fgraph (fgraph , mutate_fgraph = True )
85+
0 commit comments