Skip to content

Commit 0fbc7d9

Browse files
committed
broken WIP move minibatch transform and rename/rework
1 parent a350467 commit 0fbc7d9

File tree

4 files changed

+304
-196
lines changed

4 files changed

+304
-196
lines changed

pymc/model/transform/basic.py

Lines changed: 1 addition & 110 deletions
Original file line numberDiff line numberDiff line change
@@ -13,19 +13,16 @@
1313
# limitations under the License.
1414
from collections.abc import Sequence
1515

16-
from pytensor import Variable, clone_replace
16+
from pytensor import Variable
1717
from pytensor.graph import ancestors
18-
from pytensor.graph.fg import FunctionGraph
1918

20-
from pymc.data import Minibatch, MinibatchOp
2119
from pymc.model.core import Model
2220
from pymc.model.fgraph import (
2321
ModelObservedRV,
2422
ModelVar,
2523
fgraph_from_model,
2624
model_from_fgraph,
2725
)
28-
from pymc.pytensorf import toposort_replace
2926

3027
ModelVariable = Variable | str
3128

@@ -61,109 +58,3 @@ def parse_vars(model: Model, vars: ModelVariable | Sequence[ModelVariable]) -> l
6158
else:
6259
vars_seq = (vars,)
6360
return [model[var] if isinstance(var, str) else var for var in vars_seq]
64-
65-
66-
def model_to_minibatch(
67-
model: Model, *, batch_size: int, minibatch_vars: list[str] | None = None
68-
) -> Model:
69-
"""Replace all Data containers with pm.Minibatch, and add total_size to all observed RVs."""
70-
from pymc.variational.minibatch_rv import create_minibatch_rv
71-
72-
if minibatch_vars is None:
73-
original_minibatch_vars = [
74-
variable
75-
for variable in model.data_vars
76-
if (variable.type.ndim > 0) and (variable.type.shape[0] is None)
77-
]
78-
else:
79-
original_minibatch_vars = parse_vars(model, minibatch_vars)
80-
for variable in original_minibatch_vars:
81-
if variable.type.ndim == 0:
82-
raise ValueError(
83-
f"Cannot minibatch {variable.name} because it is a scalar variable."
84-
)
85-
if variable.type.shape[0] is not None:
86-
raise ValueError(
87-
f"Cannot minibatch {variable.name} because its first dimension is static "
88-
f"(size={variable.type.shape[0]})."
89-
)
90-
91-
# TODO: Validate that this graph is actually valid to minibatch. Example: linear regression with sigma fixed
92-
# shape, but mu from data --> y cannot be minibatched because of sigma.
93-
94-
fgraph, memo = fgraph_from_model(model, inlined_views=True)
95-
96-
pre_minibatch_vars = [memo[var] for var in original_minibatch_vars]
97-
minibatch_vars = Minibatch(*pre_minibatch_vars, batch_size=batch_size)
98-
99-
# Replace uses of the specified data variables with Minibatch variables
100-
# We need a two-step clone because FunctionGraph can only mutate one variable at a time
101-
# and when there are multiple vars to minibatch you end up replacing the same variable twice recursively
102-
# exampre: out = x + y
103-
# goal: replace (x, y) by (Minibatch(x, y).0, Minibatch(x, y).1)]
104-
# replace x first we get: out = Minibatch(x, y).0 + y
105-
# then replace y we get: out = Minibatch(x, Minibatch(...).1).0 + Minibatch(x, y).1
106-
# The second replacement of y ends up creating a circular dependency
107-
pre_minibatch_var_to_dummy = tuple((var, var.type()) for var in pre_minibatch_vars)
108-
dummy_to_minibatch_var = tuple(
109-
(dummy, minibatch_var)
110-
for (_, dummy), minibatch_var in zip(pre_minibatch_var_to_dummy, minibatch_vars)
111-
)
112-
113-
# Furthermore, we only want to replace uses of the data variables (x, y), but not the data variables themselves,
114-
# So we use an intermediate FunctionGraph that doesn't contain the data variables as outputs
115-
other_model_vars = [out for out in fgraph.outputs if out not in pre_minibatch_vars]
116-
minibatch_fgraph = FunctionGraph(outputs=other_model_vars, clone=False)
117-
minibatch_fgraph._coords = fgraph._coords # type: ignore[attr-defined]
118-
minibatch_fgraph._dim_lengths = fgraph._dim_lengths # type: ignore[attr-defined]
119-
toposort_replace(minibatch_fgraph, pre_minibatch_var_to_dummy)
120-
toposort_replace(minibatch_fgraph, dummy_to_minibatch_var)
121-
122-
# Then replace all observed RVs that depend on the minibatch variables with MinibatchRVs
123-
dependent_replacements = {}
124-
total_size = (pre_minibatch_vars[0].owner.inputs[0].shape[0], ...)
125-
vars_to_minibatch_set = set(pre_minibatch_vars)
126-
for model_var in minibatch_fgraph.outputs:
127-
if not (set(ancestors([model_var])) & vars_to_minibatch_set):
128-
continue
129-
if not isinstance(model_var.owner.op, ModelObservedRV):
130-
raise ValueError(
131-
"Minibatching only supports observed RVs depending on minibatched variables. "
132-
f"Found dependent unobserved variable: {model_var.name}."
133-
)
134-
# TODO: If vars_to_minibatch had a leading dim, we should check that the dependent RVs also has that same dim
135-
# And conversely other variables do not have that dim
136-
observed_rv = model_var.owner.inputs[0]
137-
dependent_replacements[observed_rv] = create_minibatch_rv(
138-
observed_rv, total_size=total_size
139-
)
140-
141-
toposort_replace(minibatch_fgraph, tuple(dependent_replacements.items()))
142-
143-
# Finally reintroduce the original data variable outputs
144-
for pre_minibatch_var in pre_minibatch_vars:
145-
minibatch_fgraph.add_output(pre_minibatch_var)
146-
147-
return model_from_fgraph(minibatch_fgraph, mutate_fgraph=True)
148-
149-
150-
def remove_minibatched_nodes(model: Model) -> Model:
151-
"""Remove all uses of pm.Minibatch in the Model."""
152-
fgraph, _ = fgraph_from_model(model)
153-
154-
replacements = {}
155-
for var in fgraph.apply_nodes:
156-
if isinstance(var.op, MinibatchOp):
157-
for inp, out in zip(var.inputs, var.outputs):
158-
replacements[out] = inp
159-
160-
old_outs, old_coords, old_dim_lengths = fgraph.outputs, fgraph._coords, fgraph._dim_lengths # type: ignore[attr-defined]
161-
# Using `rebuild_strict=False` means all coords, names, and dim information is lost
162-
# So we need to restore it from the old fgraph
163-
new_outs = clone_replace(old_outs, replacements, rebuild_strict=False) # type: ignore[arg-type]
164-
for old_out, new_out in zip(old_outs, new_outs):
165-
new_out.name = old_out.name
166-
fgraph = FunctionGraph(outputs=new_outs, clone=False)
167-
fgraph._coords = old_coords # type: ignore[attr-defined]
168-
fgraph._dim_lengths = old_dim_lengths # type: ignore[attr-defined]
169-
return model_from_fgraph(fgraph, mutate_fgraph=True)

