Skip to content

Commit 7f278b1

Browse files
committed
Apply review comments
1 parent cadd612 commit 7f278b1

File tree

15 files changed

+58
-47
lines changed

15 files changed

+58
-47
lines changed

.github/workflows/build.yml

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ on:
77
pull_request:
88
branches:
99
- main
10-
- stes/upgrade-docs-rebased
1110

1211
jobs:
1312
build:
@@ -54,12 +53,12 @@ jobs:
5453
run: |
5554
python -m pip install --upgrade pip setuptools wheel
5655
python -m pip install torch==${{ matrix.torch-version }} --extra-index-url https://download.pytorch.org/whl/cpu
57-
pip install '.[dev,datasets,integrations,xcebra]'
56+
pip install '.[dev,datasets,integrations]'
5857
5958
- name: Check sklearn legacy version
6059
if: matrix.sklearn-version == 'legacy'
6160
run: |
62-
pip install scikit-learn==1.4.2 '.[dev,datasets,integrations,xcebra]'
61+
pip install scikit-learn==1.4.2 '.[dev,datasets,integrations]'
6362
6463
- name: Run the formatter
6564
run: |

.github/workflows/docs.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@ on:
99
- main
1010
- public
1111
- dev
12-
- stes/upgrade-docs-rebased
1312
paths:
1413
- '**.py'
1514
- '**.ipynb'

Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ FROM cebra-base
4343
ENV WHEEL=cebra-0.6.0a1-py3-none-any.whl
4444
WORKDIR /build
4545
COPY --from=wheel /build/dist/${WHEEL} .
46-
RUN pip install --no-cache-dir ${WHEEL}'[dev,integrations,datasets,xcebra]'
46+
RUN pip install --no-cache-dir ${WHEEL}'[dev,integrations,datasets]'
4747
RUN rm -rf /build
4848

4949
# add the repository

cebra/data/single_session.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -359,8 +359,10 @@ def __post_init__(self):
359359
# e.g. integrating the FAISS dataloader back in.
360360
super().__post_init__()
361361

362-
# BEHAVIOR DISTRIBUTION
362+
self._init_behavior_distribution()
363+
self._init_time_distribution()
363364

365+
def _init_behavior_distribution(self):
364366
if self.conditional == "time":
365367
self.behavior_distribution = cebra.distributions.TimeContrastive(
366368
time_offset=self.time_offset,
@@ -385,7 +387,8 @@ def __post_init__(self):
385387
device=self.device,
386388
)
387389

