Skip to content

Commit 8f4673b

Browse files
Working model_to_minibatch implementation
1 parent 9df25b9 commit 8f4673b

File tree

2 files changed

+101
-37
lines changed

2 files changed

+101
-37
lines changed

pymc/model/transform/basic.py

Lines changed: 78 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
from collections.abc import Sequence
1515

1616
from pytensor import Variable, clone_replace
17-
from pytensor.compile import SharedVariable
1817
from pytensor.graph import ancestors
1918
from pytensor.graph.fg import FunctionGraph
2019

@@ -23,10 +22,8 @@
2322
from pymc.model.fgraph import (
2423
ModelObservedRV,
2524
ModelVar,
26-
extract_dims,
2725
fgraph_from_model,
2826
model_from_fgraph,
29-
model_observed_rv,
3027
)
3128
from pymc.pytensorf import toposort_replace
3229

@@ -66,45 +63,96 @@ def parse_vars(model: Model, vars: ModelVariable | Sequence[ModelVariable]) -> l
6663
return [model[var] if isinstance(var, str) else var for var in vars_seq]
6764

6865

69-
def model_to_minibatch(model: Model, batch_size: int) -> Model:
66+
def model_to_minibatch(
67+
model: Model, *, batch_size: int, vars_to_minibatch: list[str] | None = None
68+
) -> Model:
7069
"""Replace all Data containers with pm.Minibatch, and add total_size to all observed RVs."""
7170
from pymc.variational.minibatch_rv import create_minibatch_rv
7271

72+
if vars_to_minibatch is None:
73+
vars_to_minibatch = [
74+
variable
75+
for variable in model.data_vars
76+
if (variable.type.ndim > 0) and (variable.type.shape[0] is None)
77+
]
78+
else:
79+
vars_to_minibatch = parse_vars(model, vars_to_minibatch)
80+
for variable in vars_to_minibatch:
81+
if variable.type.ndim == 0:
82+
raise ValueError(
83+
f"Cannot minibatch {variable.name} because it is a scalar variable."
84+
)
85+
if variable.type.shape[0] is not None:
86+
raise ValueError(
87+
f"Cannot minibatch {variable.name} because its first dimension is static "
88+
f"(size={variable.type.shape[0]})."
89+
)
90+
91+
# TODO: Validate that this graph is actually valid to minibatch. Example: linear regression with sigma fixed
92+
# shape, but mu from data --> y cannot be minibatched because of sigma.
93+
7394
fgraph, memo = fgraph_from_model(model, inlined_views=True)
7495

75-
# obs_rvs, data_vars = model.rvs_to_values.items()
96+
cloned_vars_to_minibatch = [memo[var] for var in vars_to_minibatch]
97+
minibatch_vars = Minibatch(*cloned_vars_to_minibatch, batch_size=batch_size)
7698

77-
data_vars = [
78-
memo[datum].owner.inputs[0]
79-
for datum in (model.named_vars[datum_name] for datum_name in model.named_vars)
80-
if isinstance(datum, SharedVariable)
81-
]
99+
var_to_dummy = {
100+
var: var.type() # model_named(minibatch_var, *extract_dims(var))
101+
for var, minibatch_var in zip(cloned_vars_to_minibatch, minibatch_vars)
102+
}
103+
dummy_to_minibatch = {
104+
var_to_dummy[var]: minibatch_var
105+
for var, minibatch_var in zip(cloned_vars_to_minibatch, minibatch_vars)
106+
}
107+
total_size = (cloned_vars_to_minibatch[0].owner.inputs[0].shape[0], ...)
82108

83-
minibatch_vars = Minibatch(*data_vars, batch_size=batch_size)
84-
replacements = {datum: minibatch_vars[i] for i, datum in enumerate(data_vars)}
85-
assert 0
86-
# Add total_size to all observed RVs
87-
total_size = data_vars[0].get_value().shape[0]
88-
for obs_var in model.observed_RVs:
89-
model_var = memo[obs_var]
90-
var = model_var.owner.inputs[0]
91-
var.name = model_var.name
92-
dims = extract_dims(model_var)
109+
# TODO: If vars_to_minibatch had a leading dim, we should check that the dependent RVs also has that same dim
110+
# (or just do this all in xtensor)
111+
vtm_set = set(vars_to_minibatch)
93112

94-
new_rv = create_minibatch_rv(var, total_size=total_size)
95-
new_rv.name = var.name
113+
# TODO: Handle potentials, free_RVs, etc
96114

97-
replacements[model_var] = model_observed_rv(new_rv, model.rvs_to_values[obs_var], *dims)
115+
# Create a temporary fgraph that does not include as outputs any of the variables that will be minibatched. This
116+
# ensures the results of this function match the outputs from a model constructed using the pm.Minibatch API.
117+
tmp_fgraph = FunctionGraph(
118+
outputs=[out for out in fgraph.outputs if out not in var_to_dummy.keys()], clone=False
119+
)
98120

