|
42 | 42 | walk, |
43 | 43 | ) |
44 | 44 | from pytensor.graph.fg import FunctionGraph |
| 45 | +from pytensor.scan.basic import scan |
45 | 46 | from pytensor.tensor.random.var import RandomGeneratorSharedVariable |
46 | 47 | from pytensor.tensor.sharedvar import SharedVariable, TensorSharedVariable |
47 | 48 | from pytensor.tensor.variable import TensorConstant, TensorVariable |
|
57 | 58 | from pymc.distributions.shape_utils import change_dist_size |
58 | 59 | from pymc.model import Model, modelcontext |
59 | 60 | from pymc.progress_bar import CustomProgress, default_progress_theme |
60 | | -from pymc.pytensorf import compile, rvs_in_graph |
| 61 | +from pymc.pytensorf import ( |
| 62 | + clone_while_sharing_some_variables, |
| 63 | + collect_default_updates, |
| 64 | + compile, |
| 65 | + rvs_in_graph, |
| 66 | +) |
61 | 67 | from pymc.util import ( |
62 | 68 | RandomState, |
63 | 69 | _get_seeds_per_chain, |
|
68 | 74 | __all__ = ( |
69 | 75 | "compile_forward_sampling_function", |
70 | 76 | "draw", |
| 77 | + "loop_over_posterior", |
71 | 78 | "sample_posterior_predictive", |
72 | 79 | "sample_prior_predictive", |
73 | 80 | "vectorize_over_posterior", |
@@ -1083,3 +1090,122 @@ def vectorize_over_posterior( |
1083 | 1090 | f"The following random variables found in the extracted graph: {remaining_rvs}" |
1084 | 1091 | ) |
1085 | 1092 | return vectorized_outputs |
| 1093 | + |
| 1094 | + |
| 1095 | +def loop_over_posterior( |
| 1096 | + outputs: list[Variable], |
| 1097 | + posterior: xr.Dataset, |
| 1098 | + input_rvs: list[Variable], |
| 1099 | + input_tensors: Sequence[Variable] = (), |
| 1100 | + allow_rvs_in_graph: bool = True, |
| 1101 | + sample_dims: tuple[str, ...] = ("chain", "draw"), |
| 1102 | +) -> tuple[list[Variable], dict[Variable, Variable]]: |
| 1103 | + """Loop over posterior samples of subset of input rvs. |
| 1104 | +
|
| 1105 | + This function creates a new graph for the supplied outputs, where the required |
| 1106 | + subset of input rvs are replaced by their posterior samples (chain and draw |
| 1107 | + dimensions, or the dimensions provided in sample_dims are flattened). The other |
| 1108 | + input tensors are kept as is. |
| 1109 | +
|
| 1110 | + Parameters |
| 1111 | + ---------- |
| 1112 | + outputs : list[Variable] |
| 1113 | + The list of variables to vectorize over the posterior samples. |
| 1114 | + posterior : xr.Dataset |
| 1115 | + The posterior samples to use as replacements for the `input_rvs`. |
| 1116 | + input_rvs : list[Variable] |
| 1117 | + The list of random variables to replace with their posterior samples. |
| 1118 | + input_tensors : Sequence[Variable] |
| 1119 | + The list of tensors to keep as is. |
| 1120 | + allow_rvs_in_graph : bool |
| 1121 | + Whether to allow random variables to be present in the graph. If False, |
| 1122 | + an error will be raised if any random variables are found in the graph. If |
| 1123 | + True, the remaining random variables will be resized to match the number of |
| 1124 | + draws from the posterior. |
| 1125 | + sample_dims : tuple[str, ...] |
| 1126 | + The dimensions of the posterior samples to use for looping the `input_rvs`. |
| 1127 | +
|
| 1128 | + Returns |
| 1129 | + ------- |
| 1130 | + looped_outputs : list[Variable] |
| 1131 | + The looped variables, reshaped to match the original shape of the outputs, but |
| 1132 | + adding the sample_dims to the left. |
| 1133 | + updates : dict[Variable, Variable] |
| 1134 | + Dictionary of updates needed to compile the pytensor function to produce the |
| 1135 | + outputs. |
| 1136 | +
|
| 1137 | + Raises |
| 1138 | + ------ |
| 1139 | + RuntimeError |
| 1140 | + If random variables are found in the graph and `allow_rvs_in_graph` is False |
| 1141 | + ValueError |
| 1142 | + If the supplied output tensors do not depend on the requested input tensors |
| 1143 | + """ |
| 1144 | + if not (set(input_tensors) <= set(ancestors(outputs))): |
| 1145 | + raise ValueError( # pragma: no cover |
| 1146 | + "The supplied output tensors do not depend on the following requested " |
| 1147 | + f"input tensors: {set(input_tensors) - set(ancestors(outputs))}" |
| 1148 | + ) |
| 1149 | + outputs_ancestors = ancestors(outputs, blockers=input_rvs) |
| 1150 | + rvs_from_posterior: list[TensorVariable] = [ |
| 1151 | + cast(TensorVariable, rv) for rv in outputs_ancestors if rv in set(input_rvs) |
| 1152 | + ] |
| 1153 | + independent_rvs = [ |
| 1154 | + rv |
| 1155 | + for rv in rvs_in_graph(outputs) |
| 1156 | + if rv in outputs_ancestors and rv not in rvs_from_posterior |
| 1157 | + ] |
| 1158 | + |
| 1159 | + def step(*args): |
| 1160 | + input_values = args[: len(args) - len(input_tensors) - len(independent_rvs)] |
| 1161 | + non_sequences = args[len(args) - len(input_tensors) - len(independent_rvs) :] |
| 1162 | + |
| 1163 | + # Compute output sample value for input sample values |
| 1164 | + replace = { |
| 1165 | + **dict(zip(rvs_from_posterior, input_values, strict=True)), |
| 1166 | + } |
| 1167 | + samples = clone_while_sharing_some_variables( |
| 1168 | + outputs, replace=replace, kept_variables=non_sequences |
| 1169 | + ) |
| 1170 | + |
| 1171 | + # Collect updates if there are RV Ops in the graph |
| 1172 | + updates = collect_default_updates(outputs=samples, inputs=input_values) |
| 1173 | + return (*samples,), updates |
| 1174 | + |
| 1175 | + sequences = [] |
| 1176 | + batch_shape = tuple([len(posterior.coords[dim]) for dim in sample_dims]) |
| 1177 | + nsamples = np.prod(batch_shape) |
| 1178 | + for rv in rvs_from_posterior: |
| 1179 | + values = posterior[rv.name].data |
| 1180 | + sequences.append( |
| 1181 | + pt.constant( |
| 1182 | + np.reshape(values, (nsamples, *values.shape[2:])), |
| 1183 | + name=rv.name, |
| 1184 | + dtype=rv.dtype, |
| 1185 | + ) |
| 1186 | + ) |
| 1187 | + scan_out, updates = scan( |
| 1188 | + fn=step, |
| 1189 | + sequences=sequences, |
| 1190 | + non_sequences=[*input_tensors, *independent_rvs], |
| 1191 | + n_steps=nsamples, |
| 1192 | + ) |
| 1193 | + if len(outputs) == 1: |
| 1194 | + scan_out = [scan_out] # pragma: no cover |
| 1195 | + |
| 1196 | + looped: list[Variable] = [] |
| 1197 | + for out in scan_out: |
| 1198 | + core_shape = tuple( |
| 1199 | + [ |
| 1200 | + static if static is not None else dynamic |
| 1201 | + for static, dynamic in zip(out.type.shape[1:], out.shape[1:]) |
| 1202 | + ] |
| 1203 | + ) |
| 1204 | + looped.append(pt.reshape(out, (*batch_shape, *core_shape))) |
| 1205 | + if not allow_rvs_in_graph: |
| 1206 | + remaining_rvs = rvs_in_graph(looped) |
| 1207 | + if remaining_rvs: |
| 1208 | + raise RuntimeError( |
| 1209 | + f"The following random variables found in the extracted graph: {remaining_rvs}" |
| 1210 | + ) |
| 1211 | + return looped, updates |
0 commit comments