Skip to content

Commit 9d982be

Browse files
committed
Revert unneeded updates to the solver
1 parent 2db1d22 commit 9d982be

File tree

3 files changed

+2
-121
lines changed

3 files changed

+2
-121
lines changed

cebra/solver/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@
3636

3737
# pylint: disable=wrong-import-position
3838
from cebra.solver.base import *
39-
from cebra.solver.metrics import *
4039
from cebra.solver.multi_session import *
4140
from cebra.solver.multiobjective import *
4241
from cebra.solver.regularized import *

cebra/solver/metrics.py

Lines changed: 0 additions & 103 deletions
This file was deleted.

cebra/solver/util.py

Lines changed: 2 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
#
2222
"""Utility functions for solvers and their training loops."""
2323

24-
import logging
2524
from collections.abc import Iterable
2625
from typing import Dict
2726

@@ -30,7 +29,7 @@
3029

3130

3231
def _description(stats: Dict[str, float]):
33-
stats_str = [f"{key}: {value:.3f}" for key, value in stats.items()]
32+
stats_str = [f"{key}: {value: .4f}" for key, value in stats.items()]
3433
return " ".join(stats_str)
3534

3635

@@ -74,9 +73,7 @@ class ProgressBar:
7473
"Log and display values during training."
7574

7675
loader: Iterable
77-
logger: logging.Logger = None
78-
log_format: str = None
79-
log_frequency: int = None
76+
log_format: str
8077

8178
_valid_formats = ["tqdm", "off"]
8279

@@ -90,23 +87,13 @@ def __post_init__(self):
9087
raise ValueError(
9188
f"log_format must be one of {self._valid_formats}, "
9289
f"but got {self.log_formats}")
93-
self._stats = None
9490

9591
def __iter__(self):
9692
self.iterator = self.loader
9793
if self.use_tqdm:
9894
self.iterator = tqdm.tqdm(self.iterator)
9995
for num_batch, batch in enumerate(self.iterator):
10096
yield num_batch, batch
101-
self._log_message(num_batch, self._stats)
102-
self._log_message(num_batch, self._stats)
103-
104-
def _log_message(self, num_steps, stats):
105-
if self.logger is None:
106-
return
107-
if num_steps % self.log_frequency != 0:
108-
return
109-
self.logger.info(f"Train: Step {num_steps} {_description(stats)}")
11097

11198
def set_description(self, stats: Dict[str, float]):
11299
"""Update the progress bar description.
@@ -119,5 +106,3 @@ def set_description(self, stats: Dict[str, float]):
119106
"""
120107
if self.use_tqdm:
121108
self.iterator.set_description(_description(stats))
122-
123-
self._stats = stats

0 commit comments

Comments
 (0)