Skip to content

Commit 45f2e94

Browse files
committed
Parametrize offset models
1 parent e982248 commit 45f2e94

File tree

2 files changed

+34
-151
lines changed

2 files changed

+34
-151
lines changed

cebra/models/model.py

Lines changed: 33 additions & 150 deletions
Original file line numberDiff line numberDiff line change
@@ -260,59 +260,62 @@ def num_trainable_parameters(self) -> int:
260260
param.numel() for param in self.parameters() if param.requires_grad)
261261

262262

263-
@register("offset10-model")
264-
class Offset10Model(_OffsetModel, ConvolutionalModelMixin):
265-
"""CEBRA model with a 10 sample receptive field."""
263+
@parametrize("offset{n_offset}-model",
264+
n_offset=(5, 10, 15, 18, 20, 31, 36, 40, 50))
265+
class OffsetNModel(_OffsetModel, ConvolutionalModelMixin):
266+
"""CEBRA model with a `n_offset` sample receptive field.
266267
267-
def __init__(self, num_neurons, num_units, num_output, normalize=True):
268+
n_offset: The size of the receptive field.
269+
"""
270+
271+
def __init__(self,
272+
num_neurons,
273+
num_units,
274+
num_output,
275+
n_offset,
276+
normalize=True):
268277
if num_units < 1:
269278
raise ValueError(
270279
f"Hidden dimension needs to be at least 1, but got {num_units}."
271280
)
281+
282+
self.n_offset = n_offset
283+
284+
def _compute_num_layers(n_offset):
285+
"""Compute the number of layers to add on top of the first and last conv layers."""
286+
return (n_offset - 4) // 2 + self.n_offset % 2
287+
288+
last_layer_kernel = 3 if (self.n_offset % 2) == 0 else 2
272289
super().__init__(
273290
nn.Conv1d(num_neurons, num_units, 2),
274291
nn.GELU(),
275-
*self._make_layers(num_units, num_layers=3),
276-
nn.Conv1d(num_units, num_output, 3),
292+
*self._make_layers(num_units,
293+
num_layers=_compute_num_layers(self.n_offset)),
294+
nn.Conv1d(num_units, num_output, last_layer_kernel),
277295
num_input=num_neurons,
278296
num_output=num_output,
279297
normalize=normalize,
280298
)
281299

282300
def get_offset(self) -> cebra.data.datatypes.Offset:
283-
"""See :py:meth:`~.Model.get_offset`"""
284-
return cebra.data.Offset(5, 5)
301+
"""See `:py:meth:Model.get_offset`"""
302+
return cebra.data.Offset(self.n_offset // 2,
303+
self.n_offset // 2 + self.n_offset % 2)
285304

286305

287306
@register("offset10-model-mse")
288-
class Offset10ModelMSE(Offset10Model):
307+
class Offset10ModelMSE(OffsetNModel):
289308
"""Symmetric model with 10 sample receptive field, without normalization.
290309
291310
Suitable for use with InfoNCE metrics for Euclidean space.
292311
"""
293312

294313
def __init__(self, num_neurons, num_units, num_output, normalize=False):
295-
super().__init__(num_neurons, num_units, num_output, normalize)
296-
297-
298-
@register("offset5-model")
299-
class Offset5Model(_OffsetModel, ConvolutionalModelMixin):
300-
"""CEBRA model with a 5 sample receptive field and output normalization."""
301-
302-
def __init__(self, num_neurons, num_units, num_output, normalize=True):
303-
super().__init__(
304-
nn.Conv1d(num_neurons, num_units, 2),
305-
nn.GELU(),
306-
cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()),
307-
nn.Conv1d(num_units, num_output, 2),
308-
num_input=num_neurons,
309-
num_output=num_output,
310-
normalize=normalize,
311-
)
312-
313-
def get_offset(self) -> cebra.data.datatypes.Offset:
314-
"""See :py:meth:`~.Model.get_offset`"""
315-
return cebra.data.Offset(2, 3)
314+
super().__init__(num_neurons,
315+
num_units,
316+
num_output,
317+
n_offset=10,
318+
normalize=normalize)
316319

317320

318321
@register("offset1-model-mse")
@@ -666,30 +669,6 @@ def get_offset(self) -> cebra.data.datatypes.Offset:
666669
return cebra.data.Offset(0, 1)
667670

668671

