Skip to content

Commit 9ab26a8

Browse files
committed
Cleanup implementation and test
1 parent 4d18e37 commit 9ab26a8

File tree

2 files changed

+81
-72
lines changed

2 files changed

+81
-72
lines changed

pymc/model/transform/basic.py

Lines changed: 53 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -64,20 +64,20 @@ def parse_vars(model: Model, vars: ModelVariable | Sequence[ModelVariable]) -> l
6464

6565

6666
def model_to_minibatch(
67-
model: Model, *, batch_size: int, vars_to_minibatch: list[str] | None = None
67+
model: Model, *, batch_size: int, minibatch_vars: list[str] | None = None
6868
) -> Model:
6969
"""Replace all Data containers with pm.Minibatch, and add total_size to all observed RVs."""
7070
from pymc.variational.minibatch_rv import create_minibatch_rv
7171

72-
if vars_to_minibatch is None:
73-
vars_to_minibatch = [
72+
if minibatch_vars is None:
73+
original_minibatch_vars = [
7474
variable
7575
for variable in model.data_vars
7676
if (variable.type.ndim > 0) and (variable.type.shape[0] is None)
7777
]
7878
else:
79-
vars_to_minibatch = parse_vars(model, vars_to_minibatch)
80-
for variable in vars_to_minibatch:
79+
original_minibatch_vars = parse_vars(model, minibatch_vars)
80+
for variable in original_minibatch_vars:
8181
if variable.type.ndim == 0:
8282
raise ValueError(
8383
f"Cannot minibatch {variable.name} because it is a scalar variable."
@@ -93,66 +93,58 @@ def model_to_minibatch(
9393

9494
fgraph, memo = fgraph_from_model(model, inlined_views=True)
9595

96-
cloned_vars_to_minibatch = [memo[var] for var in vars_to_minibatch]
97-
minibatch_vars = Minibatch(*cloned_vars_to_minibatch, batch_size=batch_size)
98-
99-
var_to_dummy = {
100-
var: var.type() # model_named(minibatch_var, *extract_dims(var))
101-
for var, minibatch_var in zip(cloned_vars_to_minibatch, minibatch_vars)
102-
}
103-
dummy_to_minibatch = {
104-
var_to_dummy[var]: minibatch_var
105-
for var, minibatch_var in zip(cloned_vars_to_minibatch, minibatch_vars)
106-
}
107-
total_size = (cloned_vars_to_minibatch[0].owner.inputs[0].shape[0], ...)
108-
109-
# TODO: If vars_to_minibatch had a leading dim, we should check that the dependent RVs also has that same dim
110-
# (or just do this all in xtensor)
111-
vtm_set = set(vars_to_minibatch)
112-
113-
# TODO: Handle potentials, free_RVs, etc
114-
115-
# Create a temporary fgraph that does not include as outputs any of the variables that will be minibatched. This
116-
# ensures the results of this function match the outputs from a model constructed using the pm.Minibatch API.
117-
tmp_fgraph = FunctionGraph(
118-
outputs=[out for out in fgraph.outputs if out not in var_to_dummy.keys()], clone=False
96+
pre_minibatch_vars = [memo[var] for var in original_minibatch_vars]
97+
minibatch_vars = Minibatch(*pre_minibatch_vars, batch_size=batch_size)
98+
99+
# Replace uses of the specified data variables with Minibatch variables
100+
# We need a two-step clone because FunctionGraph can only mutate one variable at a time
101+
# and when there are multiple vars to minibatch you end up replacing the same variable twice recursively
102+
# exampre: out = x + y
103+
# goal: replace (x, y) by (Minibatch(x, y).0, Minibatch(x, y).1)]
104+
# replace x first we get: out = Minibatch(x, y).0 + y
105+
# then replace y we get: out = Minibatch(x, Minibatch(...).1).0 + Minibatch(x, y).1
106+
# The second replacement of y ends up creating a circular dependency
107+
pre_minibatch_var_to_dummy = tuple((var, var.type()) for var in pre_minibatch_vars)
108+
dummy_to_minibatch_var = tuple(
109+
(dummy, minibatch_var)
110+
for (_, dummy), minibatch_var in zip(pre_minibatch_var_to_dummy, minibatch_vars)
119111
)
120112

121-
# All variables that will be minibatched are first replaced by dummy variables, to avoid infinite recursion during
122-
# rewrites. The issue is that the Minibatch Op we will introduce depends on the original input variables (to get
123-
# the shapes). That's fine in the final output, but during the intermediate rewrites this creates a circulatr
124-
# dependency.
125-
dummy_replacements = tuple(var_to_dummy.items())
126-
toposort_replace(tmp_fgraph, dummy_replacements)
127-
128-
# Now we can replace the dummy variables with the actual Minibatch variables.
129-
replacements = tuple(dummy_to_minibatch.items())
130-
toposort_replace(tmp_fgraph, replacements)
131-
132-
# The last step is to replace all RVs that depend on the minibatched variables with MinibatchRVs that are aware
133-
# of the total_size. Importantly, all of the toposort_replace calls above modify fgraph in place, so the
134-
# model.rvs_to_values[original_rv] will already have been modified to depend on the Minibatch variables -- only
135-
# the outer RVs need to be replaced here.
136-
dependent_replacements = {}
113+
# Furthermore, we only want to replace uses of the data variables (x, y), but not the data variables themselves,
114+
# So we use an intermediate FunctionGraph that doesn't contain the data variables as outputs
115+
other_model_vars = [out for out in fgraph.outputs if out not in pre_minibatch_vars]
116+
minibatch_fgraph = FunctionGraph(outputs=other_model_vars, clone=False)
117+
minibatch_fgraph._coords = fgraph._coords # type: ignore[attr-defined]
118+
minibatch_fgraph._dim_lengths = fgraph._dim_lengths # type: ignore[attr-defined]
119+
toposort_replace(minibatch_fgraph, pre_minibatch_var_to_dummy)
120+
toposort_replace(minibatch_fgraph, dummy_to_minibatch_var)
137121

138-
for original_rv in model.observed_RVs:
139-
original_value_var = model.rvs_to_values[original_rv]
140-
141-
if not (set(ancestors([original_rv, original_value_var])) & vtm_set):
122+
# Then replace all observed RVs that depend on the minibatch variables with MinibatchRVs
123+
dependent_replacements = {}
124+
total_size = (pre_minibatch_vars[0].owner.inputs[0].shape[0], ...)
125+
vars_to_minibatch_set = set(pre_minibatch_vars)
126+
for model_var in minibatch_fgraph.outputs:
127+
if not (set(ancestors([model_var])) & vars_to_minibatch_set):
142128
continue
143-
144-
rv = memo[original_rv].owner.inputs[0]
145-
dependent_replacements[rv] = create_minibatch_rv(rv, total_size=total_size)
146-
147-
toposort_replace(fgraph, tuple(dependent_replacements.items()))
148-
149-
# FIXME: The fgraph is being rebuilt here to clean up the clients. It is not clear why they are getting messed up
150-
# in the first place (pytensor bug, or something wrong in the above manipulations?)
151-
new_fgraph = FunctionGraph(outputs=fgraph.outputs)
152-
new_fgraph._coords = fgraph._coords # type: ignore[attr-defined]
153-
new_fgraph._dim_lengths = fgraph._dim_lengths # type: ignore[attr-defined]
154-
155-
return model_from_fgraph(new_fgraph, mutate_fgraph=True)
129+
if not isinstance(model_var.owner.op, ModelObservedRV):
130+
raise ValueError(
131+
"Minibatching only supports observed RVs depending on minibatched variables. "
132+
f"Found dependent unobserved variable: {model_var.name}."
133+
)
134+
# TODO: If vars_to_minibatch had a leading dim, we should check that the dependent RVs also has that same dim
135+
# And conversely other variables do not have that dim
136+
observed_rv = model_var.owner.inputs[0]
137+
dependent_replacements[observed_rv] = create_minibatch_rv(
138+
observed_rv, total_size=total_size
139+
)
140+
141+
toposort_replace(minibatch_fgraph, tuple(dependent_replacements.items()))
142+
143+
# Finally reintroduce the original data variable outputs
144+
for pre_minibatch_var in pre_minibatch_vars:
145+
minibatch_fgraph.add_output(pre_minibatch_var)
146+
147+
return model_from_fgraph(minibatch_fgraph, mutate_fgraph=True)
156148

157149

158150
def remove_minibatched_nodes(model: Model) -> Model:

tests/model/transform/test_basic.py

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
prune_vars_detached_from_observed,
2121
remove_minibatched_nodes,
2222
)
23-
from pymc.testing import assert_equivalent_models
2423

2524

2625
def test_prune_vars_detached_from_observed():
@@ -43,25 +42,22 @@ def test_model_to_minibatch():
4342
data_size = 100
4443
n_features = 4
4544

46-
obs_data_np = np.zeros((data_size,))
45+
obs_data_np = np.random.normal(size=(data_size,))
4746
X_data_np = np.random.normal(size=(data_size, n_features))
4847

49-
with pm.Model(coords={"feature": range(n_features), "data_dim": range(data_size)}) as m1:
48+
with pm.Model(coords={"feature": range(n_features), "data_dim": range(data_size)}) as m:
5049
obs_data = pm.Data("obs_data", obs_data_np, dims=["data_dim"])
5150
X_data = pm.Data("X_data", X_data_np, dims=["data_dim", "feature"])
52-
beta = pm.Normal("beta", dims="feature")
51+
beta = pm.Normal("beta", mu=np.pi, dims="feature")
5352

5453
mu = X_data @ beta
55-
5654
y = pm.Normal("y", mu=mu, sigma=1, observed=obs_data, dims="data_dim")
5755

58-
with pm.Model(
59-
coords={"feature": range(n_features), "data_dim": range(data_size)}
60-
) as reference_model:
56+
with pm.Model(coords={"feature": range(n_features), "data_dim": range(data_size)}) as ref_m:
6157
obs_data = pm.Data("obs_data", obs_data_np, dims=["data_dim"])
6258
X_data = pm.Data("X_data", X_data_np, dims=["data_dim", "feature"])
6359
minibatch_obs_data, minibatch_X_data = pm.Minibatch(obs_data, X_data, batch_size=10)
64-
beta = pm.Normal("beta", dims="feature")
60+
beta = pm.Normal("beta", mu=np.pi, dims="feature")
6561
mu = minibatch_X_data @ beta
6662
y = pm.Normal(
6763
"y",
@@ -72,8 +68,29 @@ def test_model_to_minibatch():
7268
total_size=(obs_data.shape[0], ...),
7369
)
7470

75-
m2 = model_to_minibatch(m1, batch_size=10)
76-
assert_equivalent_models(m2, reference_model)
71+
mb = model_to_minibatch(m, batch_size=10)
72+
mb_logp_fn = mb.compile_logp(random_seed=42)
73+
ref_mb_logp_fn = ref_m.compile_logp(random_seed=42)
74+
ip = mb.initial_point()
75+
76+
mb_res1 = mb_logp_fn(ip)
77+
ref_mb_res1 = ref_mb_logp_fn(ip)
78+
np.testing.assert_allclose(mb_res1, ref_mb_res1)
79+
mb_res2 = mb_logp_fn(ip)
80+
# Minibatch should give different results on each call
81+
assert mb_res1 != mb_res2
82+
ref_mb_res2 = ref_mb_logp_fn(ip)
83+
np.testing.assert_allclose(mb_res2, ref_mb_res2)
84+
85+
m_again = remove_minibatched_nodes(mb)
86+
m_again_logp_fn = m_again.compile_logp(random_seed=42)
87+
m_logp_fn = m_again.compile_logp(random_seed=42)
88+
ip = m_again.initial_point()
89+
m_again_res = m_again_logp_fn(ip)
90+
m_res = m_logp_fn(ip)
91+
np.testing.assert_allclose(m_again_res, m_res)
92+
# Check that repeated calls give the same result (no more minibatching)
93+
np.testing.assert_allclose(m_again_res, m_again_logp_fn(ip))
7794

7895

7996
def test_remove_minibatches():

0 commit comments

Comments
 (0)