Skip to content

Commit 1d10668

Browse files
committed
Add legacy multiobjective model for backward compat
1 parent 1d69957 commit 1d10668

File tree

3 files changed

+190
-4
lines changed

3 files changed

+190
-4
lines changed

cebra/models/multiobjective.py

Lines changed: 135 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
# limitations under the License.
2121
#
2222
import itertools
23-
from typing import List
23+
from typing import List, Tuple
2424

2525
import torch
2626
from torch import nn
@@ -106,6 +106,140 @@ def forward(self, inp):
106106
return inp / torch.norm(inp, dim=1, keepdim=True)
107107

108108

109+
class LegacyMultiobjectiveModel(nn.Module):
110+
"""Wrapper around contrastive learning models to all training with multiple objectives
111+
112+
Multi-objective training splits the last layer's feature representation into multiple
113+
chunks, which are then used for individual training objectives.
114+
115+
Args:
116+
module: The module to wrap
117+
dimensions: A tuple of dimension values to extract from the model's feature embedding.
118+
renormalize: If True, the individual feature slices will be re-normalized before
119+
getting returned---this option only makes sense in conjunction with a loss based
120+
on the cosine distance or dot product.
121+
output_mode: A mode as defined in ``MultiobjectiveModel.Mode``. Overlapping means that
122+
when ``dimensions`` are set to `(x0, x1, ...)``, features will be extracted from
123+
``0:x0, 0:x1, ...``. When mode is set to separate, features are extracted from
124+
``x0:x1, x1:x2, ...``.
125+
append_last_dimension: Defaults to True, and will allow to omit the last dimension in
126+
the ``dimensions`` argument (which should be equal to the output dimension) of the
127+
given model.
128+
129+
TODO:
130+
- Update nn.Module type annotation for ``module`` to cebra.models.Model
131+
"""
132+
133+
class Mode:
134+
"""Mode for slicing and potentially normalizing the output embedding.
135+
136+
The options are:
137+
138+
- ``OVERLAPPING``: When ``dimensions`` are set to `(x0, x1, ...)``, features will be
139+
extracted from ``0:x0, 0:x1, ...``.
140+
- ``SEPARATE``: Features are extracted from ``x0:x1, x1:x2, ...``
141+
142+
"""
143+
144+
OVERLAPPING = "overlapping"
145+
SEPARATE = "separate"
146+
_ALL = {OVERLAPPING, SEPARATE}
147+
148+
def is_valid(self, mode):
149+
"""Check if a given string representation is valid.
150+
151+
Args:
152+
mode: String representation of the mode.
153+
154+
Returns:
155+
``True`` for a valid representation, ``False`` otherwise.
156+
"""
157+
return mode in _ALL # noqa: F821
158+
159+
def __init__(
160+
self,
161+
module: nn.Module,
162+
dimensions: Tuple[int],
163+
renormalize: bool = False,
164+
output_mode: str = "overlapping",
165+
append_last_dimension: bool = False,
166+
):
167+
super().__init__()
168+
169+
if not isinstance(module, cebra.models.Model):
170+
raise ValueError("Can only wrap models that are subclassing the "
171+
"cebra.models.Model abstract base class. "
172+
f"Got a model of type {type(module)}.")
173+
174+
self.module = module
175+
self.renormalize = renormalize
176+
self.output_mode = output_mode
177+
178+
self._norm = _Norm()
179+
self._compute_slices(dimensions, append_last_dimension)
180+
181+
@property
182+
def get_offset(self):
183+
"""See :py:meth:`cebra.models.model.Model.get_offset`."""
184+
return self.module.get_offset
185+
186+
@property
187+
def num_output(self):
188+
"""See :py:attr:`cebra.models.model.Model.num_output`."""
189+
return self.module.num_output
190+
191+
def _compute_slices(self, dimensions, append_last_dimension):
192+
193+
def _valid_dimensions(dimensions):
194+
return max(dimensions) == self.num_output
195+
196+
if append_last_dimension:
197+
if _valid_dimensions(dimensions):
198+
raise ValueError(
199+
f"append_last_dimension should only be used if extra values are "
200+
f"available. Last requested dimensionality is already {dimensions[-1]}."
201+
)
202+
dimensions += (self.num_output,)
203+
if not _valid_dimensions(dimensions):
204+
raise ValueError(
205+
f"Max of given dimensions needs to match the number of outputs "
206+
f"in the encoder network. Got {dimensions} and expected a "
207+
f"maximum value of {self.num_output}.")
208+
209+
if self.output_mode == self.Mode.OVERLAPPING:
210+
self.feature_ranges = tuple(
211+
slice(0, dimension) for dimension in dimensions)
212+
elif self.output_mode == self.Mode.SEPARATE:
213+
from_dimension = (0,) + dimensions
214+
self.feature_ranges = tuple(
215+
slice(i, j) for i, j in zip(from_dimension, dimensions))
216+
else:
217+
raise ValueError(
218+
f"Unknown mode: '{self.output_mode}', use one of {self.Mode._ALL}."
219+
)
220+
221+
def forward(self, inputs):
222+
"""Compute multiple embeddings for a single signal input.
223+
224+
Args:
225+
inputs: The input tensor
226+
227+
Returns:
228+
A tuple of tensors which are sliced according to `self.feature_ranges`
229+
if `renormalize` is set to true, each of the tensors will be normalized
230+
across the first (feature) dimension.
231+
232+
TODO:
233+
- Cover this function with unit tests
234+
"""
235+
output = self.module(inputs)
236+
outputs = (
237+
output[:, slice_features] for slice_features in self.feature_ranges)
238+
if self.renormalize:
239+
outputs = (self._norm(output) for output in outputs)
240+
return tuple(outputs)
241+
242+
109243
class MultiobjectiveModel(nn.Module):
110244
"""Wrapper around contrastive learning models to all training with multiple objectives
111245

cebra/solver/base.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def load_state_dict(self, state_dict: dict, strict: bool = True):
120120
to partially load the state for all given keys.
121121
"""
122122

