Skip to content

Commit 64d1db8

Browse files
committed
Make xCEBRA compatible with the batched inference & padding in solver
1 parent a1218aa commit 64d1db8

File tree

2 files changed

+16
-4
lines changed

2 files changed

+16
-4
lines changed

cebra/solver/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -599,7 +599,7 @@ def transform(self,
599599
session_id: The session ID, an :py:class:`int` between 0 and
600600
the number of sessions -1 for multisession, and set to
601601
``None`` for single session.
602-
batch_size: If not None, batched inference will be applied.
602+
batch_size: If not None, batched inference will not be applied.
603603
604604
Returns:
605605
The output embedding.

cebra/solver/multiobjective.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
import cebra.data
5454
import cebra.io
5555
import cebra.models
56+
import cebra.solver.single_session as cebra_solver_single
5657
from cebra.solver import register
5758
from cebra.solver.base import Solver
5859
from cebra.solver.schedulers import Scheduler
@@ -187,7 +188,7 @@ def _process_info(self, info):
187188

188189

189190
@dataclasses.dataclass
190-
class MultiobjectiveSolverBase(Solver):
191+
class MultiobjectiveSolverBase(cebra_solver_single.SingleSessionSolver):
191192

192193
feature_ranges: List[slice] = None
193194
renormalize: bool = None
@@ -209,6 +210,13 @@ def __post_init__(self):
209210
renormalize=self.renormalize,
210211
)
211212

213+
def parameters(self, session_id: Optional[int] = None):
214+
"""Iterate over all parameters."""
215+
super().parameters(session_id=session_id)
216+
217+
for parameter in self.regularizer.parameters():
218+
yield parameter
219+
212220
def fit(self,
213221
loader: cebra.data.Loader,
214222
valid_loader: cebra.data.Loader = None,
@@ -241,6 +249,7 @@ def _run_validation():
241249
save_hook(solver=self, step=num_steps)
242250
return stats_val
243251

252+
self._set_fitted_params(loader)
244253
self.to(loader.device)
245254

246255
iterator = self._get_loader(loader,
@@ -393,11 +402,14 @@ def validation(
393402
logger=None,
394403
weights_loss: Optional[List[float]] = None,
395404
):
405+
loader.dataset.configure_for(self.model)
406+
iterator = self._get_loader(loader)
407+
396408
self.model.eval()
397409
total_loss = Meter()
398410

399411
losses_dict = {}
400-
for _, batch in enumerate(loader):
412+
for _, batch in iterator:
401413
predictions = self._inference(batch)
402414
losses = self.criterion(predictions)
403415

@@ -445,7 +457,7 @@ def validation(
445457
return stats_val
446458

447459
@torch.no_grad()
448-
def transform(self, inputs: torch.Tensor) -> torch.Tensor:
460+
def transform_deprecated(self, inputs: torch.Tensor) -> torch.Tensor:
449461
offset = self.model.get_offset()
450462
self.model.eval()
451463
X = inputs.cpu().numpy()

0 commit comments

Comments
 (0)