|
| 1 | +# Copyright 2023 The PyMC Developers |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | +from typing import Any, Sequence, Union, cast |
| 15 | + |
| 16 | +import pytensor.tensor as pt |
| 17 | + |
| 18 | +from pytensor import Variable, config |
| 19 | +from pytensor.graph import Apply, Op |
| 20 | +from pytensor.tensor import NoneConst, TensorVariable, as_tensor_variable |
| 21 | + |
| 22 | +from pymc.logprob.abstract import MeasurableVariable, _get_measurable_outputs, _logprob |
| 23 | +from pymc.logprob.abstract import logprob as logprob_logprob |
| 24 | +from pymc.logprob.utils import ignore_logprob |
| 25 | + |
| 26 | + |
| 27 | +class MinibatchRandomVariable(Op): |
| 28 | + """RV whose logprob should be rescaled to match total_size""" |
| 29 | + |
| 30 | + __props__ = () |
| 31 | + view_map = {0: [0]} |
| 32 | + |
| 33 | + def make_node(self, rv, *total_size): |
| 34 | + rv = as_tensor_variable(rv) |
| 35 | + total_size = [ |
| 36 | + as_tensor_variable(t, dtype="int64", ndim=0) if t is not None else NoneConst |
| 37 | + for t in total_size |
| 38 | + ] |
| 39 | + assert len(total_size) == rv.ndim |
| 40 | + out = rv.type() |
| 41 | + return Apply(self, [rv, *total_size], [out]) |
| 42 | + |
| 43 | + def perform(self, node, inputs, output_storage): |
| 44 | + output_storage[0][0] = inputs[0] |
| 45 | + |
| 46 | + |
| 47 | +minibatch_rv = MinibatchRandomVariable() |
| 48 | + |
| 49 | + |
| 50 | +EllipsisType = Any # EllipsisType is not present in Python 3.8 yet |
| 51 | + |
| 52 | + |
| 53 | +def create_minibatch_rv( |
| 54 | + rv: TensorVariable, |
| 55 | + total_size: Union[int, None, Sequence[Union[int, EllipsisType, None]]], |
| 56 | +) -> TensorVariable: |
| 57 | + """Create variable whose logp is rescaled by total_size.""" |
| 58 | + if isinstance(total_size, int): |
| 59 | + if rv.ndim <= 1: |
| 60 | + total_size = [total_size] |
| 61 | + else: |
| 62 | + missing_ndims = rv.ndim - 1 |
| 63 | + total_size = [total_size] + [None] * missing_ndims |
| 64 | + elif isinstance(total_size, (list, tuple)): |
| 65 | + total_size = list(total_size) |
| 66 | + if Ellipsis in total_size: |
| 67 | + # Replace Ellipsis by None |
| 68 | + if total_size.count(Ellipsis) > 1: |
| 69 | + raise ValueError("Only one Ellipsis can be present in total_size") |
| 70 | + sep = total_size.index(Ellipsis) |
| 71 | + begin = total_size[:sep] |
| 72 | + end = total_size[sep + 1 :] |
| 73 | + missing_ndims = max((rv.ndim - len(begin) - len(end), 0)) |
| 74 | + total_size = begin + [None] * missing_ndims + end |
| 75 | + if len(total_size) > rv.ndim: |
| 76 | + raise ValueError(f"Length of total_size {total_size} is langer than RV ndim {rv.ndim}") |
| 77 | + else: |
| 78 | + raise TypeError(f"Invalid type for total_size: {total_size}") |
| 79 | + |
| 80 | + rv = ignore_logprob(rv) |
| 81 | + |
| 82 | + return cast(TensorVariable, minibatch_rv(rv, *total_size)) |
| 83 | + |
| 84 | + |
| 85 | +def get_scaling(total_size: Sequence[Variable], shape: TensorVariable) -> TensorVariable: |
| 86 | + """Gets scaling constant for logp.""" |
| 87 | + |
| 88 | + # mypy doesn't understand we can convert a shape TensorVariable into a tuple |
| 89 | + shape = tuple(shape) # type: ignore |
| 90 | + |
| 91 | + # Scalar RV |
| 92 | + if len(shape) == 0: # type: ignore |
| 93 | + coef = total_size[0] if not NoneConst.equals(total_size[0]) else 1.0 |
| 94 | + else: |
| 95 | + coefs = [t / shape[i] for i, t in enumerate(total_size) if not NoneConst.equals(t)] |
| 96 | + coef = pt.prod(coefs) |
| 97 | + |
| 98 | + return pt.cast(coef, dtype=config.floatX) |
| 99 | + |
| 100 | + |
| 101 | +MeasurableVariable.register(MinibatchRandomVariable) |
| 102 | + |
| 103 | + |
| 104 | +@_get_measurable_outputs.register(MinibatchRandomVariable) |
| 105 | +def _get_measurable_outputs_minibatch_random_variable(op, node): |
| 106 | + return [node.outputs[0]] |
| 107 | + |
| 108 | + |
| 109 | +@_logprob.register(MinibatchRandomVariable) |
| 110 | +def minibatch_rv_logprob(op, values, *inputs, **kwargs): |
| 111 | + [value] = values |
| 112 | + rv, *total_size = inputs |
| 113 | + return logprob_logprob(rv, value, **kwargs) * get_scaling(total_size, value.shape) |
0 commit comments