Skip to content

Commit cf5f5a2

Browse files
committed
revert unneeded changes in solver
1 parent 6b6af82 commit cf5f5a2

File tree

1 file changed

+8
-36
lines changed

1 file changed

+8
-36
lines changed

cebra/solver/base.py

Lines changed: 8 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,7 @@
3131
"""
3232

3333
import abc
34-
import logging
3534
import os
36-
import time
3735
from typing import Callable, Dict, List, Literal, Optional
3836

3937
import literate_dataclasses as dataclasses
@@ -72,10 +70,6 @@ class Solver(abc.ABC, cebra.io.HasDevice):
7270
optimizer: torch.optim.Optimizer
7371
history: List = dataclasses.field(default_factory=list)
7472
decode_history: List = dataclasses.field(default_factory=list)
75-
metadata: Dict = dataclasses.field(default_factory=lambda: ({
76-
"timestamp": None,
77-
"batches_seen": None,
78-
}))
7973
log: Dict = dataclasses.field(default_factory=lambda: ({
8074
"pos": [],
8175
"neg": [],
@@ -84,8 +78,6 @@ class Solver(abc.ABC, cebra.io.HasDevice):
8478
}))
8579
tqdm_on: bool = True
8680

87-
#metrics: MetricCollection = None
88-
8981
def __post_init__(self):
9082
cebra.io.HasDevice.__init__(self)
9183
self.best_loss = float("inf")
@@ -105,7 +97,6 @@ def state_dict(self) -> dict:
10597
"loss": torch.tensor(self.history),
10698
"decode": self.decode_history,
10799
"criterion": self.criterion.state_dict(),
108-
"metadata": self.metadata,
109100
"version": cebra.__version__,
110101
"log": self.log,
111102
}
@@ -120,7 +111,7 @@ def load_state_dict(self, state_dict: dict, strict: bool = True):
120111
to partially load the state for all given keys.
121112
"""
122113

123-
def _contains(key, strict=strict):
114+
def _contains(key):
124115
if key in state_dict:
125116
return True
126117
elif strict:
@@ -146,9 +137,6 @@ def _get(key):
146137
self.decode_history = _get("decode")
147138
if _contains("log"):
148139
self.log = _get("log")
149-
# NOTE(stes): Added in CEBRA 0.6.0
150-
if _contains("metadata", strict=False):
151-
self.metadata = _get("metadata")
152140

153141
@property
154142
def num_parameters(self) -> int:
@@ -163,26 +151,21 @@ def parameters(self):
163151
for parameter in self.criterion.parameters():
164152
yield parameter
165153

166-
def _get_loader(self, loader, **kwargs):
167-
return ProgressBar(loader=loader,
168-
log_format="tqdm" if self.tqdm_on else "off",
169-
**kwargs)
170-
171-
def _update_metadata(self, num_steps):
172-
self.metadata["timestamp"] = time.time()
173-
self.metadata["batches_seen"] = num_steps
154+
def _get_loader(self, loader):
155+
return ProgressBar(
156+
loader,
157+
"tqdm" if self.tqdm_on else "off",
158+
)
174159

175160
def fit(self,
176161
loader: cebra.data.Loader,
177162
valid_loader: cebra.data.Loader = None,
178163
*,
179164
save_frequency: int = None,
180165
valid_frequency: int = None,
181-
log_frequency: int = None,
182166
decode: bool = False,
183167
logdir: str = None,
184-
save_hook: Callable[[int, "Solver"], None] = None,
185-
logger: logging.Logger = None):
168+
save_hook: Callable[[int, "Solver"], None] = None):
186169
"""Train model for the specified number of steps.
187170
188171
Args:
@@ -192,27 +175,20 @@ def fit(self,
192175
save_frequency: If not `None`, the frequency for automatically saving model checkpoints
193176
to `logdir`.
194177
valid_frequency: The frequency for running validation on the ``valid_loader`` instance.
195-
log_frequency: TODO
196178
logdir: The logging directory for writing model checkpoints. The checkpoints
197179
can be read again using the `solver.load` function, or manually via loading the
198180
state dict.
199-
logger: TODO
200181
201182
TODO:
202183
* Refine the API here. Drop the validation entirely, and implement this via a hook?
203184
"""
204185

205186
self.to(loader.device)
206187

207-
iterator = self._get_loader(loader,
208-
logger=logger,
209-
log_frequency=log_frequency)
210-
211188
iterator = self._get_loader(loader)
212189
self.model.train()
213190
for num_steps, batch in iterator:
214191
stats = self.step(batch)
215-
self._update_metadata(num_steps)
216192
iterator.set_description(stats)
217193

218194
if save_frequency is None:
@@ -476,15 +452,11 @@ def step(self, batch: cebra.data.Batch) -> dict:
476452
self.optimizer.step()
477453
self.history.append(loss.item())
478454

479-
stats = dict(
455+
return dict(
480456
behavior_pos=behavior_align.item(),
481457
behavior_neg=behavior_uniform.item(),
482458
behavior_total=behavior_loss.item(),
483459
time_pos=time_align.item(),
484460
time_neg=time_uniform.item(),
485461
time_total=time_loss.item(),
486462
)
487-
488-
for key, value in stats.items():
489-
self.log[key].append(value)
490-
return stats

0 commit comments

Comments
 (0)