669-
@register("offset36-model")
670-
class Offset36(_OffsetModel, ConvolutionalModelMixin):
671-
"""CEBRA model with a 10 sample receptive field."""
672-
673-
def __init__(self, num_neurons, num_units, num_output, normalize=True):
674-
if num_units < 1:
675-
raise ValueError(
676-
f"Hidden dimension needs to be at least 1, but got {num_units}."
677-
)
678-
super().__init__(
679-
nn.Conv1d(num_neurons, num_units, 2),
680-
nn.GELU(),
681-
*self._make_layers(num_units, num_layers=16),
682-
nn.Conv1d(num_units, num_output, 3),
683-
num_input=num_neurons,
684-
num_output=num_output,
685-
normalize=normalize,
686-
)
687-
688-
def get_offset(self) -> cebra.data.datatypes.Offset:
689-
"""See `:py:meth:Model.get_offset`"""
690-
return cebra.data.Offset(18, 18)
691-
692-
693672
@_register_conditionally("offset36-model-dropout")
694673
class Offset36Dropout(_OffsetModel, ConvolutionalModelMixin):
695674
"""CEBRA model with a 10 sample receptive field.
@@ -767,102 +746,6 @@ def get_offset(self) -> cebra.data.datatypes.Offset:
767746
return cebra.data.Offset(18, 18)
768747

769748

770-
@register("offset40-model")
771-
class Offset40(_OffsetModel, ConvolutionalModelMixin):
772-
"""CEBRA model with a 40 samples receptive field."""
773-
774-
def __init__(self, num_neurons, num_units, num_output, normalize=True):
775-
if num_units < 1:
776-
raise ValueError(
777-
f"Hidden dimension needs to be at least 1, but got {num_units}."
778-
)
779-
super().__init__(
780-
nn.Conv1d(num_neurons, num_units, 2),
781-
nn.GELU(),
782-
*self._make_layers(num_units, 18),
783-
nn.Conv1d(num_units, num_output, 3),
784-
num_input=num_neurons,
785-
num_output=num_output,
786-
normalize=normalize,
787-
)
788-
789-
def get_offset(self) -> cebra.data.datatypes.Offset:
790-
"""See `:py:meth:Model.get_offset`"""
791-
return cebra.data.Offset(20, 20)
792-
793-
794-
@register("offset50-model")
795-
class Offset50(_OffsetModel, ConvolutionalModelMixin):
796-
"""CEBRA model with a sample receptive field."""
797-
798-
def __init__(self, num_neurons, num_units, num_output, normalize=True):
799-
if num_units < 1:
800-
raise ValueError(
801-
f"Hidden dimension needs to be at least 1, but got {num_units}."
802-
)
803-
super().__init__(
804-
nn.Conv1d(num_neurons, num_units, 2),
805-
nn.GELU(),
806-
*self._make_layers(num_units, 23),
807-
nn.Conv1d(num_units, num_output, 3),
808-
num_input=num_neurons,
809-
num_output=num_output,
810-
normalize=normalize,
811-
)
812-
813-
def get_offset(self) -> cebra.data.datatypes.Offset:
814-
"""See `:py:meth:Model.get_offset`"""
815-
return cebra.data.Offset(25, 25)
816-
817-
818-
@register("offset15-model")
819-
class Offset15Model(_OffsetModel, ConvolutionalModelMixin):
820-
"""CEBRA model with a 15 sample receptive field."""
821-
822-
def __init__(self, num_neurons, num_units, num_output, normalize=True):
823-
if num_units < 1:
824-
raise ValueError(
825-
f"Hidden dimension needs to be at least 1, but got {num_units}."
826-
)
827-
super().__init__(
828-
nn.Conv1d(num_neurons, num_units, 2),
829-
nn.GELU(),
830-
*self._make_layers(num_units, num_layers=6),
831-
nn.Conv1d(num_units, num_output, 2),
832-
num_input=num_neurons,
833-
num_output=num_output,
834-
normalize=normalize,
835-
)
836-
837-
def get_offset(self) -> cebra.data.datatypes.Offset:
838-
"""See `:py:meth:Model.get_offset`"""
839-
return cebra.data.Offset(7, 8)
840-
841-
842-
@register("offset20-model")
843-
class Offset20Model(_OffsetModel, ConvolutionalModelMixin):
844-
"""CEBRA model with a 15 sample receptive field."""
845-
846-
def __init__(self, num_neurons, num_units, num_output, normalize=True):
847-
if num_units < 1:
848-
raise ValueError(
849-
f"Hidden dimension needs to be at least 1, but got {num_units}."
850-
)
851-
super().__init__(
852-
nn.Conv1d(num_neurons, num_units, 2),
853-
nn.GELU(),
854-
*self._make_layers(num_units, num_layers=8),
855-
nn.Conv1d(num_units, num_output, 3),
856-
num_input=num_neurons,
857-
num_output=num_output,
858-
normalize=normalize,
859-
)
860-
861-
def get_offset(self) -> cebra.data.datatypes.Offset:
862-
"""See `:py:meth:Model.get_offset`"""
863-
return cebra.data.Offset(10, 10)
864-
865-
866749
@register("offset10-model-mse-tanh")
867750
class Offset10Model(_OffsetModel, ConvolutionalModelMixin):
868751
"""CEBRA model with a 10 sample receptive field."""

docs/source/usage.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ We provide a set of pre-defined models. You can access (and search) a list of av
175175

176176
.. testoutput::
177177

178-
['offset10-model', 'offset10-model-mse', 'offset5-model', 'offset1-model-mse']
178+
['offset5-model', 'offset10-model', 'offset15-model', 'offset18-model']
179179

180180
Then, you can choose the one that fits best with your needs and provide it to the CEBRA model as the :py:attr:`~.CEBRA.model_architecture` parameter.
181181

0 commit comments

Comments
 (0)