-
Notifications
You must be signed in to change notification settings - Fork 0
Callbacks Documentation
#torchmanager/callbacks
An empty basic training callback
- Methods
-
on_batch_end- The callback when batch ends
- Parameters:
- batch: An
intof batch index - summary: A
dictof summary with name instrand value infloat
- batch: An
-
on_batch_start- The callback when batch starts
- Parameters:
- batch: An
intof batch index
- batch: An
- Parameters:
- The callback when batch starts
-
on_epoch_end- The callback when batch ends
- Parameters:
- epoch: An
intof epoch index - summary: A
dictof training summary with name instrand value infloat - val_summary: A
dictof validation summary with name instrand value infloat
- epoch: An
- Parameters:
- The callback when batch ends
-
on_epoch_start- The callback when epoch starts
- Parameters:
- epoch: An
intof epoch index
- epoch: An
-
on_train_end- The callback when training ends
-
-
on_train_start- The callback when training starts
- Parameters:
- initial_epoch: An
intof initial epoch index
The early stop callback that raises StopTraining error during the training if monitored metric not improved for several steps
- extends:
Callback
- Properties:
- monitor: A
strof monitored metric - monitor_type: AMonitorTypeof eitherMINofMAXmode for the best model - steps: Anintof steps to monitor
The callback that wraps last and best checkpoints in checkpoints folder by last.model and best_*.model with tensorboard logs in data folder together into a wrapped *.exp file
- extends:
.callback.Callback - requires:
tensorboardpackage
- Properties:
- best_ckpts: A
listof.ckpt.BestCheckpointcallbacks that records best checkpoints - last_ckpt: A.ckpt.LastCheckpointcallback that records the last checkpoint - tensorboard: A.ckpt.TensorBoardcallback that records data to tensorboard
A callback with frequency control
- extends:
Callbacks - abstract class that needs implementation of
stepmethod
- Properties:
- current_step: An
intof the current step index- freq: A
WeightUpdateFreqof the frequency type to update the weight
- freq: A
An abstract dynamic weight callback that set weight dynamically
- extends:
.callback.FrequencyCallback - abstract class that needs implementation of
stepmethod
A dynamic weight callback that set weight dynamically with lambda function
- extends:
DynamicWeight - Targeting to any object that performs to
.protocol.Weightedprotocol:
from torchmanager import losses
loss_fi = losses.Loss(...) # where torchmanager.losses.Loss` performs to `.protocols.Weighted` protocol
- Passing defined functions into the
DynamicWeightcallback:
def weight_fn(step: int) -> int: ...
dynamic_weight_callback = LambdaDynamicWeight(weight_fn, loss_fn)
- Or using Python lambda functions:
dyncami_weight_callback = LambdaDynamicWeight(lambda e: ..., loss_fn)
- Add to callbacks list and parsing to
fitfunction:
from torchmanager import Manager
manager = Manager(..., loss_fn=loss_fn, ...)
callbacks_list = [..., dynamic_weight_callback]
manager.fit(..., callbacks_list=callbacks_list)
The callback to save the last checkpoint during training
- extends:
Callback
- Properties:
- ckpt_path: A
strof checkpoint path
- ckpt_path: A
The callback to step learning rate scheduler
- extends:
Callback
- Parameters:
- freq: An
_lr.LrScheduleFreqof the frequency to update learning rate
- freq: An
The enum of monitor types
The callback to save the latest checkpoint for each epoch
- extends:
LastCheckpoint
- Properties:
- best_score: A
floatof the best score to be monitored - monitor: Astrof the summary name to be monitored - monitor_type: AMonitorTypeof the monitor
The callback to record summary to tensorboard for each epoch
- extends:
FrequencyCallback
- Properties:
- writer: A
tensorboard.SummaryWriterto record scalars
- Methods:
-
add_graph- Add graph to TensorBoard
- Parameters:
- model: A
torch.nn.Moduleto add - input_shape: An optionaltupleof inintfor the inputs
-