|
| 1 | +# Copyright 2025 - present The PyMC Developers |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | +from collections.abc import Sequence |
| 15 | + |
| 16 | +from pytensor import Variable |
| 17 | +from pytensor.graph import FunctionGraph, ancestors |
| 18 | + |
| 19 | +from build.lib.pymc.variational.minibatch_rv import MinibatchRandomVariable |
| 20 | +from pymc import Minibatch, Model |
| 21 | +from pymc.data import MinibatchOp |
| 22 | +from pymc.model.fgraph import ModelObservedRV, fgraph_from_model, model_from_fgraph |
| 23 | +from pymc.model.transform.basic import parse_vars |
| 24 | +from pymc.pytensorf import toposort_replace |
| 25 | + |
| 26 | + |
| 27 | +def minibatch_model( |
| 28 | + model: Model, |
| 29 | + *, |
| 30 | + batch_size: int, |
| 31 | + minibatch_vars: Sequence[str | Variable] | None = None, |
| 32 | +) -> Model: |
| 33 | + """Create a minibatch version of the given Model. |
| 34 | +
|
| 35 | + Replaces minibatch_vars data containers with Minibatch views and rescales the logp of dependent observed variables. |
| 36 | +
|
| 37 | + .. warning:: This transformation acts on the leading dimension of the specified data variables and dependent observed RVs. If a dimension other than the first is linked to the minibatched data variables, the resulting model will be invalid. |
| 38 | +
|
| 39 | + Parameters |
| 40 | + ---------- |
| 41 | + model : Model |
| 42 | + The original model to transform. |
| 43 | + batch_size : int |
| 44 | + The minibatch size to use. |
| 45 | + minibatch_vars : Sequence of Variable or string, optional |
| 46 | + Data variables to convert to minibatch. If None, all data variables with a leading dimension of size None will be minibatched. |
| 47 | +
|
| 48 | + Returns |
| 49 | + ------- |
| 50 | + Model |
| 51 | + A new Model with the specified data variables replaced by Minibatch views and dependent observed RVs adjusted accordingly. |
| 52 | +
|
| 53 | + Raises |
| 54 | + ------ |
| 55 | + ValueError |
| 56 | + If any of the specified variables cannot be minibatched (e.g., scalar variables or variables with static leading dimensions), or if dependent variables are Potentials / Unobserved RVs. |
| 57 | +
|
| 58 | + Examples |
| 59 | + -------- |
| 60 | + .. code-block:: python |
| 61 | +
|
| 62 | + import numpy as np |
| 63 | + import pymc as pm |
| 64 | + from pymc.model.transform.minibatch import minibatch_model |
| 65 | +
|
| 66 | + with pm.Model() as m: |
| 67 | + obs_data = pm.Data("obs_data", np.random.normal(size=(100,))) |
| 68 | + X_data = pm.Data("X_data", np.random.normal(size=(100, 4))) |
| 69 | + beta = pm.Normal("beta", mu=np.pi, dims="feature") |
| 70 | +
|
| 71 | + mu = X_data @ beta |
| 72 | + y = pm.Normal("y", mu=mu, sigma=1, observed=obs_data) |
| 73 | +
|
| 74 | + with minibatch_model(m, batch_size=10) as mb: |
| 75 | + pm.fit() |
| 76 | + """ |
| 77 | + from pymc.variational.minibatch_rv import create_minibatch_rv |
| 78 | + |
| 79 | + if minibatch_vars is None: |
| 80 | + original_minibatch_vars = [ |
| 81 | + variable |
| 82 | + for variable in model.data_vars |
| 83 | + if (variable.type.ndim > 0) and (variable.type.shape[0] is None) |
| 84 | + ] |
| 85 | + else: |
| 86 | + original_minibatch_vars = parse_vars(model, minibatch_vars) |
| 87 | + for variable in original_minibatch_vars: |
| 88 | + if variable.type.ndim == 0: |
| 89 | + raise ValueError( |
| 90 | + f"Cannot minibatch {variable.name} because it is a scalar variable." |
| 91 | + ) |
| 92 | + if variable.type.shape[0] is not None: |
| 93 | + raise ValueError( |
| 94 | + f"Cannot minibatch {variable.name} because its first dimension is static " |
| 95 | + f"(size={variable.type.shape[0]})." |
| 96 | + ) |
| 97 | + |
| 98 | + # TODO: Validate that this graph is actually valid to minibatch. Example: linear regression with sigma fixed |
| 99 | + # shape, but mu from data --> y cannot be minibatched because of sigma. |
| 100 | + |
| 101 | + fgraph, memo = fgraph_from_model(model, inlined_views=True) |
| 102 | + |
| 103 | + pre_minibatch_vars = [memo[var] for var in original_minibatch_vars] |
| 104 | + minibatch_vars = Minibatch(*pre_minibatch_vars, batch_size=batch_size) |
| 105 | + |
| 106 | + # Replace uses of the specified data variables with Minibatch variables |
| 107 | + # We need a two-step clone because FunctionGraph can only mutate one variable at a time |
| 108 | + # and when there are multiple vars to minibatch you end up replacing the same variable twice recursively |
| 109 | + # exampre: out = x + y |
| 110 | + # goal: replace (x, y) by (Minibatch(x, y).0, Minibatch(x, y).1)] |
| 111 | + # replace x first we get: out = Minibatch(x, y).0 + y |
| 112 | + # then replace y we get: out = Minibatch(x, Minibatch(...).1).0 + Minibatch(x, y).1 |
| 113 | + # The second replacement of y ends up creating a circular dependency |
| 114 | + pre_minibatch_var_to_dummy = tuple((var, var.type()) for var in pre_minibatch_vars) |
| 115 | + dummy_to_minibatch_var = tuple( |
| 116 | + (dummy, minibatch_var) |
| 117 | + for (_, dummy), minibatch_var in zip(pre_minibatch_var_to_dummy, minibatch_vars) |
| 118 | + ) |
| 119 | + |
| 120 | + # Furthermore, we only want to replace uses of the data variables (x, y), but not the data variables themselves, |
| 121 | + # So we use an intermediate FunctionGraph that doesn't contain the data variables as outputs |
| 122 | + other_model_vars = [out for out in fgraph.outputs if out not in pre_minibatch_vars] |
| 123 | + minibatch_fgraph = FunctionGraph(outputs=other_model_vars, clone=False) |
| 124 | + minibatch_fgraph._coords = fgraph._coords # type: ignore[attr-defined] |
| 125 | + minibatch_fgraph._dim_lengths = fgraph._dim_lengths # type: ignore[attr-defined] |
| 126 | + toposort_replace(minibatch_fgraph, pre_minibatch_var_to_dummy) |
| 127 | + toposort_replace(minibatch_fgraph, dummy_to_minibatch_var) |
| 128 | + |
| 129 | + # Then replace all observed RVs that depend on the minibatch variables with MinibatchRVs |
| 130 | + dependent_replacements = {} |
| 131 | + total_size = (pre_minibatch_vars[0].owner.inputs[0].shape[0], ...) |
| 132 | + vars_to_minibatch_set = set(pre_minibatch_vars) |
| 133 | + for model_var in minibatch_fgraph.outputs: |
| 134 | + if not (set(ancestors([model_var])) & vars_to_minibatch_set): |
| 135 | + continue |
| 136 | + if not isinstance(model_var.owner.op, ModelObservedRV): |
| 137 | + raise ValueError( |
| 138 | + "Minibatching only supports observed RVs depending on minibatched variables. " |
| 139 | + f"Found dependent unobserved variable: {model_var.name}." |
| 140 | + ) |
| 141 | + # TODO: If vars_to_minibatch had a leading dim, we should check that the dependent RVs also has that same dim |
| 142 | + # And conversely other variables do not have that dim |
| 143 | + observed_rv = model_var.owner.inputs[0] |
| 144 | + dependent_replacements[observed_rv] = create_minibatch_rv( |
| 145 | + observed_rv, total_size=total_size |
| 146 | + ) |
| 147 | + |
| 148 | + toposort_replace(minibatch_fgraph, tuple(dependent_replacements.items())) |
| 149 | + |
| 150 | + # Finally reintroduce the original data variable outputs |
| 151 | + for pre_minibatch_var in pre_minibatch_vars: |
| 152 | + minibatch_fgraph.add_output(pre_minibatch_var) |
| 153 | + |
| 154 | + return model_from_fgraph(minibatch_fgraph, mutate_fgraph=True) |
| 155 | + |
| 156 | + |
| 157 | +def remove_minibatch(model: Model) -> Model: |
| 158 | + """Remove all uses of Minibatch data and random variables from the Model. |
| 159 | +
|
| 160 | + Parameters |
| 161 | + ---------- |
| 162 | + model : Model |
| 163 | + The original model to transform. |
| 164 | +
|
| 165 | + Returns |
| 166 | + ------- |
| 167 | + Model |
| 168 | + A new Model with all Minibatch data variables and MinibatchRVs replaced by their original counterparts. |
| 169 | +
|
| 170 | + Examples |
| 171 | + -------- |
| 172 | + .. code-block:: python |
| 173 | +
|
| 174 | + import pymc as pm |
| 175 | + from pymc.model.transform.minibatch import undo_minibatch |
| 176 | +
|
| 177 | + with pm.Model() as mb: |
| 178 | + X_data = pm.Data("X_data", [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]]) |
| 179 | + obs_data = pm.Data("obs_data", [1, 2, 3, 4, 5]) |
| 180 | + minibatch_X_data, minibatch_obs_data = pm.Minibatch(X_data, obs_data, batch_size=3) |
| 181 | +
|
| 182 | + beta = pm.Normal("beta", shape=(2,)) |
| 183 | + mu = minibatch_X_data @ beta |
| 184 | + y = pm.Normal("y", mu=mu, sigma=1, observed=minibatch_obs_data, total_size=(5,)) |
| 185 | +
|
| 186 | + with undo_minibatch(mb) as m: |
| 187 | + idata = pm.sample_prior_predictive() |
| 188 | + assert idata.prior["y"].shape[-1] == 5 # Original data size restored |
| 189 | +
|
| 190 | + """ |
| 191 | + fgraph, _ = fgraph_from_model(model) |
| 192 | + |
| 193 | + replacements = [] |
| 194 | + for var in fgraph.apply_nodes: |
| 195 | + if isinstance(var.op, MinibatchOp): |
| 196 | + replacements.extend(zip(var.inputs, var.outputs)) |
| 197 | + elif isinstance(var.op, MinibatchRandomVariable): |
| 198 | + replacements.append((var.outputs[0], var.inputs[0])) |
| 199 | + |
| 200 | + toposort_replace(fgraph, replacements) |
| 201 | + return model_from_fgraph(fgraph, mutate_fgraph=True) |
0 commit comments