2626
2727import literate_dataclasses as dataclasses
2828import torch
29+ import torch .nn as nn
2930
3031import cebra .data as cebra_data
3132import cebra .distributions
3233from cebra .data .datatypes import Batch
3334from cebra .data .datatypes import BatchIndex
34- from cebra .models import Model
3535
3636__all__ = [
3737 "MultiSessionDataset" ,
@@ -105,7 +105,7 @@ def load_batch(self, index: BatchIndex) -> List[Batch]:
105105 ) for session_id , session in enumerate (self .iter_sessions ())
106106 ]
107107
108- def configure_for (self , model : "Model" ):
108+ def configure_for (self , model : "cebra.models. Model" ):
109109 """Configure the dataset offset for the provided model.
110110
111111 Call this function before indexing the dataset. This sets the
@@ -114,9 +114,16 @@ def configure_for(self, model: "Model"):
114114 Args:
115115 model: The model to configure the dataset for.
116116 """
117- self .offset = model .get_offset ()
118- for session in self .iter_sessions ():
119- session .configure_for (model )
117+ if not isinstance (model , nn .ModuleList ):
118+ raise ValueError (
119+ "The model must be a nn.ModuleList to configure the dataset." )
120+ if len (model ) != self .num_sessions :
121+ raise ValueError (
122+ f"The model must have { self .num_sessions } sessions, but got { len (model )} ."
123+ )
124+
125+ for i , session in enumerate (self .iter_sessions ()):
126+ session .configure_for (model [i ])
120127
121128
122129@dataclasses .dataclass
0 commit comments