pymc/model/transform/minibatch.py

Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
# Copyright 2025 - present The PyMC Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from collections.abc import Sequence
15+
16+
from pytensor import Variable
17+
from pytensor.graph import FunctionGraph, ancestors
18+
19+
from build.lib.pymc.variational.minibatch_rv import MinibatchRandomVariable
20+
from pymc import Minibatch, Model
21+
from pymc.data import MinibatchOp
22+
from pymc.model.fgraph import ModelObservedRV, fgraph_from_model, model_from_fgraph
23+
from pymc.model.transform.basic import parse_vars
24+
from pymc.pytensorf import toposort_replace
25+
26+
27+
def minibatch_model(
28+
model: Model,
29+
*,
30+
batch_size: int,
31+
minibatch_vars: Sequence[str | Variable] | None = None,
32+
) -> Model:
33+
"""Create a minibatch version of the given Model.
34+
35+
Replaces minibatch_vars data containers with Minibatch views and rescales the logp of dependent observed variables.
36+
37+
.. warning:: This transformation acts on the leading dimension of the specified data variables and dependent observed RVs. If a dimension other than the first is linked to the minibatched data variables, the resulting model will be invalid.
38+
39+
Parameters
40+
----------
41+
model : Model
42+
The original model to transform.
43+
batch_size : int
44+
The minibatch size to use.
45+
minibatch_vars : Sequence of Variable or string, optional
46+
Data variables to convert to minibatch. If None, all data variables with a leading dimension of size None will be minibatched.
47+
48+
Returns
49+
-------
50+
Model
51+
A new Model with the specified data variables replaced by Minibatch views and dependent observed RVs adjusted accordingly.
52+
53+
Raises
54+
------
55+
ValueError
56+
If any of the specified variables cannot be minibatched (e.g., scalar variables or variables with static leading dimensions), or if dependent variables are Potentials / Unobserved RVs.
57+
58+
Examples
59+
--------
60+
.. code-block:: python
61+
62+
import numpy as np
63+
import pymc as pm
64+
from pymc.model.transform.minibatch import minibatch_model
65+
66+
with pm.Model() as m:
67+
obs_data = pm.Data("obs_data", np.random.normal(size=(100,)))
68+
X_data = pm.Data("X_data", np.random.normal(size=(100, 4)))
69+
beta = pm.Normal("beta", mu=np.pi, dims="feature")
70+
71+
mu = X_data @ beta
72+
y = pm.Normal("y", mu=mu, sigma=1, observed=obs_data)
73+
74+
with minibatch_model(m, batch_size=10) as mb:
75+
pm.fit()
76+
"""
77+
from pymc.variational.minibatch_rv import create_minibatch_rv
78+
79+
if minibatch_vars is None:
80+
original_minibatch_vars = [
81+
variable
82+
for variable in model.data_vars
83+
if (variable.type.ndim > 0) and (variable.type.shape[0] is None)
84+
]
85+
else:
86+
original_minibatch_vars = parse_vars(model, minibatch_vars)
87+
for variable in original_minibatch_vars:
88+
if variable.type.ndim == 0:
89+
raise ValueError(
90+
f"Cannot minibatch {variable.name} because it is a scalar variable."
91+
)
92+
if variable.type.shape[0] is not None:
93+
raise ValueError(
94+
f"Cannot minibatch {variable.name} because its first dimension is static "
95+
f"(size={variable.type.shape[0]})."
96+
)
97+
98+
# TODO: Validate that this graph is actually valid to minibatch. Example: linear regression with sigma fixed
99+
# shape, but mu from data --> y cannot be minibatched because of sigma.
100+
101+
fgraph, memo = fgraph_from_model(model, inlined_views=True)
102+
103+
pre_minibatch_vars = [memo[var] for var in original_minibatch_vars]
104+
minibatch_vars = Minibatch(*pre_minibatch_vars, batch_size=batch_size)
105+
106+
# Replace uses of the specified data variables with Minibatch variables
107+
# We need a two-step clone because FunctionGraph can only mutate one variable at a time
108+
# and when there are multiple vars to minibatch you end up replacing the same variable twice recursively
109+
# exampre: out = x + y
110+
# goal: replace (x, y) by (Minibatch(x, y).0, Minibatch(x, y).1)]
111+
# replace x first we get: out = Minibatch(x, y).0 + y
112+
# then replace y we get: out = Minibatch(x, Minibatch(...).1).0 + Minibatch(x, y).1
113+
# The second replacement of y ends up creating a circular dependency
114+
pre_minibatch_var_to_dummy = tuple((var, var.type()) for var in pre_minibatch_vars)
115+
dummy_to_minibatch_var = tuple(
116+
(dummy, minibatch_var)
117+
for (_, dummy), minibatch_var in zip(pre_minibatch_var_to_dummy, minibatch_vars)
118+
)
119+
120+
# Furthermore, we only want to replace uses of the data variables (x, y), but not the data variables themselves,
121+
# So we use an intermediate FunctionGraph that doesn't contain the data variables as outputs
122+
other_model_vars = [out for out in fgraph.outputs if out not in pre_minibatch_vars]
123+
minibatch_fgraph = FunctionGraph(outputs=other_model_vars, clone=False)
124+
minibatch_fgraph._coords = fgraph._coords # type: ignore[attr-defined]
125+
minibatch_fgraph._dim_lengths = fgraph._dim_lengths # type: ignore[attr-defined]
126+
toposort_replace(minibatch_fgraph, pre_minibatch_var_to_dummy)
127+
toposort_replace(minibatch_fgraph, dummy_to_minibatch_var)
128+
129+
# Then replace all observed RVs that depend on the minibatch variables with MinibatchRVs
130+
dependent_replacements = {}
131+
total_size = (pre_minibatch_vars[0].owner.inputs[0].shape[0], ...)
132+
vars_to_minibatch_set = set(pre_minibatch_vars)
133+
for model_var in minibatch_fgraph.outputs:
134+
if not (set(ancestors([model_var])) & vars_to_minibatch_set):
135+
continue
136+
if not isinstance(model_var.owner.op, ModelObservedRV):
137+
raise ValueError(
138+
"Minibatching only supports observed RVs depending on minibatched variables. "
139+
f"Found dependent unobserved variable: {model_var.name}."
140+
)
141+
# TODO: If vars_to_minibatch had a leading dim, we should check that the dependent RVs also has that same dim
142+
# And conversely other variables do not have that dim
143+
observed_rv = model_var.owner.inputs[0]
144+
dependent_replacements[observed_rv] = create_minibatch_rv(
145+
observed_rv, total_size=total_size
146+
)
147+
148+
toposort_replace(minibatch_fgraph, tuple(dependent_replacements.items()))
149+
150+
# Finally reintroduce the original data variable outputs
151+
for pre_minibatch_var in pre_minibatch_vars:
152+
minibatch_fgraph.add_output(pre_minibatch_var)
153+
154+
return model_from_fgraph(minibatch_fgraph, mutate_fgraph=True)
155+
156+
157+
def remove_minibatch(model: Model) -> Model:
158+
"""Remove all uses of Minibatch data and random variables from the Model.
159+
160+
Parameters
161+
----------
162+
model : Model
163+
The original model to transform.
164+
165+
Returns
166+
-------
167+
Model
168+
A new Model with all Minibatch data variables and MinibatchRVs replaced by their original counterparts.
169+
170+
Examples
171+
--------
172+
.. code-block:: python
173+
174+
import pymc as pm
175+
from pymc.model.transform.minibatch import undo_minibatch
176+
177+
with pm.Model() as mb:
178+
X_data = pm.Data("X_data", [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]])
179+
obs_data = pm.Data("obs_data", [1, 2, 3, 4, 5])
180+
minibatch_X_data, minibatch_obs_data = pm.Minibatch(X_data, obs_data, batch_size=3)
181+
182+
beta = pm.Normal("beta", shape=(2,))
183+
mu = minibatch_X_data @ beta
184+
y = pm.Normal("y", mu=mu, sigma=1, observed=minibatch_obs_data, total_size=(5,))
185+
186+
with undo_minibatch(mb) as m:
187+
idata = pm.sample_prior_predictive()
188+
assert idata.prior["y"].shape[-1] == 5 # Original data size restored
189+
190+
"""
191+
fgraph, _ = fgraph_from_model(model)
192+
193+
replacements = []
194+
for var in fgraph.apply_nodes:
195+
if isinstance(var.op, MinibatchOp):
196+
replacements.extend(zip(var.inputs, var.outputs))
197+
elif isinstance(var.op, MinibatchRandomVariable):
198+
replacements.append((var.outputs[0], var.inputs[0]))
199+
200+
toposort_replace(fgraph, replacements)
201+
return model_from_fgraph(fgraph, mutate_fgraph=True)

0 commit comments

Comments
 (0)