-
Notifications
You must be signed in to change notification settings - Fork 0
Home
Qisheng Robert He edited this page Jan 12, 2024
·
23 revisions
#torchmanager
Contains callbacks for fit method in the manager
Contains organized configurations tools
Contains wrapped loss functions
Contains wrapped metric functions
The basic manager
- Properties:
- compiled: A
boolflag of if the manager has been compiled - loss_fn: ACallablemethod that takes the truth and predictions intorch.Tensorand returns a losstorch.Tensor- metrics: Adictof metrics with a name instrand aCallablemethod that takes the truth and predictions intorch.Tensorand returns a losstorch.Tensor- model: A targettorch.nn.Moduleto be trained - optimizer: Atorch.optim.Optimizerto train the model
- Methods:
- Constructor
- Parameters:
- loss_fn: An optional
Lossobject to calculate the loss for single loss or adictof losses inLosswith their names instrto calculate multiple losses - metrics: An optionaldictof metrics with a name instrand aMetricobject to calculate the metric - model: An optional targettorch.nn.Moduleto be trained - optimizer: An optionaltorch.optim.Optimizerto train the model -
compile- Compiles the manager
- Parameters:
- loss_fn: A
Lossobject to calculate the loss for single loss or adictof losses inLosswith their names instrto calculate multiple losses - metrics: A
dictof metrics with a name instrand aMetricobject to calculate the metric - optimizer: A
torch.optim.Optimizerto train the model
- loss_fn: A
- Parameters:
- Compiles the manager
-
from_checkpoint- Method to load a manager from a saved
Checkpoint. The manager will not be compiled with a loss function and its metrics. - classmethod
- Returns: A loaded
Manager
- Returns: A loaded
- Method to load a manager from a saved
-
to_checkpointConvert the current manager to a checkpoint - Returns: ACheckpointwith its model inModuletype
- Constructor
- Parameters:
- loss_fn: An optional
The manager to load data during training or testing
- Methods:
-
unpack_dataUnpacks data to input and target - Parameters: - data:Anykind of data object - Returns: AtupleofAnykind of input andAnykind of target
-
A testing manager, only used for testing
- extends:
BaseManager,DataManager
- Properties
- compiled_losses: The loss function in
Lossthat must be exist - compiled_metrics: The
dictof metrics inMetricthat does not contain losses
- compiled_losses: The loss function in
- Methods
-
test- Test target model
- Parameters:
- dataset: Either
SizedIterableordata.DataLoaderto load the dataset - device: An optional
torch.deviceto test on - use_multi_gpus: A
boolflag to use multi gpus during testing - show_verbose: A
boolflag to show the progress bar during testing
- dataset: Either
- Returns: A
dictof validation summary
- Parameters:
- Test target model
-
test_step- A single testing step
- Parameters:
- x_train: The testing data in
torch.Tensor - y_train: The testing label in
torch.Tensor
- x_train: The testing data in
- Returns: A
dictof validation summary
-
A training manager
- extends:
TestingManager - [Deprecation Warning]: Method
trainbecomes protected from v1.0.2, the public method will be removed from v1.2.0. Override_trainmethod instead. - Compile a model, optimizer, loss function, and metrics into the manager:
import torch
from torchmanager import losses, metrics
class SomeModel(torch.nn.Module): ...
model = SomeModel()
optimizer = torch.optim.SGD(...)
loss_fn = losses.Loss(...)
metric_fns = {
... ...
... }
manager = Manager(model, optimizer, loss_fn, metric_fns=metric_fns)
- Train using fit method:
from torch.utils.data import Dataset, DataLoader
dataset = Dataset(...)
dataset = DataLoader(dataset, ...)
epochs: int = ...
manager.fit(dataset, epochs, ...)
- Properties
- current_epoch: The
intindex of current training epoch - compiled_optimizer: The
torch.optim.Optimizerthat must be exist
- current_epoch: The
- Methods
-
_train- The single training step for an epoch
- Parameters:
- dataset: A
SizedIterabletraining dataset - iterations: An optional
intof total training iterations, must be smaller than the size of dataset - device: A
torch.devicewhere the data is moved to, should be same as the model - use_multi_gpus: A
boolflag of if using multi gpus - show_verbose: A
boolflag of if showing progress bar - verbose_type: A
view.VerboseTypethat controls the display of verbose - callbacks_list: A
listof callbacks inCallback
- dataset: A
- Returns: A summary of
dictwith keys asstrand values asfloat
- Parameters:
- The single training step for an epoch
-
fit- Training algorithm
- Parameters:
- training_dataset: Any kind of training dataset, must performs to
SizedIterable - epochs: An optional
intnumber of training epochs - iterations: An optional
intnumber of training iterations - lr_scheduelr: An optioanl
torch.optim.lr_scheduler._LRSchedulerto update the lr per epoch - is_dynamic_pruning: A
boolflag of if using dynamic pruning - val_dataset: An optional validation
Any - device: An optional
torch.devicewhere the data is moved to, gpu will be used when available if not specified. - use_multi_gpus: A
boolflag of if using multi gpus - callbacks_list: A
listof callbacks inCallback - **kwargs: Additional keyword arguments that will be passed to
trainmethod.
- training_dataset: Any kind of training dataset, must performs to
- Returns: A trained
torch.nn.Module
- Parameters:
- Training algorithm
-
train_step- A single training step
- Parameters:
- x_train: The training data
- y_train: The training label
- Returns: A summary of
dictwith keys asstrand values asfloat
- Parameters:
- A single training step
-