3131"""
3232
3333import abc
34- import logging
3534import os
36- import time
3735from typing import Callable , Dict , List , Literal , Optional
3836
3937import literate_dataclasses as dataclasses
@@ -72,10 +70,6 @@ class Solver(abc.ABC, cebra.io.HasDevice):
7270 optimizer : torch .optim .Optimizer
7371 history : List = dataclasses .field (default_factory = list )
7472 decode_history : List = dataclasses .field (default_factory = list )
75- metadata : Dict = dataclasses .field (default_factory = lambda : ({
76- "timestamp" : None ,
77- "batches_seen" : None ,
78- }))
7973 log : Dict = dataclasses .field (default_factory = lambda : ({
8074 "pos" : [],
8175 "neg" : [],
@@ -84,8 +78,6 @@ class Solver(abc.ABC, cebra.io.HasDevice):
8478 }))
8579 tqdm_on : bool = True
8680
87- #metrics: MetricCollection = None
88-
8981 def __post_init__ (self ):
9082 cebra .io .HasDevice .__init__ (self )
9183 self .best_loss = float ("inf" )
@@ -105,7 +97,6 @@ def state_dict(self) -> dict:
10597 "loss" : torch .tensor (self .history ),
10698 "decode" : self .decode_history ,
10799 "criterion" : self .criterion .state_dict (),
108- "metadata" : self .metadata ,
109100 "version" : cebra .__version__ ,
110101 "log" : self .log ,
111102 }
@@ -120,7 +111,7 @@ def load_state_dict(self, state_dict: dict, strict: bool = True):
120111 to partially load the state for all given keys.
121112 """
122113
123- def _contains (key , strict = strict ):
114+ def _contains (key ):
124115 if key in state_dict :
125116 return True
126117 elif strict :
@@ -146,9 +137,6 @@ def _get(key):
146137 self .decode_history = _get ("decode" )
147138 if _contains ("log" ):
148139 self .log = _get ("log" )
149- # NOTE(stes): Added in CEBRA 0.6.0
150- if _contains ("metadata" , strict = False ):
151- self .metadata = _get ("metadata" )
152140
153141 @property
154142 def num_parameters (self ) -> int :
@@ -163,26 +151,21 @@ def parameters(self):
163151 for parameter in self .criterion .parameters ():
164152 yield parameter
165153
166- def _get_loader (self , loader , ** kwargs ):
167- return ProgressBar (loader = loader ,
168- log_format = "tqdm" if self .tqdm_on else "off" ,
169- ** kwargs )
170-
171- def _update_metadata (self , num_steps ):
172- self .metadata ["timestamp" ] = time .time ()
173- self .metadata ["batches_seen" ] = num_steps
154+ def _get_loader (self , loader ):
155+ return ProgressBar (
156+ loader ,
157+ "tqdm" if self .tqdm_on else "off" ,
158+ )
174159
175160 def fit (self ,
176161 loader : cebra .data .Loader ,
177162 valid_loader : cebra .data .Loader = None ,
178163 * ,
179164 save_frequency : int = None ,
180165 valid_frequency : int = None ,
181- log_frequency : int = None ,
182166 decode : bool = False ,
183167 logdir : str = None ,
184- save_hook : Callable [[int , "Solver" ], None ] = None ,
185- logger : logging .Logger = None ):
168+ save_hook : Callable [[int , "Solver" ], None ] = None ):
186169 """Train model for the specified number of steps.
187170
188171 Args:
@@ -192,27 +175,20 @@ def fit(self,
192175 save_frequency: If not `None`, the frequency for automatically saving model checkpoints
193176 to `logdir`.
194177 valid_frequency: The frequency for running validation on the ``valid_loader`` instance.
195- log_frequency: TODO
196178 logdir: The logging directory for writing model checkpoints. The checkpoints
197179 can be read again using the `solver.load` function, or manually via loading the
198180 state dict.
199- logger: TODO
200181
201182 TODO:
202183 * Refine the API here. Drop the validation entirely, and implement this via a hook?
203184 """
204185
205186 self .to (loader .device )
206187
207- iterator = self ._get_loader (loader ,
208- logger = logger ,
209- log_frequency = log_frequency )
210-
211188 iterator = self ._get_loader (loader )
212189 self .model .train ()
213190 for num_steps , batch in iterator :
214191 stats = self .step (batch )
215- self ._update_metadata (num_steps )
216192 iterator .set_description (stats )
217193
218194 if save_frequency is None :
@@ -476,15 +452,11 @@ def step(self, batch: cebra.data.Batch) -> dict:
476452 self .optimizer .step ()
477453 self .history .append (loss .item ())
478454
479- stats = dict (
455+ return dict (
480456 behavior_pos = behavior_align .item (),
481457 behavior_neg = behavior_uniform .item (),
482458 behavior_total = behavior_loss .item (),
483459 time_pos = time_align .item (),
484460 time_neg = time_uniform .item (),
485461 time_total = time_loss .item (),
486462 )
487-
488- for key , value in stats .items ():
489- self .log [key ].append (value )
490- return stats
0 commit comments