99-
# old_outs, old_coords, old_dim_lengths = fgraph.outputs, fgraph._coords, fgraph._dim_lengths
100-
toposort_replace(fgraph, tuple(replacements.items()))
101-
# new_outs = clone_replace(old_outs, replacements, rebuild_strict=False) # type: ignore[arg-type]
121+
# All variables that will be minibatched are first replaced by dummy variables, to avoid infinite recursion during
122+
# rewrites. The issue is that the Minibatch Op we will introduce depends on the original input variables (to get
123+
# the shapes). That's fine in the final output, but during the intermediate rewrites this creates a circulatr
124+
# dependency.
125+
dummy_replacements = tuple(var_to_dummy.items())
126+
toposort_replace(tmp_fgraph, dummy_replacements)
102127

103-
# fgraph = FunctionGraph(outputs=new_outs, clone=False)
104-
# fgraph._coords = old_coords # type: ignore[attr-defined]
105-
# fgraph._dim_lengths = old_dim_lengths # type: ignore[attr-defined]
128+
# Now we can replace the dummy variables with the actual Minibatch variables.
129+
replacements = tuple(dummy_to_minibatch.items())
130+
toposort_replace(tmp_fgraph, replacements)
106131

107-
return model_from_fgraph(fgraph, mutate_fgraph=True)
132+
# The last step is to replace all RVs that depend on the minibatched variables with MinibatchRVs that are aware
133+
# of the total_size. Importantly, all of the toposort_replace calls above modify fgraph in place, so the
134+
# model.rvs_to_values[original_rv] will already have been modified to depend on the Minibatch variables -- only
135+
# the outer RVs need to be replaced here.
136+
dependent_replacements = {}
137+
138+
for original_rv in model.observed_RVs:
139+
original_value_var = model.rvs_to_values[original_rv]
140+
141+
if not (set(ancestors([original_rv, original_value_var])) & vtm_set):
142+
continue
143+
144+
rv = memo[original_rv].owner.inputs[0]
145+
dependent_replacements[rv] = create_minibatch_rv(rv, total_size=total_size)
146+
147+
toposort_replace(fgraph, tuple(dependent_replacements.items()))
148+
149+
# FIXME: The fgraph is being rebuilt here to clean up the clients. It is not clear why they are getting messed up
150+
# in the first place (pytensor bug, or something wrong in the above manipulations?)
151+
new_fgraph = FunctionGraph(outputs=fgraph.outputs)
152+
new_fgraph._coords = fgraph._coords # type: ignore[attr-defined]
153+
new_fgraph._dim_lengths = fgraph._dim_lengths # type: ignore[attr-defined]
154+
155+
return model_from_fgraph(new_fgraph, mutate_fgraph=True)
108156

109157

110158
def remove_minibatched_nodes(model: Model) -> Model:

tests/model/transform/test_basic.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
prune_vars_detached_from_observed,
2121
remove_minibatched_nodes,
2222
)
23+
from pymc.testing import assert_equivalent_models
2324

2425

2526
def test_prune_vars_detached_from_observed():
@@ -42,22 +43,37 @@ def test_model_to_minibatch():
4243
data_size = 100
4344
n_features = 4
4445

45-
obs_data = np.zeros((data_size,))
46-
X_data = np.random.normal(size=(data_size, n_features))
46+
obs_data_np = np.zeros((data_size,))
47+
X_data_np = np.random.normal(size=(data_size, n_features))
4748

4849
with pm.Model(coords={"feature": range(n_features), "data_dim": range(data_size)}) as m1:
49-
obs_data = pm.Data("obs_data", obs_data, dims=["data_dim"])
50-
X_data = pm.Data("X_data", X_data, dims=["data_dim", "feature"])
50+
obs_data = pm.Data("obs_data", obs_data_np, dims=["data_dim"])
51+
X_data = pm.Data("X_data", X_data_np, dims=["data_dim", "feature"])
5152
beta = pm.Normal("beta", dims="feature")
5253

5354
mu = X_data @ beta
5455

5556
y = pm.Normal("y", mu=mu, sigma=1, observed=obs_data, dims="data_dim")
5657

57-
m2 = model_to_minibatch(m1, batch_size=10)
58-
m2["y"].dprint()
58+
with pm.Model(
59+
coords={"feature": range(n_features), "data_dim": range(data_size)}
60+
) as reference_model:
61+
obs_data = pm.Data("obs_data", obs_data_np, dims=["data_dim"])
62+
X_data = pm.Data("X_data", X_data_np, dims=["data_dim", "feature"])
63+
minibatch_obs_data, minibatch_X_data = pm.Minibatch(obs_data, X_data, batch_size=10)
64+
beta = pm.Normal("beta", dims="feature")
65+
mu = minibatch_X_data @ beta
66+
y = pm.Normal(
67+
"y",
68+
mu=mu,
69+
sigma=1,
70+
observed=minibatch_obs_data,
71+
dims="data_dim",
72+
total_size=(obs_data.shape[0], ...),
73+
)
5974

60-
assert 0
75+
m2 = model_to_minibatch(m1, batch_size=10)
76+
assert_equivalent_models(m2, reference_model)
6177

6278

6379
def test_remove_minibatches():

0 commit comments

Comments
 (0)