388-
# TIME DISTRIBUTION
390+
def _init_time_distribution(self):
391+
389392
if self.time_distribution == "time":
390393
self.time_distribution = cebra.distributions.TimeContrastive(
391394
time_offset=self.time_offset,
@@ -403,9 +406,10 @@ def __post_init__(self):
403406
self.time_distribution = cebra.distributions.DeltaNormalDistribution(
404407
self.dataset.continuous_index, self.delta, device=self.device)
405408

406-
elif self.time_distribution == "delta_vmf":
407-
self.time_distribution = cebra.distributions.DeltaVMFDistribution(
408-
self.dataset.continuous_index, self.delta, device=self.device)
409+
# TODO(stes): Add this distribution from internal xCEBRA codebase at a later point
410+
#elif self.time_distribution == "delta_vmf":
411+
# self.time_distribution = cebra.distributions.DeltaVMFDistribution(
412+
# self.dataset.continuous_index, self.delta, device=self.device)
409413
else:
410414
raise ValueError
411415

cebra/models/multiobjective.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,13 @@
2929
import cebra.models.model as cebra_models_base
3030

3131

32-
def create_multiobjective_model(module, **kwargs) -> "MultiobjectiveModel":
32+
def create_multiobjective_model(module,
33+
**kwargs) -> "SubspaceMultiobjectiveModel":
3334
assert isinstance(module, cebra_models_base.Model)
3435
if isinstance(module, cebra.models.ConvolutionalModelMixin):
35-
return MultiobjectiveConvolutionalModel(module=module, **kwargs)
36+
return SubspaceMultiobjectiveConvolutionalModel(module=module, **kwargs)
3637
else:
37-
return MultiobjectiveModel(module=module, **kwargs)
38+
return SubspaceMultiobjectiveModel(module=module, **kwargs)
3839

3940

4041
def check_slices_for_gaps(slice_list):
@@ -106,7 +107,7 @@ def forward(self, inp):
106107
return inp / torch.norm(inp, dim=1, keepdim=True)
107108

108109

109-
class LegacyMultiobjectiveModel(nn.Module):
110+
class MultiobjectiveModel(nn.Module):
110111
"""Wrapper around contrastive learning models to all training with multiple objectives
111112
112113
Multi-objective training splits the last layer's feature representation into multiple
@@ -128,6 +129,13 @@ class LegacyMultiobjectiveModel(nn.Module):
128129
129130
TODO:
130131
- Update nn.Module type annotation for ``module`` to cebra.models.Model
132+
133+
Note:
134+
This model will be deprecated in a future version. Please use the functionality in
135+
:py:mod:`cebra.models.multiobjective` instead, which provides more versatile
136+
multi-objective training capabilities. Instantiation of this model will raise a
137+
deprecation warning. The new model is :py:class:`cebra.models.multiobjective.SubspaceMultiobjectiveModel`
138+
which allows for unlimited subspaces and better configuration of the feature ranges.
131139
"""
132140

133141
class Mode:
@@ -240,7 +248,7 @@ def forward(self, inputs):
240248
return tuple(outputs)
241249

242250

243-
class MultiobjectiveModel(nn.Module):
251+
class SubspaceMultiobjectiveModel(nn.Module):
244252
"""Wrapper around contrastive learning models to all training with multiple objectives
245253
246254
Multi-objective training splits the last layer's feature representation into multiple
@@ -354,7 +362,6 @@ def forward(self, inputs):
354362
return output
355363

356364

357-
class MultiobjectiveConvolutionalModel(MultiobjectiveModel,
358-
cebra_models_base.ConvolutionalModelMixin
359-
):
365+
class SubspaceMultiobjectiveConvolutionalModel(
366+
SubspaceMultiobjectiveModel, cebra_models_base.ConvolutionalModelMixin):
360367
pass

cebra/registry.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -287,10 +287,6 @@ def get_options(pattern: str = None,
287287
raise RuntimeError(
288288
f"Registry could not be successfully registered: {module}.")
289289

290-
# NOTE(stes): Used in xCEBRA initially. If you see this note past 0.6.0, please remove it
291-
# as the functionality is no longer needed.
292-
#return register, parametrize, init, get_options
293-
294290

295291
def add_docstring(module: Union[types.ModuleType, str]):
296292
"""Apply additional information about configuration options to registry modules.

cebra/solver/base.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232

3333
import abc
3434
import os
35+
import warnings
3536
from typing import Callable, Dict, List, Literal, Optional
3637

3738
import literate_dataclasses as dataclasses
@@ -367,11 +368,19 @@ class MultiobjectiveSolver(Solver):
367368
for time contrastive learning.
368369
renormalize_features: If ``True``, normalize the behavior and time
369370
contrastive features individually before computing similarity scores.
371+
ignore_deprecation_warning: If ``True``, suppress the deprecation warning.
372+
373+
Note:
374+
This solver will be deprecated in a future version. Please use the functionality in
375+
:py:mod:`cebra.solver.multiobjective` instead, which provides more versatile
376+
multi-objective training capabilities. Instantiation of this solver will raise a
377+
deprecation warning.
370378
"""
371379

372380
num_behavior_features: int = 3
373381
renormalize_features: bool = False
374382
output_mode: Literal["overlapping", "separate"] = "overlapping"
383+
ignore_deprecation_warning: bool = False
375384

376385
@property
377386
def num_time_features(self):
@@ -383,8 +392,15 @@ def num_total_features(self):
383392

384393
def __post_init__(self):
385394
super().__post_init__()
395+
if not self.ignore_deprecation_warning:
396+
warnings.warn(
397+
"MultiobjectiveSolver is deprecated since CEBRA 0.6.0 and will be removed in a future version. "
398+
"Use the new functionality in cebra.solver.multiobjective instead, which is more versatile. "
399+
"If you see this warning when using the scikit-learn interface, no action is required.",
400+
DeprecationWarning,
401+
stacklevel=2)
386402
self._check_dimensions()
387-
self.model = cebra.models.LegacyMultiobjectiveModel(
403+
self.model = cebra.models.MultiobjectiveModel(
388404
self.model,
389405
dimensions=(self.num_behavior_features, self.model.num_output),
390406
renormalize=self.renormalize_features,

cebra/solver/single_session.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
"""Single session solvers embed a single pair of time series."""
2323

2424
import copy
25-
from typing import Dict
2625

2726
import literate_dataclasses as dataclasses
2827
import torch
@@ -131,16 +130,6 @@ def _inference(self, batch):
131130
class SingleSessionHybridSolver(abc_.MultiobjectiveSolver):
132131
"""Single session training, contrasting neural data against behavior."""
133132

134-
log: Dict = dataclasses.field(default_factory=lambda: ({
135-
"behavior_pos": [],
136-
"behavior_neg": [],
137-
"behavior_total": [],
138-
"time_pos": [],
139-
"time_neg": [],
140-
"time_total": [],
141-
"temperature": []
142-
}))
143-
144133
_variant_name = "single-session-hybrid"
145134

146135
def _inference(self, batch: cebra.data.Batch) -> cebra.data.Batch:

docs/source/api.rst

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,9 @@ these components in other contexts and research code bases.
3838
api/pytorch/distributions
3939
api/pytorch/models
4040
api/pytorch/helpers
41-
api/xcebra/multiobjective
42-
api/xcebra/regularized
43-
api/xcebra/attribution
41+
api/pytorch/multiobjective
42+
api/pytorch/regularized
43+
api/pytorch/attribution
4444

4545
.. toctree::
4646
:hidden:

0 commit comments

Comments
 (0)