Skip to content

Commit 5537753

Browse files
authored
Extend Delta distribution to multiple dimensions (#76)
* adding vmf distribution * introduce DeltaNormal and DeltaVMF distributions * small changes * add dependencies + fix things * fix docs * remove vmf distribution * fix docstrings * fix docstring * make delta/delta_normal backwards compatible * add test * extend docs * add one more assert to test * add failing test * fix docstrings * change spherical to isotropic * Adapt tests for delta distribution * fix indent * Update test_distributions.py * Increase allowed tolerance in num. test
1 parent 7d48954 commit 5537753

File tree

4 files changed

+116
-12
lines changed

4 files changed

+116
-12
lines changed

cebra/data/single_session.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import abc
1919
import collections
20+
import warnings
2021
from typing import List
2122

2223
import literate_dataclasses as dataclasses
@@ -164,9 +165,9 @@ class ContinuousDataLoader(cebra_data.Loader):
164165
* auxiliary variables, using the empirical distribution of how behavior various across
165166
``time_offset`` timesteps (``time_delta``). Sampling for this setting is implemented
166167
in :py:class:`cebra.distributions.continuous.TimedeltaDistribution`.
167-
* alternatively, the distribution can be selected to be a Gaussian distribution parametrized
168-
by a fixed ``delta`` around the reference sample, using the implementation in
169-
:py:class:`cebra.distributions.continuous.DeltaDistribution`.
168+
* alternatively, the distribution can be selected to be a Gaussian distribution
169+
parametrized by a fixed ``delta`` around the reference sample, using the implementation in
170+
:py:class:`cebra.distributions.continuous.DeltaNormalDistribution`.
170171
171172
Args:
172173
See dataclass fields.
@@ -208,8 +209,14 @@ def _init_distribution(self):
208209
self.dataset.continuous_index,
209210
self.time_offset,
210211
device=self.device)
211-
elif self.conditional == "delta":
212-
self.distribution = cebra.distributions.DeltaDistribution(
212+
213+
elif self.conditional in ("delta", "delta_normal"):
214+
if self.conditional == "delta":
215+
warnings.warn(
216+
'"delta" distribution will be deprecated in an upcoming release. Please use "delta_normal" instead.',
217+
DeprecationWarning)
218+
219+
self.distribution = cebra.distributions.DeltaNormalDistribution(
213220
self.dataset.continuous_index,
214221
self.delta,
215222
device=self.device)

cebra/distributions/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
"Discrete",
5252
"DiscreteUniform",
5353
"DiscreteEmpirical",
54+
"DeltaNormalDistribution",
5455
"MultivariateDiscrete",
5556
"MultisessionSampler",
5657
]

cebra/distributions/continuous.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -242,14 +242,16 @@ def sample_conditional(self, reference_idx: torch.Tensor) -> torch.Tensor:
242242
return self.index.search(query)
243243

244244

245-
class DeltaDistribution(abc_.JointDistribution, abc_.HasGenerator):
245+
class DeltaNormalDistribution(abc_.JointDistribution, abc_.HasGenerator):
246246
"""Define a conditional distribution based on behavioral changes over time.
247247
248-
Takes a continuous index, and uses sample from Gaussian distribution to sample positive
248+
Takes a continuous index, and uses sample from Gaussian distribution to sample positive pairs.
249+
Note that if the continuous index is multidimensional, the Gaussian distribution will have
250+
isotropic covariance matrix i.e. Σ = sigma^2 * I.
249251
250252
Args:
251-
continuous: The multidimensional, continuous index
252-
delta: Standard deviation of Gaussian distribution to sample positive pair
253+
continuous: The multidimensional, continuous index.
254+
delta: Standard deviation of Gaussian distribution to sample positive pair.
253255
254256
"""
255257

@@ -277,12 +279,14 @@ def sample_conditional(self, reference_idx: torch.Tensor) -> torch.Tensor:
277279
"Pass a 1D array of indices of reference samples.")
278280

279281
# TODO(stes): Set seed
282+
mean = self.data[reference_idx]
280283
query = torch.distributions.Normal(
281-
self.data[reference_idx].squeeze(),
282-
torch.ones_like(reference_idx, device=self.device) * self.std,
284+
loc=mean,
285+
scale=torch.ones_like(mean, device=self.device) * self.std,
283286
).sample()
284287

285-
return self.index.search(query.unsqueeze(-1))
288+
query = query.unsqueeze(-1) if query.dim() == 1 else query
289+
return self.index.search(query)
286290

287291

288292
class CEBRADistribution(abc_.JointDistribution):

tests/test_distributions.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,15 @@
1010
# https://github.com/AdaptiveMotorControlLab/CEBRA/LICENSE.md
1111
#
1212
import functools
13+
from typing import Literal, Optional
1314

