|
14 | 14 | from collections.abc import Sequence |
15 | 15 |
|
16 | 16 | from pytensor import Variable, clone_replace |
17 | | -from pytensor.compile import SharedVariable |
18 | 17 | from pytensor.graph import ancestors |
19 | 18 | from pytensor.graph.fg import FunctionGraph |
20 | 19 |
|
|
23 | 22 | from pymc.model.fgraph import ( |
24 | 23 | ModelObservedRV, |
25 | 24 | ModelVar, |
26 | | - extract_dims, |
27 | 25 | fgraph_from_model, |
28 | 26 | model_from_fgraph, |
29 | | - model_observed_rv, |
30 | 27 | ) |
31 | 28 | from pymc.pytensorf import toposort_replace |
32 | 29 |
|
@@ -66,45 +63,96 @@ def parse_vars(model: Model, vars: ModelVariable | Sequence[ModelVariable]) -> l |
66 | 63 | return [model[var] if isinstance(var, str) else var for var in vars_seq] |
67 | 64 |
|
68 | 65 |
|
69 | | -def model_to_minibatch(model: Model, batch_size: int) -> Model: |
| 66 | +def model_to_minibatch( |
| 67 | + model: Model, *, batch_size: int, vars_to_minibatch: list[str] | None = None |
| 68 | +) -> Model: |
70 | 69 | """Replace all Data containers with pm.Minibatch, and add total_size to all observed RVs.""" |
71 | 70 | from pymc.variational.minibatch_rv import create_minibatch_rv |
72 | 71 |
|
| 72 | + if vars_to_minibatch is None: |
| 73 | + vars_to_minibatch = [ |
| 74 | + variable |
| 75 | + for variable in model.data_vars |
| 76 | + if (variable.type.ndim > 0) and (variable.type.shape[0] is None) |
| 77 | + ] |
| 78 | + else: |
| 79 | + vars_to_minibatch = parse_vars(model, vars_to_minibatch) |
| 80 | + for variable in vars_to_minibatch: |
| 81 | + if variable.type.ndim == 0: |
| 82 | + raise ValueError( |
| 83 | + f"Cannot minibatch {variable.name} because it is a scalar variable." |
| 84 | + ) |
| 85 | + if variable.type.shape[0] is not None: |
| 86 | + raise ValueError( |
| 87 | + f"Cannot minibatch {variable.name} because its first dimension is static " |
| 88 | + f"(size={variable.type.shape[0]})." |
| 89 | + ) |
| 90 | + |
| 91 | + # TODO: Validate that this graph is actually valid to minibatch. Example: linear regression with sigma fixed |
| 92 | + # shape, but mu from data --> y cannot be minibatched because of sigma. |
| 93 | + |
73 | 94 | fgraph, memo = fgraph_from_model(model, inlined_views=True) |
74 | 95 |
|
75 | | - # obs_rvs, data_vars = model.rvs_to_values.items() |
| 96 | + cloned_vars_to_minibatch = [memo[var] for var in vars_to_minibatch] |
| 97 | + minibatch_vars = Minibatch(*cloned_vars_to_minibatch, batch_size=batch_size) |
76 | 98 |
|
77 | | - data_vars = [ |
78 | | - memo[datum].owner.inputs[0] |
79 | | - for datum in (model.named_vars[datum_name] for datum_name in model.named_vars) |
80 | | - if isinstance(datum, SharedVariable) |
81 | | - ] |
| 99 | + var_to_dummy = { |
| 100 | + var: var.type() # model_named(minibatch_var, *extract_dims(var)) |
| 101 | + for var, minibatch_var in zip(cloned_vars_to_minibatch, minibatch_vars) |
| 102 | + } |
| 103 | + dummy_to_minibatch = { |
| 104 | + var_to_dummy[var]: minibatch_var |
| 105 | + for var, minibatch_var in zip(cloned_vars_to_minibatch, minibatch_vars) |
| 106 | + } |
| 107 | + total_size = (cloned_vars_to_minibatch[0].owner.inputs[0].shape[0], ...) |
82 | 108 |
|
83 | | - minibatch_vars = Minibatch(*data_vars, batch_size=batch_size) |
84 | | - replacements = {datum: minibatch_vars[i] for i, datum in enumerate(data_vars)} |
85 | | - assert 0 |
86 | | - # Add total_size to all observed RVs |
87 | | - total_size = data_vars[0].get_value().shape[0] |
88 | | - for obs_var in model.observed_RVs: |
89 | | - model_var = memo[obs_var] |
90 | | - var = model_var.owner.inputs[0] |
91 | | - var.name = model_var.name |
92 | | - dims = extract_dims(model_var) |
| 109 | + # TODO: If vars_to_minibatch had a leading dim, we should check that the dependent RVs also has that same dim |
| 110 | + # (or just do this all in xtensor) |
| 111 | + vtm_set = set(vars_to_minibatch) |
93 | 112 |
|
94 | | - new_rv = create_minibatch_rv(var, total_size=total_size) |
95 | | - new_rv.name = var.name |
| 113 | + # TODO: Handle potentials, free_RVs, etc |
96 | 114 |
|
97 | | - replacements[model_var] = model_observed_rv(new_rv, model.rvs_to_values[obs_var], *dims) |
| 115 | + # Create a temporary fgraph that does not include as outputs any of the variables that will be minibatched. This |
| 116 | + # ensures the results of this function match the outputs from a model constructed using the pm.Minibatch API. |
| 117 | + tmp_fgraph = FunctionGraph( |
| 118 | + outputs=[out for out in fgraph.outputs if out not in var_to_dummy.keys()], clone=False |
| 119 | + ) |
98 | 120 |
|
99 | | - # old_outs, old_coords, old_dim_lengths = fgraph.outputs, fgraph._coords, fgraph._dim_lengths |
100 | | - toposort_replace(fgraph, tuple(replacements.items())) |
101 | | - # new_outs = clone_replace(old_outs, replacements, rebuild_strict=False) # type: ignore[arg-type] |
| 121 | + # All variables that will be minibatched are first replaced by dummy variables, to avoid infinite recursion during |
| 122 | + # rewrites. The issue is that the Minibatch Op we will introduce depends on the original input variables (to get |
| 123 | + # the shapes). That's fine in the final output, but during the intermediate rewrites this creates a circulatr |
| 124 | + # dependency. |
| 125 | + dummy_replacements = tuple(var_to_dummy.items()) |
| 126 | + toposort_replace(tmp_fgraph, dummy_replacements) |
102 | 127 |
|
103 | | - # fgraph = FunctionGraph(outputs=new_outs, clone=False) |
104 | | - # fgraph._coords = old_coords # type: ignore[attr-defined] |
105 | | - # fgraph._dim_lengths = old_dim_lengths # type: ignore[attr-defined] |
| 128 | + # Now we can replace the dummy variables with the actual Minibatch variables. |
| 129 | + replacements = tuple(dummy_to_minibatch.items()) |
| 130 | + toposort_replace(tmp_fgraph, replacements) |
106 | 131 |
|
107 | | - return model_from_fgraph(fgraph, mutate_fgraph=True) |
| 132 | + # The last step is to replace all RVs that depend on the minibatched variables with MinibatchRVs that are aware |
| 133 | + # of the total_size. Importantly, all of the toposort_replace calls above modify fgraph in place, so the |
| 134 | + # model.rvs_to_values[original_rv] will already have been modified to depend on the Minibatch variables -- only |
| 135 | + # the outer RVs need to be replaced here. |
| 136 | + dependent_replacements = {} |
| 137 | + |
| 138 | + for original_rv in model.observed_RVs: |
| 139 | + original_value_var = model.rvs_to_values[original_rv] |
| 140 | + |
| 141 | + if not (set(ancestors([original_rv, original_value_var])) & vtm_set): |
| 142 | + continue |
| 143 | + |
| 144 | + rv = memo[original_rv].owner.inputs[0] |
| 145 | + dependent_replacements[rv] = create_minibatch_rv(rv, total_size=total_size) |
| 146 | + |
| 147 | + toposort_replace(fgraph, tuple(dependent_replacements.items())) |
| 148 | + |
| 149 | + # FIXME: The fgraph is being rebuilt here to clean up the clients. It is not clear why they are getting messed up |
| 150 | + # in the first place (pytensor bug, or something wrong in the above manipulations?) |
| 151 | + new_fgraph = FunctionGraph(outputs=fgraph.outputs) |
| 152 | + new_fgraph._coords = fgraph._coords # type: ignore[attr-defined] |
| 153 | + new_fgraph._dim_lengths = fgraph._dim_lengths # type: ignore[attr-defined] |
| 154 | + |
| 155 | + return model_from_fgraph(new_fgraph, mutate_fgraph=True) |
108 | 156 |
|
109 | 157 |
|
110 | 158 | def remove_minibatched_nodes(model: Model) -> Model: |
|
0 commit comments