Skip to content

Commit 143476c

Browse files
committed
update config for classification
update config for classification
1 parent 0e034a7 commit 143476c

File tree

8 files changed

+62
-31
lines changed

8 files changed

+62
-31
lines changed

classification/AntBee/README.md

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,18 +6,27 @@
66
In this example, we finetune a pretrained resnet18 for classification of images with two categries: Ant and Bee. This example is a PyMIC implementation of pytorch's "transfer learning for computer vision tutorial". The orginal tutorial can be found [here][torch_tutorial]. In PyMIC's implementation, we only need to edit the configure file to run the code.
77

88
## Data and preprocessing
9-
1. The dataset contains about 120 training images each for ants and bees. There are 75 validation images for each class. Download the data from [here][data_link] and extract it.
10-
2. Set `AntBee_root` according to your computer in `write_csv_files.py`, where `AntBee_root` should be the path of `hymenoptera_data` based on the dataset you extracted.
11-
3. Run `python write_csv_files.py` to create two csv files storing the paths and labels of training and validation images. They are `train_data.csv` and `valid_data.csv` and saved in `./config`.
9+
1. The dataset contains about 120 training images each for ants and bees. There are 75 validation images for each class. Download the data from [here][data_link] and extract it to `PyMIC_data`. Then the path for training and validation set should be `PyMIC_data/hymenoptera_data/train` and `PyMIC_data/hymenoptera_data/val`, respectively.
10+
2. Run `python write_csv_files.py` to create two csv files storing the paths and labels of training and validation images. They are `train_data.csv` and `valid_data.csv` and saved in `./config`.
1211

1312
[torch_tutorial]:https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html
1413
[data_link]:https://download.pytorch.org/tutorial/hymenoptera_data.zip
1514

1615
## Finetuning all layers of resnet18
17-
1. Here we use resnet18 for finetuning, and update all the layers. Open the configure file `config/train_test_ce1.cfg`. In the `network` section we can find details for the network. In the `dataset` section, set the value of `root_dir` as your `AntBee_root`. Then start to train by running:
16+
1. Here we use resnet18 for finetuning, and update all the layers. Open the configure file `config/train_test_ce1.cfg`. In the `network` section we can find details for the network. Here `update_layers = 0` means updating all the layers.
17+
```bash
18+
# type of network
19+
net_type = resnet18
20+
pretrain = True
21+
input_chns = 3
22+
# finetune all the layers
23+
update_layers = 0
24+
```
25+
26+
Then start to train by running:
1827

1928
```bash
20-
pymic_net_run train config/train_test_ce1.cfg
29+
pymic_run train config/train_test_ce1.cfg
2130
```
2231

2332
2. During training or after training, run `tensorboard --logdir model/resnet18_ce1` and you will see a link in the output, such as `http://your-computer:6006`. Open the link in the browser and you can observe the average loss and accuracy during the training stage, such as shown in the following images, where blue and red curves are for training set and validation set respectively. The iteration number obtained the highest accuracy on the validation set was 400, and may be different based on the hardware environment. After training, you can find the trained models in `./model/resnet18_ce1`.
@@ -26,24 +35,33 @@ pymic_net_run train config/train_test_ce1.cfg
2635
![avg_acc](./picture/acc.png)
2736

2837
## Testing and evaluation
29-
1. Run the following command to obtain classification results of testing images. By default we use the best performing checkpoint based on the validation set. You can set `ckpt_mode` to 0 in `config/train_test.cfg` to use the latest checkpoint.
38+
1. Run the following command to obtain classification results of testing images. By default we use the best performing checkpoint based on the validation set. You can set `ckpt_mode` to 0 in `config/train_test_ce1.cfg` to use the latest checkpoint.
3039

3140
```bash
3241
mkdir result
33-
pymic_net_run test config/train_test_ce1.cfg
42+
pymic_run test config/train_test_ce1.cfg
3443
```
3544

3645
2. Then run the following command to obtain quantitative evaluation results in terms of accuracy.
3746

3847
```bash
39-
pymic_evaluate_cls config/evaluation.cfg
48+
pymic_eval_cls config/evaluation.cfg
4049
```
4150

42-
The obtained accuracy by default setting should be around 0.9412, and the AUC will be 0.973.
51+
The obtained accuracy by default setting should be around 0.9412, and the AUC will be around 0.976.
4352

4453
3. Run `python show_roc.py` to show the receiver operating characteristic curve.
4554

4655
![roc](./picture/roc.png)
4756

