@@ -64,20 +64,20 @@ def parse_vars(model: Model, vars: ModelVariable | Sequence[ModelVariable]) -> l
6464
6565
6666def model_to_minibatch (
67- model : Model , * , batch_size : int , vars_to_minibatch : list [str ] | None = None
67+ model : Model , * , batch_size : int , minibatch_vars : list [str ] | None = None
6868) -> Model :
6969 """Replace all Data containers with pm.Minibatch, and add total_size to all observed RVs."""
7070 from pymc .variational .minibatch_rv import create_minibatch_rv
7171
72- if vars_to_minibatch is None :
73- vars_to_minibatch = [
72+ if minibatch_vars is None :
73+ original_minibatch_vars = [
7474 variable
7575 for variable in model .data_vars
7676 if (variable .type .ndim > 0 ) and (variable .type .shape [0 ] is None )
7777 ]
7878 else :
79- vars_to_minibatch = parse_vars (model , vars_to_minibatch )
80- for variable in vars_to_minibatch :
79+ original_minibatch_vars = parse_vars (model , minibatch_vars )
80+ for variable in original_minibatch_vars :
8181 if variable .type .ndim == 0 :
8282 raise ValueError (
8383 f"Cannot minibatch { variable .name } because it is a scalar variable."
@@ -93,66 +93,58 @@ def model_to_minibatch(
9393
9494 fgraph , memo = fgraph_from_model (model , inlined_views = True )
9595
96- cloned_vars_to_minibatch = [memo [var ] for var in vars_to_minibatch ]
97- minibatch_vars = Minibatch (* cloned_vars_to_minibatch , batch_size = batch_size )
98-
99- var_to_dummy = {
100- var : var .type () # model_named(minibatch_var, *extract_dims(var))
101- for var , minibatch_var in zip (cloned_vars_to_minibatch , minibatch_vars )
102- }
103- dummy_to_minibatch = {
104- var_to_dummy [var ]: minibatch_var
105- for var , minibatch_var in zip (cloned_vars_to_minibatch , minibatch_vars )
106- }
107- total_size = (cloned_vars_to_minibatch [0 ].owner .inputs [0 ].shape [0 ], ...)
108-
109- # TODO: If vars_to_minibatch had a leading dim, we should check that the dependent RVs also has that same dim
110- # (or just do this all in xtensor)
111- vtm_set = set (vars_to_minibatch )
112-
113- # TODO: Handle potentials, free_RVs, etc
114-
115- # Create a temporary fgraph that does not include as outputs any of the variables that will be minibatched. This
116- # ensures the results of this function match the outputs from a model constructed using the pm.Minibatch API.
117- tmp_fgraph = FunctionGraph (
118- outputs = [out for out in fgraph .outputs if out not in var_to_dummy .keys ()], clone = False
96+ pre_minibatch_vars = [memo [var ] for var in original_minibatch_vars ]
97+ minibatch_vars = Minibatch (* pre_minibatch_vars , batch_size = batch_size )
98+
99+ # Replace uses of the specified data variables with Minibatch variables
100+ # We need a two-step clone because FunctionGraph can only mutate one variable at a time
101+ # and when there are multiple vars to minibatch you end up replacing the same variable twice recursively
102+ # exampre: out = x + y
103+ # goal: replace (x, y) by (Minibatch(x, y).0, Minibatch(x, y).1)]
104+ # replace x first we get: out = Minibatch(x, y).0 + y
105+ # then replace y we get: out = Minibatch(x, Minibatch(...).1).0 + Minibatch(x, y).1
106+ # The second replacement of y ends up creating a circular dependency
107+ pre_minibatch_var_to_dummy = tuple ((var , var .type ()) for var in pre_minibatch_vars )
108+ dummy_to_minibatch_var = tuple (
109+ (dummy , minibatch_var )
110+ for (_ , dummy ), minibatch_var in zip (pre_minibatch_var_to_dummy , minibatch_vars )
119111 )
120112
121- # All variables that will be minibatched are first replaced by dummy variables, to avoid infinite recursion during
122- # rewrites. The issue is that the Minibatch Op we will introduce depends on the original input variables (to get
123- # the shapes). That's fine in the final output, but during the intermediate rewrites this creates a circulatr
124- # dependency.
125- dummy_replacements = tuple (var_to_dummy .items ())
126- toposort_replace (tmp_fgraph , dummy_replacements )
127-
128- # Now we can replace the dummy variables with the actual Minibatch variables.
129- replacements = tuple (dummy_to_minibatch .items ())
130- toposort_replace (tmp_fgraph , replacements )
131-
132- # The last step is to replace all RVs that depend on the minibatched variables with MinibatchRVs that are aware
133- # of the total_size. Importantly, all of the toposort_replace calls above modify fgraph in place, so the
134- # model.rvs_to_values[original_rv] will already have been modified to depend on the Minibatch variables -- only
135- # the outer RVs need to be replaced here.
136- dependent_replacements = {}
113+ # Furthermore, we only want to replace uses of the data variables (x, y), but not the data variables themselves,
114+ # So we use an intermediate FunctionGraph that doesn't contain the data variables as outputs
115+ other_model_vars = [out for out in fgraph .outputs if out not in pre_minibatch_vars ]
116+ minibatch_fgraph = FunctionGraph (outputs = other_model_vars , clone = False )
117+ minibatch_fgraph ._coords = fgraph ._coords # type: ignore[attr-defined]
118+ minibatch_fgraph ._dim_lengths = fgraph ._dim_lengths # type: ignore[attr-defined]
119+ toposort_replace (minibatch_fgraph , pre_minibatch_var_to_dummy )
120+ toposort_replace (minibatch_fgraph , dummy_to_minibatch_var )
137121
138- for original_rv in model .observed_RVs :
139- original_value_var = model .rvs_to_values [original_rv ]
140-
141- if not (set (ancestors ([original_rv , original_value_var ])) & vtm_set ):
122+ # Then replace all observed RVs that depend on the minibatch variables with MinibatchRVs
123+ dependent_replacements = {}
124+ total_size = (pre_minibatch_vars [0 ].owner .inputs [0 ].shape [0 ], ...)
125+ vars_to_minibatch_set = set (pre_minibatch_vars )
126+ for model_var in minibatch_fgraph .outputs :
127+ if not (set (ancestors ([model_var ])) & vars_to_minibatch_set ):
142128 continue
143-
144- rv = memo [original_rv ].owner .inputs [0 ]
145- dependent_replacements [rv ] = create_minibatch_rv (rv , total_size = total_size )
146-
147- toposort_replace (fgraph , tuple (dependent_replacements .items ()))
148-
149- # FIXME: The fgraph is being rebuilt here to clean up the clients. It is not clear why they are getting messed up
150- # in the first place (pytensor bug, or something wrong in the above manipulations?)
151- new_fgraph = FunctionGraph (outputs = fgraph .outputs )
152- new_fgraph ._coords = fgraph ._coords # type: ignore[attr-defined]
153- new_fgraph ._dim_lengths = fgraph ._dim_lengths # type: ignore[attr-defined]
154-
155- return model_from_fgraph (new_fgraph , mutate_fgraph = True )
129+ if not isinstance (model_var .owner .op , ModelObservedRV ):
130+ raise ValueError (
131+ "Minibatching only supports observed RVs depending on minibatched variables. "
132+ f"Found dependent unobserved variable: { model_var .name } ."
133+ )
134+ # TODO: If vars_to_minibatch had a leading dim, we should check that the dependent RVs also has that same dim
135+ # And conversely other variables do not have that dim
136+ observed_rv = model_var .owner .inputs [0 ]
137+ dependent_replacements [observed_rv ] = create_minibatch_rv (
138+ observed_rv , total_size = total_size
139+ )
140+
141+ toposort_replace (minibatch_fgraph , tuple (dependent_replacements .items ()))
142+
143+ # Finally reintroduce the original data variable outputs
144+ for pre_minibatch_var in pre_minibatch_vars :
145+ minibatch_fgraph .add_output (pre_minibatch_var )
146+
147+ return model_from_fgraph (minibatch_fgraph , mutate_fgraph = True )
156148
157149
158150def remove_minibatched_nodes (model : Model ) -> Model :
0 commit comments