@@ -260,59 +260,62 @@ def num_trainable_parameters(self) -> int:
260260 param .numel () for param in self .parameters () if param .requires_grad )
261261
262262
263- @register ("offset10-model" )
264- class Offset10Model (_OffsetModel , ConvolutionalModelMixin ):
265- """CEBRA model with a 10 sample receptive field."""
263+ @parametrize ("offset{n_offset}-model" ,
264+ n_offset = (5 , 10 , 15 , 18 , 20 , 31 , 36 , 40 , 50 ))
265+ class OffsetNModel (_OffsetModel , ConvolutionalModelMixin ):
266+ """CEBRA model with a `n_offset` sample receptive field.
266267
267- def __init__ (self , num_neurons , num_units , num_output , normalize = True ):
268+ n_offset: The size of the receptive field.
269+ """
270+
271+ def __init__ (self ,
272+ num_neurons ,
273+ num_units ,
274+ num_output ,
275+ n_offset ,
276+ normalize = True ):
268277 if num_units < 1 :
269278 raise ValueError (
270279 f"Hidden dimension needs to be at least 1, but got { num_units } ."
271280 )
281+
282+ self .n_offset = n_offset
283+
284+ def _compute_num_layers (n_offset ):
285+ """Compute the number of layers to add on top of the first and last conv layers."""
286+ return (n_offset - 4 ) // 2 + self .n_offset % 2
287+
288+ last_layer_kernel = 3 if (self .n_offset % 2 ) == 0 else 2
272289 super ().__init__ (
273290 nn .Conv1d (num_neurons , num_units , 2 ),
274291 nn .GELU (),
275- * self ._make_layers (num_units , num_layers = 3 ),
276- nn .Conv1d (num_units , num_output , 3 ),
292+ * self ._make_layers (num_units ,
293+ num_layers = _compute_num_layers (self .n_offset )),
294+ nn .Conv1d (num_units , num_output , last_layer_kernel ),
277295 num_input = num_neurons ,
278296 num_output = num_output ,
279297 normalize = normalize ,
280298 )
281299
282300 def get_offset (self ) -> cebra .data .datatypes .Offset :
283- """See :py:meth:`~.Model.get_offset`"""
284- return cebra .data .Offset (5 , 5 )
301+ """See `:py:meth:Model.get_offset`"""
302+ return cebra .data .Offset (self .n_offset // 2 ,
303+ self .n_offset // 2 + self .n_offset % 2 )
285304
286305
287306@register ("offset10-model-mse" )
288- class Offset10ModelMSE (Offset10Model ):
307+ class Offset10ModelMSE (OffsetNModel ):
289308 """Symmetric model with 10 sample receptive field, without normalization.
290309
291310 Suitable for use with InfoNCE metrics for Euclidean space.
292311 """
293312
294313 def __init__ (self , num_neurons , num_units , num_output , normalize = False ):
295- super ().__init__ (num_neurons , num_units , num_output , normalize )
296-
297-
298- @register ("offset5-model" )
299- class Offset5Model (_OffsetModel , ConvolutionalModelMixin ):
300- """CEBRA model with a 5 sample receptive field and output normalization."""
301-
302- def __init__ (self , num_neurons , num_units , num_output , normalize = True ):
303- super ().__init__ (
304- nn .Conv1d (num_neurons , num_units , 2 ),
305- nn .GELU (),
306- cebra_layers ._Skip (nn .Conv1d (num_units , num_units , 3 ), nn .GELU ()),
307- nn .Conv1d (num_units , num_output , 2 ),
308- num_input = num_neurons ,
309- num_output = num_output ,
310- normalize = normalize ,
311- )
312-
313- def get_offset (self ) -> cebra .data .datatypes .Offset :
314- """See :py:meth:`~.Model.get_offset`"""
315- return cebra .data .Offset (2 , 3 )
314+ super ().__init__ (num_neurons ,
315+ num_units ,
316+ num_output ,
317+ n_offset = 10 ,
318+ normalize = normalize )
316319
317320
318321@register ("offset1-model-mse" )
@@ -666,30 +669,6 @@ def get_offset(self) -> cebra.data.datatypes.Offset:
666669 return cebra .data .Offset (0 , 1 )
667670
668671
669- @register ("offset36-model" )
670- class Offset36 (_OffsetModel , ConvolutionalModelMixin ):
671- """CEBRA model with a 10 sample receptive field."""
672-
673- def __init__ (self , num_neurons , num_units , num_output , normalize = True ):
674- if num_units < 1 :
675- raise ValueError (
676- f"Hidden dimension needs to be at least 1, but got { num_units } ."
677- )
678- super ().__init__ (
679- nn .Conv1d (num_neurons , num_units , 2 ),
680- nn .GELU (),
681- * self ._make_layers (num_units , num_layers = 16 ),
682- nn .Conv1d (num_units , num_output , 3 ),
683- num_input = num_neurons ,
684- num_output = num_output ,
685- normalize = normalize ,
686- )
687-
688- def get_offset (self ) -> cebra .data .datatypes .Offset :
689- """See `:py:meth:Model.get_offset`"""
690- return cebra .data .Offset (18 , 18 )
691-
692-
693672@_register_conditionally ("offset36-model-dropout" )
694673class Offset36Dropout (_OffsetModel , ConvolutionalModelMixin ):
695674 """CEBRA model with a 10 sample receptive field.
@@ -767,102 +746,6 @@ def get_offset(self) -> cebra.data.datatypes.Offset:
767746 return cebra .data .Offset (18 , 18 )
768747
769748
770- @register ("offset40-model" )
771- class Offset40 (_OffsetModel , ConvolutionalModelMixin ):
772- """CEBRA model with a 40 samples receptive field."""
773-
774- def __init__ (self , num_neurons , num_units , num_output , normalize = True ):
775- if num_units < 1 :
776- raise ValueError (
777- f"Hidden dimension needs to be at least 1, but got { num_units } ."
778- )
779- super ().__init__ (
780- nn .Conv1d (num_neurons , num_units , 2 ),
781- nn .GELU (),
782- * self ._make_layers (num_units , 18 ),
783- nn .Conv1d (num_units , num_output , 3 ),
784- num_input = num_neurons ,
785- num_output = num_output ,
786- normalize = normalize ,
787- )
788-
789- def get_offset (self ) -> cebra .data .datatypes .Offset :
790- """See `:py:meth:Model.get_offset`"""
791- return cebra .data .Offset (20 , 20 )
792-
793-
794- @register ("offset50-model" )
795- class Offset50 (_OffsetModel , ConvolutionalModelMixin ):
796- """CEBRA model with a sample receptive field."""
797-
798- def __init__ (self , num_neurons , num_units , num_output , normalize = True ):
799- if num_units < 1 :
800- raise ValueError (
801- f"Hidden dimension needs to be at least 1, but got { num_units } ."
802- )
803- super ().__init__ (
804- nn .Conv1d (num_neurons , num_units , 2 ),
805- nn .GELU (),
806- * self ._make_layers (num_units , 23 ),
807- nn .Conv1d (num_units , num_output , 3 ),
808- num_input = num_neurons ,
809- num_output = num_output ,
810- normalize = normalize ,
811- )
812-
813- def get_offset (self ) -> cebra .data .datatypes .Offset :
814- """See `:py:meth:Model.get_offset`"""
815- return cebra .data .Offset (25 , 25 )
816-
817-
818- @register ("offset15-model" )
819- class Offset15Model (_OffsetModel , ConvolutionalModelMixin ):
820- """CEBRA model with a 15 sample receptive field."""
821-
822- def __init__ (self , num_neurons , num_units , num_output , normalize = True ):
823- if num_units < 1 :
824- raise ValueError (
825- f"Hidden dimension needs to be at least 1, but got { num_units } ."
826- )
827- super ().__init__ (
828- nn .Conv1d (num_neurons , num_units , 2 ),
829- nn .GELU (),
830- * self ._make_layers (num_units , num_layers = 6 ),
831- nn .Conv1d (num_units , num_output , 2 ),
832- num_input = num_neurons ,
833- num_output = num_output ,
834- normalize = normalize ,
835- )
836-
837- def get_offset (self ) -> cebra .data .datatypes .Offset :
838- """See `:py:meth:Model.get_offset`"""
839- return cebra .data .Offset (7 , 8 )
840-
841-
842- @register ("offset20-model" )
843- class Offset20Model (_OffsetModel , ConvolutionalModelMixin ):
844- """CEBRA model with a 15 sample receptive field."""
845-
846- def __init__ (self , num_neurons , num_units , num_output , normalize = True ):
847- if num_units < 1 :
848- raise ValueError (
849- f"Hidden dimension needs to be at least 1, but got { num_units } ."
850- )
851- super ().__init__ (
852- nn .Conv1d (num_neurons , num_units , 2 ),
853- nn .GELU (),
854- * self ._make_layers (num_units , num_layers = 8 ),
855- nn .Conv1d (num_units , num_output , 3 ),
856- num_input = num_neurons ,
857- num_output = num_output ,
858- normalize = normalize ,
859- )
860-
861- def get_offset (self ) -> cebra .data .datatypes .Offset :
862- """See `:py:meth:Model.get_offset`"""
863- return cebra .data .Offset (10 , 10 )
864-
865-
866749@register ("offset10-model-mse-tanh" )
867750class Offset10Model (_OffsetModel , ConvolutionalModelMixin ):
868751 """CEBRA model with a 10 sample receptive field."""
0 commit comments