4857
## Finetuning the last layer of resnet18
49-
Similarly to the above example, we further try to only finetune the last layer of resnet18 for the same classification task. Use a different configure file `config/train_test_ce2.cfg` for training and testing, where `update_layers = -1` in the `network` section means updating the last layer only. Edit `config/evaluation.cfg` accordinly for evaluation. The iteration number obtained the highest accuracy on the validation set was 400 in our testing machine, and the accuracy was around 0.9543. The AUC was 0.981.
58+
Similarly to the above example, we further try to only finetune the last layer of resnet18 for the same classification task. Use a different configure file `config/train_test_ce2.cfg` for training and testing, where `update_layers = -1` in the `network` section means updating the last layer only:
59+
```bash
60+
net_type = resnet18
61+
pretrain = True
62+
input_chns = 3
63+
# finetune the last layer only
64+
update_layers = -1
65+
```
66+
67+
Edit `config/evaluation.cfg` accordinly for evaluation.

classification/AntBee/config/train_test_ce1.cfg

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
tensor_type = float
44

55
task_type = cls
6-
root_dir = /home/guotai/disk2t/projects/torch_project/transfer_learning/hymenoptera_data
6+
root_dir = ../../PyMIC_data/hymenoptera_data
77
train_csv = config/train_data.csv
88
valid_csv = config/valid_data.csv
99
test_csv = config/valid_data.csv
@@ -57,17 +57,18 @@ momentum = 0.9
5757
weight_decay = 1e-5
5858

5959
# for lr schedular (MultiStepLR)
60+
lr_scheduler = MultiStepLR
6061
lr_gamma = 0.1
6162
lr_milestones = [500, 1000]
6263

6364
ckpt_save_dir = model/resnet18_ce1
64-
ckpt_save_prefix = resnet18
65+
ckpt_prefix = resnet18
6566

6667
# iteration
6768
iter_start = 0
6869
iter_max = 1500
6970
iter_valid = 100
70-
iter_save = 500
71+
iter_save = 1500
7172

7273
[testing]
7374
# list of gpus

classification/AntBee/config/train_test_ce2.cfg

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
tensor_type = float
44

55
task_type = cls
6-
root_dir = /home/guotai/disk2t/projects/torch_project/transfer_learning/hymenoptera_data
6+
root_dir = ../../PyMIC_data/hymenoptera_data
77
train_csv = config/train_data.csv
88
valid_csv = config/valid_data.csv
99
test_csv = config/valid_data.csv
@@ -58,17 +58,18 @@ momentum = 0.9
5858
weight_decay = 1e-5
5959

6060
# for lr schedular (MultiStepLR)
61+
lr_scheduler = MultiStepLR
6162
lr_gamma = 0.1
6263
lr_milestones = [500, 1000]
6364

6465
ckpt_save_dir = model/resnet18_ce2
65-
ckpt_save_prefix = resnet18
66+
ckpt_prefix = resnet18
6667

6768
# iteration
6869
iter_start = 0
6970
iter_max = 1500
7071
iter_valid = 100
71-
iter_save = 500
72+
iter_save = 1500
7273

7374
[testing]
7475
# list of gpus

classification/AntBee/write_csv_files.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def get_evaluation_image_pairs(test_csv, gt_seg_csv):
5050

5151
if __name__ == "__main__":
5252
# create cvs file for JSRT dataset
53-
AntBee_root = '/home/guotai/disk2t/projects/torch_project/transfer_learning/hymenoptera_data'
53+
AntBee_root = '../../PyMIC_data/hymenoptera_data'
5454
create_csv_file(AntBee_root, 'train', 'config/train_data.csv')
5555
create_csv_file(AntBee_root, 'val', 'config/valid_data.csv')
5656

classification/CHNCXR/README.md

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,18 +6,27 @@
66
In this example, we finetune a pretrained resnet18 and vgg16 for classification of X-Ray images with two categries: normal and tuberculosis.
77

88
## Data and preprocessing
9-
1. We use the Shenzhen Hospital X-ray Set for this experiment. This dataset contains images in JPEG format. There are 326 normal x-rays and 336 abnormal x-rays showing various manifestations of tuberculosis. Download the dataset from [here][data_link] and extract it, and the folder name will be "ChinaSet_AllFiles/CXR_png".
9+
1. We use the Shenzhen Hospital X-ray Set for this experiment. This [dataset] contains images in JPEG format. There are 326 normal x-rays and 336 abnormal x-rays showing various manifestations of tuberculosis. The images are available in `PyMIC_data/CHNCXR`.
1010

1111
[data_link]:https://lhncbc.nlm.nih.gov/publication/pub9931
1212

13-
2. Set `image_dir` according to your computer in `write_csv_files.py`, where `image_dir` should be the path of "CXR_png" based on the dataset you extracted.
14-
3. Run `python write_csv_files.py` to randomly split the entire dataset into 70% for training, 10% for validation and 20% for testing. The output files are `cxr_train.csv`, `cxr_valid.csv` and `cxr_test.csv` under folder `./config`.
13+
2. Run `python write_csv_files.py` to randomly split the entire dataset into 70% for training, 10% for validation and 20% for testing. The output files are `cxr_train.csv`, `cxr_valid.csv` and `cxr_test.csv` under folder `./config`.
1514

