Skip to content

Commit 1ca1c94

Browse files
committed
Replace rvs_to_total_sizes mapping by ManibatchRandomVariables
1 parent 33d641d commit 1ca1c94

File tree

15 files changed

+320
-323
lines changed

15 files changed

+320
-323
lines changed

.github/workflows/tests.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,7 @@ jobs:
240240
- |
241241
tests/sampling/test_parallel.py
242242
tests/test_data.py
243+
tests/variational/test_minibatch_rv.py
243244
tests/test_model.py
244245
245246
- |

pymc/logprob/joint_logprob.py

Lines changed: 4 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@
3939
from collections import deque
4040
from typing import Dict, List, Optional, Sequence, Union
4141

42-
import numpy as np
4342
import pytensor
4443
import pytensor.tensor as pt
4544

@@ -55,7 +54,6 @@
5554
from pymc.logprob.rewriting import construct_ir_fgraph
5655
from pymc.logprob.transforms import RVTransform, TransformValuesRewrite
5756
from pymc.logprob.utils import rvs_to_value_vars
58-
from pymc.pytensorf import floatX
5957

6058

6159
def logp(rv: TensorVariable, value) -> TensorVariable:
@@ -248,77 +246,6 @@ def factorized_joint_logprob(
248246
return logprob_vars
249247

250248

251-
TOTAL_SIZE = Union[int, Sequence[int], None]
252-
253-
254-
def _get_scaling(total_size: TOTAL_SIZE, shape, ndim: int) -> TensorVariable:
255-
"""
256-
Gets scaling constant for logp.
257-
258-
Parameters
259-
----------
260-
total_size: Optional[int|List[int]]
261-
size of a fully observed data without minibatching,
262-
`None` means data is fully observed
263-
shape: shape
264-
shape of an observed data
265-
ndim: int
266-
ndim hint
267-
268-
Returns
269-
-------
270-
scalar
271-
"""
272-
if total_size is None:
273-
coef = 1.0
274-
elif isinstance(total_size, int):
275-
if ndim >= 1:
276-
denom = shape[0]
277-
else:
278-
denom = 1
279-
coef = floatX(total_size) / floatX(denom)
280-
elif isinstance(total_size, (list, tuple)):
281-
if not all(isinstance(i, int) for i in total_size if (i is not Ellipsis and i is not None)):
282-
raise TypeError(
283-
"Unrecognized `total_size` type, expected "
284-
"int or list of ints, got %r" % total_size
285-
)
286-
if Ellipsis in total_size:
287-
sep = total_size.index(Ellipsis)
288-
begin = total_size[:sep]
289-
end = total_size[sep + 1 :]
290-
if Ellipsis in end:
291-
raise ValueError(
292-
"Double Ellipsis in `total_size` is restricted, got %r" % total_size
293-
)
294-
else:
295-
begin = total_size
296-
end = []
297-
if (len(begin) + len(end)) > ndim:
298-
raise ValueError(
299-
"Length of `total_size` is too big, "
300-
"number of scalings is bigger that ndim, got %r" % total_size
301-
)
302-
elif (len(begin) + len(end)) == 0:
303-
coef = 1.0
304-
if len(end) > 0:
305-
shp_end = shape[-len(end) :]
306-
else:
307-
shp_end = np.asarray([])
308-
shp_begin = shape[: len(begin)]
309-
begin_coef = [
310-
floatX(t) / floatX(shp_begin[i]) for i, t in enumerate(begin) if t is not None
311-
]
312-
end_coef = [floatX(t) / floatX(shp_end[i]) for i, t in enumerate(end) if t is not None]
313-
coefs = begin_coef + end_coef
314-
coef = pt.prod(coefs)
315-
else:
316-
raise TypeError(
317-
"Unrecognized `total_size` type, expected int or list of ints, got %r" % total_size
318-
)
319-
return pt.as_tensor(coef, dtype=pytensor.config.floatX)
320-
321-
322249
def _check_no_rvs(logp_terms: Sequence[TensorVariable]):
323250
# Raise if there are unexpected RandomVariables in the logp graph
324251
# Only SimulatorRVs MinibatchIndexRVs are allowed
@@ -348,7 +275,6 @@ def joint_logp(
348275
rvs_to_values: Dict[TensorVariable, TensorVariable],
349276
rvs_to_transforms: Dict[TensorVariable, RVTransform],
350277
jacobian: bool = True,
351-
rvs_to_total_sizes: Dict[TensorVariable, TOTAL_SIZE],
352278
**kwargs,
353279
) -> List[TensorVariable]:
354280
"""Thin wrapper around pymc.logprob.factorized_joint_logprob, extended with Model
@@ -371,18 +297,13 @@ def joint_logp(
371297
**kwargs,
372298
)
373299

374-
# The function returns the logp for every single value term we provided to it. This
375-
# includes the extra values we plugged in above, so we filter those we actually
376-
# wanted in the same order they were given in.
300+
# The function returns the logp for every single value term we provided to it.
301+
# This includes the extra values we plugged in above, so we filter those we
302+
# actually wanted in the same order they were given in.
377303
logp_terms = {}
378304
for rv in rvs:
379305
value_var = rvs_to_values[rv]
380-
logp_term = temp_logp_terms[value_var]
381-
total_size = rvs_to_total_sizes.get(rv, None)
382-
if total_size is not None:
383-
scaling = _get_scaling(total_size, value_var.shape, value_var.ndim)
384-
logp_term *= scaling
385-
logp_terms[value_var] = logp_term
306+
logp_terms[value_var] = temp_logp_terms[value_var]
386307

387308
_check_no_rvs(list(logp_terms.values()))
388309
return list(logp_terms.values())

pymc/model.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -564,7 +564,6 @@ def __init__(
564564
self.values_to_rvs = treedict(parent=self.parent.values_to_rvs)
565565
self.rvs_to_values = treedict(parent=self.parent.rvs_to_values)
566566
self.rvs_to_transforms = treedict(parent=self.parent.rvs_to_transforms)
567-
self.rvs_to_total_sizes = treedict(parent=self.parent.rvs_to_total_sizes)
568567
self.rvs_to_initial_values = treedict(parent=self.parent.rvs_to_initial_values)
569568
self.free_RVs = treelist(parent=self.parent.free_RVs)
570569
self.observed_RVs = treelist(parent=self.parent.observed_RVs)
@@ -578,7 +577,6 @@ def __init__(
578577
self.values_to_rvs = treedict()
579578
self.rvs_to_values = treedict()
580579
self.rvs_to_transforms = treedict()
581-
self.rvs_to_total_sizes = treedict()
582580
self.rvs_to_initial_values = treedict()
583581
self.free_RVs = treelist()
584582
self.observed_RVs = treelist()
@@ -762,7 +760,6 @@ def logp(
762760
rvs=rvs,
763761
rvs_to_values=self.rvs_to_values,
764762
rvs_to_transforms=self.rvs_to_transforms,
765-
rvs_to_total_sizes=self.rvs_to_total_sizes,
766763
jacobian=jacobian,
767764
)
768765
assert isinstance(rv_logps, list)
@@ -1314,8 +1311,6 @@ def register_rv(
13141311
name = self.name_for(name)
13151312
rv_var.name = name
13161313
_add_future_warning_tag(rv_var)
1317-
rv_var.tag.total_size = total_size
1318-
self.rvs_to_total_sizes[rv_var] = total_size
13191314

13201315
# Associate previously unknown dimension names with
13211316
# the length of the corresponding RV dimension.
@@ -1327,6 +1322,8 @@ def register_rv(
13271322
self.add_coord(dname, values=None, length=rv_var.shape[d])
13281323

13291324
if observed is None:
1325+
if total_size is not None:
1326+
raise ValueError("total_size can only be passed to observed RVs")
13301327
self.free_RVs.append(rv_var)
13311328
self.create_value_var(rv_var, transform)
13321329
self.add_named_variable(rv_var, dims)
@@ -1351,12 +1348,17 @@ def register_rv(
13511348

13521349
# `rv_var` is potentially changed by `make_obs_var`,
13531350
# for example into a new graph for imputation of missing data.
1354-
rv_var = self.make_obs_var(rv_var, observed, dims, transform)
1351+
rv_var = self.make_obs_var(rv_var, observed, dims, transform, total_size)
13551352

13561353
return rv_var
13571354

13581355
def make_obs_var(
1359-
self, rv_var: TensorVariable, data: np.ndarray, dims, transform: Optional[Any]
1356+
self,
1357+
rv_var: TensorVariable,
1358+
data: np.ndarray,
1359+
dims,
1360+
transform: Union[Any, None],
1361+
total_size: Union[int, None],
13601362
) -> TensorVariable:
13611363
"""Create a `TensorVariable` for an observed random variable.
13621364
@@ -1392,18 +1394,16 @@ def make_obs_var(
13921394

13931395
mask = getattr(data, "mask", None)
13941396
if mask is not None:
1395-
if mask.all():
1396-
# If there are no observed values, this variable isn't really
1397-
# observed.
1398-
return rv_var
1399-
14001397
impute_message = (
14011398
f"Data in {rv_var} contains missing values and"
14021399
" will be automatically imputed from the"
14031400
" sampling distribution."
14041401
)
14051402
warnings.warn(impute_message, ImputationWarning)
14061403

1404+
if total_size is not None:
1405+
raise ValueError("total_size is not compatible with imputed variables")
1406+
14071407
if not isinstance(rv_var.owner.op, RandomVariable):
14081408
raise NotImplementedError(
14091409
"Automatic inputation is only supported for univariate RandomVariables."
@@ -1471,6 +1471,13 @@ def make_obs_var(
14711471
data = sparse.basic.as_sparse(data, name=name)
14721472
else:
14731473
data = at.as_tensor_variable(data, name=name)
1474+
1475+
if total_size:
1476+
from pymc.variational.minibatch_rv import create_minibatch_rv
1477+
1478+
rv_var = create_minibatch_rv(rv_var, total_size)
1479+
rv_var.name = name
1480+
14741481
rv_var.tag.observations = data
14751482
self.create_value_var(rv_var, transform=None, value_var=data)
14761483
self.add_named_variable(rv_var, dims)

pymc/util.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -489,7 +489,6 @@ def __getattribute__(self, name):
489489
for deprecated_names, alternative in (
490490
(("value_var", "observations"), "model.rvs_to_values[rv]"),
491491
(("transform",), "model.rvs_to_transforms[rv]"),
492-
(("total_size",), "model.rvs_to_total_sizes[rv]"),
493492
):
494493
if name in deprecated_names:
495494
try:

pymc/variational/minibatch_rv.py

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
# Copyright 2023 The PyMC Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from typing import Any, Sequence, Union, cast
15+
16+
import pytensor.tensor as pt
17+
18+
from pytensor import Variable, config
19+
from pytensor.graph import Apply, Op
20+
from pytensor.tensor import NoneConst, TensorVariable, as_tensor_variable
21+
22+
from pymc.logprob.abstract import MeasurableVariable, _get_measurable_outputs, _logprob
23+
from pymc.logprob.abstract import logprob as logprob_logprob
24+
from pymc.logprob.utils import ignore_logprob
25+
26+
27+
class MinibatchRandomVariable(Op):
28+
"""RV whose logprob should be rescaled to match total_size"""
29+
30+
__props__ = ()
31+
view_map = {0: [0]}
32+
33+
def make_node(self, rv, *total_size):
34+
rv = as_tensor_variable(rv)
35+
total_size = [
36+
as_tensor_variable(t, dtype="int64", ndim=0) if t is not None else NoneConst
37+
for t in total_size
38+
]
39+
assert len(total_size) == rv.ndim
40+
out = rv.type()
41+
return Apply(self, [rv, *total_size], [out])
42+
43+
def perform(self, node, inputs, output_storage):
44+
output_storage[0][0] = inputs[0]
45+
46+
47+
minibatch_rv = MinibatchRandomVariable()
48+
49+
50+
EllipsisType = Any # EllipsisType is not present in Python 3.8 yet
51+
52+
53+
def create_minibatch_rv(
54+
rv: TensorVariable,
55+
total_size: Union[int, None, Sequence[Union[int, EllipsisType, None]]],
56+
) -> TensorVariable:
57+
"""Create variable whose logp is rescaled by total_size."""
58+
if isinstance(total_size, int):
59+
if rv.ndim <= 1:
60+
total_size = [total_size]
61+
else:
62+
missing_ndims = rv.ndim - 1
63+
total_size = [total_size] + [None] * missing_ndims
64+
elif isinstance(total_size, (list, tuple)):
65+
total_size = list(total_size)
66+
if Ellipsis in total_size:
67+
# Replace Ellipsis by None
68+
if total_size.count(Ellipsis) > 1:
69+
raise ValueError("Only one Ellipsis can be present in total_size")
70+
sep = total_size.index(Ellipsis)
71+
begin = total_size[:sep]
72+
end = total_size[sep + 1 :]
73+
missing_ndims = max((rv.ndim - len(begin) - len(end), 0))
74+
total_size = begin + [None] * missing_ndims + end
75+
if len(total_size) > rv.ndim:
76+
raise ValueError(f"Length of total_size {total_size} is langer than RV ndim {rv.ndim}")
77+
else:
78+
raise TypeError(f"Invalid type for total_size: {total_size}")
79+
80+
rv = ignore_logprob(rv)
81+
82+
return cast(TensorVariable, minibatch_rv(rv, *total_size))
83+
84+
85+
def get_scaling(total_size: Sequence[Variable], shape: TensorVariable) -> TensorVariable:
86+
"""Gets scaling constant for logp."""
87+
88+
# mypy doesn't understand we can convert a shape TensorVariable into a tuple
89+
shape = tuple(shape) # type: ignore
90+
91+
# Scalar RV
92+
if len(shape) == 0: # type: ignore
93+
coef = total_size[0] if not NoneConst.equals(total_size[0]) else 1.0
94+
else:
95+
coefs = [t / shape[i] for i, t in enumerate(total_size) if not NoneConst.equals(t)]
96+
coef = pt.prod(coefs)
97+
98+
return pt.cast(coef, dtype=config.floatX)
99+
100+
101+
MeasurableVariable.register(MinibatchRandomVariable)
102+
103+
104+
@_get_measurable_outputs.register(MinibatchRandomVariable)
105+
def _get_measurable_outputs_minibatch_random_variable(op, node):
106+
return [node.outputs[0]]
107+
108+
109+
@_logprob.register(MinibatchRandomVariable)
110+
def minibatch_rv_logprob(op, values, *inputs, **kwargs):
111+
[value] = values
112+
rv, *total_size = inputs
113+
return logprob_logprob(rv, value, **kwargs) * get_scaling(total_size, value.shape)

pymc/variational/opvi.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,6 @@
6666
from pymc.backends.ndarray import NDArray
6767
from pymc.blocking import DictToArrayBijection
6868
from pymc.initial_point import make_initial_point_fn
69-
from pymc.logprob.joint_logprob import _get_scaling
7069
from pymc.model import modelcontext
7170
from pymc.pytensorf import (
7271
SeedSequenceSeed,
@@ -82,6 +81,7 @@
8281
_get_seeds_per_chain,
8382
locally_cachedmethod,
8483
)
84+
from pymc.variational.minibatch_rv import MinibatchRandomVariable, get_scaling
8585
from pymc.variational.updates import adagrad_window
8686
from pymc.vartypes import discrete_types
8787

@@ -1069,9 +1069,11 @@ def symbolic_normalizing_constant(self):
10691069
t = self.to_flat_input(
10701070
at.max(
10711071
[
1072-
_get_scaling(self.model.rvs_to_total_sizes.get(v, None), v.shape, v.ndim)
1072+
get_scaling(v.owner.inputs[1:], v.shape)
10731073
for v in self.group
1074+
if isinstance(v.owner.op, MinibatchRandomVariable)
10741075
]
1076+
+ [1.0] # To avoid empty max
10751077
)
10761078
)
10771079
t = self.symbolic_single_sample(t)
@@ -1237,12 +1239,9 @@ def symbolic_normalizing_constant(self):
12371239
t = at.max(
12381240
self.collect("symbolic_normalizing_constant")
12391241
+ [
1240-
_get_scaling(
1241-
self.model.rvs_to_total_sizes.get(obs, None),
1242-
obs.shape,
1243-
obs.ndim,
1244-
)
1242+
get_scaling(obs.owner.inputs[1:], obs.shape)
12451243
for obs in self.model.observed_RVs
1244+
if isinstance(obs.owner.op, MinibatchRandomVariable)
12461245
]
12471246
)
12481247
t = at.switch(self._scale_cost_to_minibatch, t, at.constant(1, dtype=t.dtype))

0 commit comments

Comments
 (0)