Skip to content

Commit c9fa5c8

Browse files
committed
Fix import errors
1 parent c5dc011 commit c9fa5c8

File tree

2 files changed

+13
-7
lines changed

2 files changed

+13
-7
lines changed

cebra/data/multi_session.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,12 @@
2626

2727
import literate_dataclasses as dataclasses
2828
import torch
29+
import torch.nn as nn
2930

3031
import cebra.data as cebra_data
3132
import cebra.distributions
3233
from cebra.data.datatypes import Batch
3334
from cebra.data.datatypes import BatchIndex
34-
from cebra.models import Model
3535

3636
__all__ = [
3737
"MultiSessionDataset",
@@ -105,7 +105,7 @@ def load_batch(self, index: BatchIndex) -> List[Batch]:
105105
) for session_id, session in enumerate(self.iter_sessions())
106106
]
107107

108-
def configure_for(self, model: "Model"):
108+
def configure_for(self, model: "cebra.models.Model"):
109109
"""Configure the dataset offset for the provided model.
110110
111111
Call this function before indexing the dataset. This sets the
@@ -114,9 +114,16 @@ def configure_for(self, model: "Model"):
114114
Args:
115115
model: The model to configure the dataset for.
116116
"""
117-
self.offset = model.get_offset()
118-
for session in self.iter_sessions():
119-
session.configure_for(model)
117+
if not isinstance(model, nn.ModuleList):
118+
raise ValueError(
119+
"The model must be a nn.ModuleList to configure the dataset.")
120+
if len(model) != self.num_sessions:
121+
raise ValueError(
122+
f"The model must have {self.num_sessions} sessions, but got {len(model)}."
123+
)
124+
125+
for i, session in enumerate(self.iter_sessions()):
126+
session.configure_for(model[i])
120127

121128

122129
@dataclasses.dataclass

tests/_utils_deprecated.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
import cebra
1010
import cebra.integrations.sklearn.utils as sklearn_utils
1111
import cebra.models
12-
import cebra.solvers
1312

1413

1514
#NOTE: Deprecated: transform is now handled in the solver but the original
@@ -79,7 +78,7 @@ def cebra_transform_deprecated(cebra_model,
7978
# using the transform method of the model, and handling padding is implemented
8079
# directly in the base Solver. This method is kept for testing purposes.
8180
@torch.no_grad()
82-
def multiobjective_transform_deprecated(solver: cebra.solvers.Solver,
81+
def multiobjective_transform_deprecated(solver: "cebra.solvers.Solver",
8382
inputs: torch.Tensor) -> torch.Tensor:
8483
"""Transform the input data using the model.
8584

0 commit comments

Comments
 (0)