- This is a PyTorch-based deep learning trainer designed for beginners to get started quickly.
-
Clone the repo:
git clone https://github.com/xigua7105/DLTrainer.git cd DLTrainer -
Environment setup: DLTrainer works with Python 3.8+ and PyTorch 2.0+.
conda create -n DLTrainer python=3.8 conda activate DLTrainer pip install -r requirements.txt
- For example, training ResNet-50 on the CIFAR-100 dataset:
torchrun --nproc_per_node=1 --nnodes=1 --standalone train.py --c configs/resnet50-cifar100.yaml
- The configs mainly consist of six major parts: model, data, optim, loss, trainer, and logger. The following is the explanation of the usage of different parameters for each part.
- name (optional): Your custom naming for the network, such as convnet-lite, convnet-base, convnet-large, etc.
- task (optional): A brief introduction to the task.
- struct (required): Parameters required for defining the model architecture.
- dir (required): The path of the dataset.
- dataset_type (required): The type of the dataset, such as IRDataset (image restoration dataset), ImageFolderDefault (image classification dataset), etc. You can define the dataset in dataset and register it in _register.
- is_multi_loader (required): Whether there are multiple training sets or test sets.
- train_transforms:
- name (required): The names of the defined transforms function. You can define and register it in transforms.
- kwargs (optional):
- test_transforms (required):
- train_target_transforms (required):
- test_target_transforms (required):
- optimizer:
- name (required): Such as Adam, SGD, etc. You can register them in _register.
- lr (required): Learning rate.
- scheduler:
- name (required): Such as Cosine, MultiStepLR, etc.
- loss_terms:
- name (required): Such as CrossEntropyLoss, IRLoss, etc. You can define and register it in loss.
- name (required): The type of the Trainer, such as CLSTrainer, IRTrainer, etc. You can define it in trainer and register it in _register.
- ckpt_dir (required): The folder for saving the checkpoint.
- batch_size (required): Global train batch size.
- batch_size_test (required): Global test batch size.
- num_workers_per_gpu (required): Dataloaders parameters.
- drop_last (required): Dataloaders parameters.
- pin_memory (required): Dataloaders parameters.
- save_freq (required): The frequency of saving the checkpoint.
- amp (required): Whether to enable mixed precision to accelerate the training of the model.
- dir: The folder for saving the log.
- log_freq: The frequency of saving the log.
This project is released under the MIT license. Please see the LICENSE file for more information.