|
| 1 | +from itertools import zip_longest |
| 2 | + |
1 | 3 | from pytensor.compile import optdb
|
2 | 4 | from pytensor.configdefaults import config
|
3 | 5 | from pytensor.graph.op import compute_test_value
|
4 | 6 | from pytensor.graph.rewriting.basic import in2out, node_rewriter
|
| 7 | +from pytensor.tensor import NoneConst |
5 | 8 | from pytensor.tensor.basic import constant, get_vector_length
|
6 | 9 | from pytensor.tensor.elemwise import DimShuffle
|
7 | 10 | from pytensor.tensor.extra_ops import broadcast_to
|
|
17 | 20 | get_idx_list,
|
18 | 21 | indexed_result_shape,
|
19 | 22 | )
|
| 23 | +from pytensor.tensor.type_other import SliceType |
20 | 24 |
|
21 | 25 |
|
22 | 26 | def is_rv_used_in_graph(base_rv, node, fgraph):
|
@@ -196,141 +200,104 @@ def local_dimshuffle_rv_lift(fgraph, node):
|
196 | 200 | def local_subtensor_rv_lift(fgraph, node):
|
197 | 201 | """Lift a ``*Subtensor`` through ``RandomVariable`` inputs.
|
198 | 202 |
|
199 |
| - In a fashion similar to ``local_dimshuffle_rv_lift``, the indexed dimensions |
200 |
| - need to be separated into distinct replication-space and (independent) |
201 |
| - parameter-space ``*Subtensor``s. |
202 |
| -
|
203 |
| - The replication-space ``*Subtensor`` can be used to determine a |
204 |
| - sub/super-set of the replication-space and, thus, a "smaller"/"larger" |
205 |
| - ``size`` tuple. The parameter-space ``*Subtensor`` is simply lifted and |
206 |
| - applied to the distribution parameters. |
207 |
| -
|
208 |
| - Consider the following example graph: |
209 |
| - ``normal(mu, std, size=(d1, d2, d3))[idx1, idx2, idx3]``. The |
210 |
| - ``*Subtensor`` ``Op`` requests indices ``idx1``, ``idx2``, and ``idx3``, |
211 |
| - which correspond to all three ``size`` dimensions. Now, depending on the |
212 |
| - broadcasted dimensions of ``mu`` and ``std``, this ``*Subtensor`` ``Op`` |
213 |
| - could be reducing the ``size`` parameter and/or sub-setting the independent |
214 |
| - ``mu`` and ``std`` parameters. Only once the dimensions are properly |
215 |
| - separated into the two replication/parameter subspaces can we determine how |
216 |
| - the ``*Subtensor`` indices are distributed. |
217 |
| - For instance, ``normal(mu, std, size=(d1, d2, d3))[idx1, idx2, idx3]`` |
218 |
| - could become |
219 |
| - ``normal(mu[idx1], std[idx2], size=np.shape(idx1) + np.shape(idx2) + np.shape(idx3))`` |
220 |
| - if ``mu.shape == std.shape == ()`` |
221 |
| -
|
222 |
| - ``normal`` is a rather simple case, because it's univariate. Multivariate |
223 |
| - cases require a mapping between the parameter space and the image of the |
224 |
| - random variable. This may not always be possible, but for many common |
225 |
| - distributions it is. For example, the dimensions of the multivariate |
226 |
| - normal's image can be mapped directly to each dimension of its parameters. |
227 |
| - We use these mappings to change a graph like ``multivariate_normal(mu, Sigma)[idx1]`` |
228 |
| - into ``multivariate_normal(mu[idx1], Sigma[idx1, idx1])``. |
| 203 | + For example, ``normal(mu, std)[0] == normal(mu[0], std[0])``. |
229 | 204 |
|
| 205 | + This rewrite also applies to multivariate distributions as long |
| 206 | + as indexing does not happen within core dimensions, such as in |
| 207 | + ``mvnormal(mu, cov, size=(2,))[0, 0]``. |
230 | 208 | """
|
231 | 209 |
|
232 | 210 | st_op = node.op
|
233 | 211 |
|
234 | 212 | if not isinstance(st_op, (AdvancedSubtensor, AdvancedSubtensor1, Subtensor)):
|
235 | 213 | return False
|
236 | 214 |
|
237 |
| - base_rv = node.inputs[0] |
| 215 | + rv = node.inputs[0] |
| 216 | + rv_node = rv.owner |
238 | 217 |
|
239 |
| - rv_node = base_rv.owner |
240 | 218 | if not (rv_node and isinstance(rv_node.op, RandomVariable)):
|
241 | 219 | return False
|
242 | 220 |
|
243 |
| - # If no one else is using the underlying `RandomVariable`, then we can |
244 |
| - # do this; otherwise, the graph would be internally inconsistent. |
245 |
| - if is_rv_used_in_graph(base_rv, node, fgraph): |
246 |
| - return False |
247 |
| - |
248 | 221 | rv_op = rv_node.op
|
249 | 222 | rng, size, dtype, *dist_params = rv_node.inputs
|
250 | 223 |
|
251 |
| - # TODO: Remove this once the multi-dimensional changes described below are |
252 |
| - # in place. |
253 |
| - if rv_op.ndim_supp > 0: |
254 |
| - return False |
255 |
| - |
256 |
| - rv_op = base_rv.owner.op |
257 |
| - rng, size, dtype, *dist_params = base_rv.owner.inputs |
258 |
| - |
| 224 | + # Parse indices |
259 | 225 | idx_list = getattr(st_op, "idx_list", None)
|
260 | 226 | if idx_list:
|
261 | 227 | cdata = get_idx_list(node.inputs, idx_list)
|
262 | 228 | else:
|
263 | 229 | cdata = node.inputs[1:]
|
264 |
| - |
265 | 230 | st_indices, st_is_bool = zip(
|
266 | 231 | *tuple(
|
267 | 232 | (as_index_variable(i), getattr(i, "dtype", None) == "bool") for i in cdata
|
268 | 233 | )
|
269 | 234 | )
|
270 | 235 |
|
271 |
| - # We need to separate dimensions into replications and independents |
272 |
| - num_ind_dims = None |
273 |
| - if len(dist_params) == 1: |
274 |
| - num_ind_dims = dist_params[0].ndim |
275 |
| - else: |
276 |
| - # When there is more than one distribution parameter, assume that all |
277 |
| - # of them will broadcast to the maximum number of dimensions |
278 |
| - num_ind_dims = max(d.ndim for d in dist_params) |
279 |
| - |
280 |
| - reps_ind_split_idx = base_rv.ndim - (num_ind_dims + rv_op.ndim_supp) |
281 |
| - |
282 |
| - if len(st_indices) > reps_ind_split_idx: |
283 |
| - # These are the indices that need to be applied to the parameters |
284 |
| - ind_indices = tuple(st_indices[reps_ind_split_idx:]) |
285 |
| - |
286 |
| - # We need to broadcast the parameters before applying the `*Subtensor*` |
287 |
| - # with these indices, because the indices could be referencing broadcast |
288 |
| - # dimensions that don't exist (yet) |
289 |
| - bcast_dist_params = broadcast_params(dist_params, rv_op.ndims_params) |
290 |
| - |
291 |
| - # TODO: For multidimensional distributions, we need a map that tells us |
292 |
| - # which dimensions of the parameters need to be indexed. |
293 |
| - # |
294 |
| - # For example, `multivariate_normal` would have the following: |
295 |
| - # `RandomVariable.param_to_image_dims = ((0,), (0, 1))` |
296 |
| - # |
297 |
| - # I.e. the first parameter's (i.e. mean's) first dimension maps directly to |
298 |
| - # the dimension of the RV's image, and its second parameter's |
299 |
| - # (i.e. covariance's) first and second dimensions map directly to the |
300 |
| - # dimension of the RV's image. |
301 |
| - |
302 |
| - args_lifted = tuple(p[ind_indices] for p in bcast_dist_params) |
303 |
| - else: |
304 |
| - # In this case, no indexing is applied to the parameters; only the |
305 |
| - # `size` parameter is affected. |
306 |
| - args_lifted = dist_params |
| 236 | + # Check that indexing does not act on support dims |
| 237 | + batched_ndims = rv.ndim - rv_op.ndim_supp |
| 238 | + if len(st_indices) > batched_ndims: |
| 239 | + # If the last indexes are just dummy `slice(None)` we discard them |
| 240 | + st_is_bool = st_is_bool[:batched_ndims] |
| 241 | + st_indices, supp_indices = ( |
| 242 | + st_indices[:batched_ndims], |
| 243 | + st_indices[batched_ndims:], |
| 244 | + ) |
| 245 | + for index in supp_indices: |
| 246 | + if not ( |
| 247 | + isinstance(index.type, SliceType) |
| 248 | + and all(NoneConst.equals(i) for i in index.owner.inputs) |
| 249 | + ): |
| 250 | + return False |
| 251 | + |
| 252 | + # If no one else is using the underlying `RandomVariable`, then we can |
| 253 | + # do this; otherwise, the graph would be internally inconsistent. |
| 254 | + if is_rv_used_in_graph(rv, node, fgraph): |
| 255 | + return False |
307 | 256 |
|
| 257 | + # Update the size to reflect the indexed dimensions |
308 | 258 | # TODO: Could use `ShapeFeature` info. We would need to be sure that
|
309 | 259 | # `node` isn't in the results, though.
|
310 | 260 | # if hasattr(fgraph, "shape_feature"):
|
311 | 261 | # output_shape = fgraph.shape_feature.shape_of(node.outputs[0])
|
312 | 262 | # else:
|
313 |
| - output_shape = indexed_result_shape(base_rv.shape, st_indices) |
314 |
| - |
315 |
| - size_lifted = ( |
316 |
| - output_shape if rv_op.ndim_supp == 0 else output_shape[: -rv_op.ndim_supp] |
| 263 | + output_shape_ignoring_bool = indexed_result_shape(rv.shape, st_indices) |
| 264 | + new_size_ignoring_boolean = ( |
| 265 | + output_shape_ignoring_bool |
| 266 | + if rv_op.ndim_supp == 0 |
| 267 | + else output_shape_ignoring_bool[: -rv_op.ndim_supp] |
317 | 268 | )
|
318 | 269 |
|
319 |
| - # Boolean indices can actually change the `size` value (compared to just |
320 |
| - # *which* dimensions of `size` are used). |
| 270 | + # Boolean indices can actually change the `size` value (compared to just *which* dimensions of `size` are used). |
| 271 | + # The `indexed_result_shape` helper does not consider this |
321 | 272 | if any(st_is_bool):
|
322 |
| - size_lifted = tuple( |
| 273 | + new_size = tuple( |
323 | 274 | at_sum(idx) if is_bool else s
|
324 |
| - for s, is_bool, idx in zip( |
325 |
| - size_lifted, st_is_bool, st_indices[: (reps_ind_split_idx + 1)] |
| 275 | + for s, is_bool, idx in zip_longest( |
| 276 | + new_size_ignoring_boolean, st_is_bool, st_indices, fillvalue=False |
326 | 277 | )
|
327 | 278 | )
|
| 279 | + else: |
| 280 | + new_size = new_size_ignoring_boolean |
328 | 281 |
|
329 |
| - new_node = rv_op.make_node(rng, size_lifted, dtype, *args_lifted) |
330 |
| - _, new_rv = new_node.outputs |
| 282 | + # Update the parameters to reflect the indexed dimensions |
| 283 | + new_dist_params = [] |
| 284 | + for param, param_ndim_supp in zip(dist_params, rv_op.ndims_params): |
| 285 | + # Apply indexing on the batched dimensions of the parameter |
| 286 | + batched_param_dims_missing = batched_ndims - (param.ndim - param_ndim_supp) |
| 287 | + batched_param = shape_padleft(param, batched_param_dims_missing) |
| 288 | + batched_st_indices = [] |
| 289 | + for st_index, batched_param_shape in zip(st_indices, batched_param.type.shape): |
| 290 | + # If we have a degenerate dimension indexing it should always do the job |
| 291 | + if batched_param_shape == 1: |
| 292 | + batched_st_indices.append(0) |
| 293 | + else: |
| 294 | + batched_st_indices.append(st_index) |
| 295 | + new_dist_params.append(batched_param[tuple(batched_st_indices)]) |
| 296 | + |
| 297 | + # Create new RV |
| 298 | + new_node = rv_op.make_node(rng, new_size, dtype, *new_dist_params) |
| 299 | + new_rv = new_node.default_output() |
331 | 300 |
|
332 |
| - # Calling `Op.make_node` directly circumvents test value computations, so |
333 |
| - # we need to compute the test values manually |
334 | 301 | if config.compute_test_value != "off":
|
335 | 302 | compute_test_value(new_node)
|
336 | 303 |
|
|
0 commit comments