1415
import numpy as np
1516
import pytest
1617
import torch
1718

1819
import cebra.datasets as cebra_datasets
1920
import cebra.distributions as cebra_distr
21+
import cebra.distributions.base as cebra_distr_base
2022

2123

2224
def assert_is_tensor(T, device=None):
@@ -284,3 +286,93 @@ def test_multi_session_time_contrastive(time_offset):
284286
# NOTE(celia): test the private function ``_inverse_idx()``, with idx arrays flat
285287
assert (idx.flatten()[rev_idx.flatten()].all() == np.arange(
286288
len(rev_idx.flatten())).all())
289+
290+
291+
class OldDeltaDistribution(cebra_distr_base.JointDistribution,
292+
cebra_distr_base.HasGenerator):
293+
"""
294+
Old version of the Delta Distribution where it only works for 1d
295+
behavior variable.
296+
297+
"""
298+
299+
def __init__(self,
300+
continuous: torch.Tensor,
301+
delta: float = 0.1,
302+
device: Literal["cpu", "cuda"] = "cpu",
303+
seed: Optional[int] = 1812):
304+
cebra_distr_base.HasGenerator.__init__(self, device=device, seed=seed)
305+
torch.manual_seed(seed)
306+
self.data = continuous
307+
self.std = delta
308+
self.index = cebra_distr.ContinuousIndex(self.data)
309+
self.prior = cebra_distr.Prior(self.data, device=device, seed=seed)
310+
311+
def sample_prior(self, num_samples: int) -> torch.Tensor:
312+
"""See :py:meth:`.Prior.sample_prior`."""
313+
return self.prior.sample_prior(num_samples)
314+
315+
def sample_conditional(self, reference_idx: torch.Tensor) -> torch.Tensor:
316+
"""Return indices from the conditional distribution."""
317+
318+
if reference_idx.dim() != 1:
319+
raise ValueError(
320+
f"Reference indices have wrong shape: {reference_idx.shape}. "
321+
"Pass a 1D array of indices of reference samples.")
322+
323+
# TODO(stes): Set seed
324+
query = torch.distributions.Normal(
325+
self.data[reference_idx].squeeze(),
326+
torch.ones_like(reference_idx, device=self.device) * self.std,
327+
).sample()
328+
329+
return self.index.search(query.unsqueeze(-1))
330+
331+
332+
def test_old_vs_new_delta_normal_with_1Dindex():
333+
_, continuous = prepare()
334+
assert continuous.dim() == 2
335+
num_samples = len(continuous)
336+
reference_idx = torch.randint(0, num_samples, (num_samples,))
337+
338+
new_distribution = cebra_distr.DeltaNormalDistribution(
339+
continuous=continuous[:, 0].unsqueeze(-1), delta=0.1)
340+
341+
old_distribution = OldDeltaDistribution(
342+
continuous=continuous[:, 0].unsqueeze(-1), delta=0.1)
343+
344+
torch.manual_seed(1812)
345+
old_positives = old_distribution.sample_conditional(reference_idx)
346+
torch.manual_seed(1812)
347+
new_positives = new_distribution.sample_conditional(reference_idx)
348+
349+
assert not torch.equal(old_positives, reference_idx)
350+
assert not torch.equal(new_positives, reference_idx)
351+
assert torch.equal(old_positives, new_positives)
352+
353+
354+
@pytest.mark.parametrize("delta,numerical_check", [(0.01, True), (0.025, True), (1., False), (5., False)])
355+
def test_new_delta_normal_with_multidimensional_index(delta, numerical_check):
356+
continuous = torch.rand(100_000, 3).to("cpu")
357+
num_samples = 1000
358+
delta_normal_multidim = cebra_distr.DeltaNormalDistribution(
359+
delta=delta, continuous=continuous)
360+
reference_idx = delta_normal_multidim.sample_prior(num_samples)
361+
positive_idx = delta_normal_multidim.sample_conditional(reference_idx)
362+
363+
assert positive_idx.dim() == 1
364+
assert len(positive_idx) == num_samples
365+
assert not torch.equal(positive_idx, reference_idx)
366+
367+
if numerical_check:
368+
reference_samples = continuous[reference_idx]
369+
positive_samples = continuous[positive_idx]
370+
diff = positive_samples - reference_samples
371+
#TODO(stes): Improve test, use lower error margin here
372+
assert torch.isclose(diff.std(), torch.tensor(delta), rtol=0.1)
373+
else:
374+
#TODO(stes): Add a warning message to the delta distribution.
375+
pytest.skip(
376+
"multivariate delta distribution can not accurately sample with the "
377+
"given parameters. TODO: Add a warning message for these cases."
378+
)

0 commit comments

Comments
 (0)