1615
## Finetuning resnet18
17-
1. First, we use resnet18 for finetuning, and update all the layers. Open the configure file `config/net_resnet18.cfg`. In the `dataset` section, set the value of `root_dir` as your path of "CXR_png". Then start to train by running:
16+
1. First, we use resnet18 for finetuning, and update all the layers. The configuration file is `config/net_resnet18.cfg`. The setting for network is:
17+
18+
```bash
19+
net_type = resnet18
20+
pretrain = True
21+
input_chns = 3
22+
# finetune all the layers
23+
update_layers = 0
24+
```
25+
26+
Start to train by running:
1827

1928
```bash
20-
pymic_net_run train config/net_resnet18.cfg
29+
pymic_run train config/net_resnet18.cfg
2130
```
2231

2332
2. During training or after training, run `tensorboard --logdir model/resnet18` and you will see a link in the output, such as `http://your-computer:6006`. Open the link in the browser and you can observe the average loss and accuracy during the training stage, such as shown in the following images, where blue and red curves are for training set and validation set respectively. The iteration number obtained the highest accuracy on the validation set was 1800, and may be different based on the hardware environment. After training, you can find the trained models in `./model/resnet18`.
@@ -30,13 +39,13 @@ pymic_net_run train config/net_resnet18.cfg
3039

3140
```bash
3241
mkdir result
33-
pymic_net_run test config/net_resnet18.cfg
42+
pymic_run test config/net_resnet18.cfg
3443
```
3544

3645
2. Then run the following command to obtain quantitative evaluation results in terms of accuracy.
3746

3847
```bash
39-
pymic_evaluate_cls config/evaluation.cfg
48+
pymic_eval_cls config/evaluation.cfg
4049
```
4150

4251
The obtained accuracy by default setting should be around 0.8571, and the AUC is 0.94.
@@ -47,4 +56,4 @@ The obtained accuracy by default setting should be around 0.8571, and the AUC is
4756

4857

4958
## Finetuning vgg16
50-
Similarly to the above example, we further try to finetune vgg16 for the same classification task. Use a different configure file `config/net_vg16.cfg` for training and testing. Edit `config/evaluation.cfg` accordinly for evaluation. The iteration number for the highest accuracy on the validation set was 2300, and the accuracy will be around 0.8797.
59+
Similarly to the above example, we further try to finetune vgg16 for the same classification task. Use a different configure file `config/net_vg16.cfg` for training and testing. Edit `config/evaluation.cfg` accordinly for evaluation. The iteration number for the highest accuracy on the validation set was 2300, and the accuracy will be around 0.8797.

classification/CHNCXR/config/net_resnet18.cfg

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
tensor_type = float
44

55
task_type = cls
6-
root_dir = /home/guotai/disk2t/data/lung/ChinaSet_AllFiles/CXR_png
6+
root_dir = ../../PyMIC_data/CHNCXR/CXR_png
77
train_csv = config/cxr_train.csv
88
valid_csv = config/cxr_valid.csv
99
test_csv = config/cxr_test.csv
@@ -55,11 +55,12 @@ momentum = 0.9
5555
weight_decay = 1e-5
5656

5757
# for lr schedular (MultiStepLR)
58+
lr_scheduler = MultiStepLR
5859
lr_gamma = 0.1
5960
lr_milestones = [1500, 3000]
6061

6162
ckpt_save_dir = model/resnet18
62-
ckpt_save_prefix = resnet18
63+
ckpt_prefix = resnet18
6364

6465
# iteration
6566
iter_start = 0

classification/CHNCXR/config/net_vgg16.cfg

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
tensor_type = float
44

55
task_type = cls
6-
root_dir = /home/guotai/disk2t/data/lung/ChinaSet_AllFiles/CXR_png
6+
root_dir = ../../PyMIC_data/CHNCXR/CXR_png
77
train_csv = config/cxr_train.csv
88
valid_csv = config/cxr_valid.csv
99
test_csv = config/cxr_test.csv
@@ -55,11 +55,12 @@ momentum = 0.9
5555
weight_decay = 1e-5
5656

5757
# for lr schedular (MultiStepLR)
58+
lr_scheduler = MultiStepLR
5859
lr_gamma = 0.1
5960
lr_milestones = [1500, 3000]
6061

61-
ckpt_save_dir = model/vgg16
62-
ckpt_save_prefix = vgg16
62+
ckpt_save_dir = model/vgg16
63+
ckpt_prefix = vgg16
6364

6465
# iteration
6566
iter_start = 0

classification/CHNCXR/write_csv_files.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def random_split_dataset():
5858

5959
if __name__ == "__main__":
6060
# create cvs file for ISIC dataset
61-
image_dir = '/home/guotai/disk2t/data/lung/ChinaSet_AllFiles/CXR_png'
61+
image_dir = '../../PyMIC_data/CHNCXR/CXR_png'
6262
output_csv = 'config/cxr_all.csv'
6363
create_csv_file(image_dir, output_csv)
6464

0 commit comments

Comments
 (0)