123-
def _contains(key):
123+
def _contains(key, strict=strict):
124124
if key in state_dict:
125125
return True
126126
elif strict:
@@ -146,7 +146,8 @@ def _get(key):
146146
self.decode_history = _get("decode")
147147
if _contains("log"):
148148
self.log = _get("log")
149-
if _contains("metadata"):
149+
# NOTE(stes): Added in CEBRA 0.6.0
150+
if _contains("metadata", strict=False):
150151
self.metadata = _get("metadata")
151152

152153
@property
@@ -405,7 +406,7 @@ def num_total_features(self):
405406
def __post_init__(self):
406407
super().__post_init__()
407408
self._check_dimensions()
408-
self.model = cebra.models.MultiobjectiveModel(
409+
self.model = cebra.models.LegacyMultiobjectiveModel(
409410
self.model,
410411
dimensions=(self.num_behavior_features, self.model.num_output),
411412
renormalize=self.renormalize_features,

tests/test_models.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
import pytest
2525
import torch
26+
from torch import nn
2627

2728
import cebra.models
2829
import cebra.models.model
@@ -87,6 +88,56 @@ def test_offset_models(model_name, batch_size, input_length):
8788
assert len(outputs) == batch_size
8889

8990

91+
def test_multiobjective():
92+
93+
# NOTE(stes): This test is deprecated and will be removed in a future version.
94+
# As of CEBRA 0.6.0, the multi objective models are tested separately in
95+
# test_multiobjective.py.
96+
97+
class TestModel(cebra.models.Model):
98+
99+
def __init__(self):
100+
super().__init__(num_input=10, num_output=10)
101+
self._model = nn.Linear(self.num_input, self.num_output)
102+
103+
def forward(self, x):
104+
return self._model(x)
105+
106+
@property
107+
def get_offset(self):
108+
return None
109+
110+
model = TestModel()
111+
112+
multi_model_overlap = cebra.models.LegacyMultiobjectiveModel(
113+
model,
114+
dimensions=(4, 6),
115+
output_mode="overlapping",
116+
append_last_dimension=True)
117+
multi_model_separate = cebra.models.LegacyMultiobjectiveModel(
118+
model,
119+
dimensions=(4, 6),
120+
output_mode="separate",
121+
append_last_dimension=True)
122+
123+
x = torch.randn(5, 10)
124+
125+
assert model(x).shape == (5, 10)
126+
127+
assert model.num_output == multi_model_overlap.num_output
128+
assert model.get_offset == multi_model_overlap.get_offset
129+
130+
first, second, third = multi_model_overlap(x)
131+
assert first.shape == (5, 4)
132+
assert second.shape == (5, 6)
133+
assert third.shape == (5, 10)
134+
135+
first, second, third = multi_model_separate(x)
136+
assert first.shape == (5, 4)
137+
assert second.shape == (5, 2)
138+
assert third.shape == (5, 4)
139+
140+
90141
@pytest.mark.parametrize("version,raises", [
91142
["1.12", False],
92143
["2.", False],

0 commit comments

Comments
 (0)