Skip to content

Commit 991995b

Browse files
committed
support ssl
add semi- and weakly-supervised learning
1 parent 0218ebc commit 991995b

File tree

6 files changed

+35
-37
lines changed

6 files changed

+35
-37
lines changed

pymic/io/nifty_dataset.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -94,8 +94,6 @@ def __init__(self, root_dir, csv_file, modal_num = 1, class_num = 2,
9494
super(ClassificationDataset, self).__init__(root_dir,
9595
csv_file, modal_num, with_label, transform)
9696
self.class_num = class_num
97-
print("class number for ClassificationDataset", self.class_num)
98-
print("self.transform", self.transform)
9997

10098
def __getlabel__(self, idx):
10199
csv_keys = list(self.csv_items.keys())
@@ -127,7 +125,6 @@ def __getitem__(self, idx):
127125
sample['label'] = self.__getlabel__(idx)
128126
if (self.image_weight_idx is not None):
129127
sample['image_weight'] = self.__getweight__(idx)
130-
print("***transform", self.transform)
131128
if self.transform:
132129
sample = self.transform(sample)
133130
return sample

pymic/net_run/agent_abstract.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ def __init__(self, config, stage = 'train'):
4242
self.scheduler = None
4343
self.loss_dict = None
4444
self.transform_dict = None
45+
self.inferer = None
4546
self.tensor_type = config['dataset']['tensor_type']
4647
self.task_type = config['dataset']['task_type'] #cls, cls_mtbc, seg
4748
self.deterministic = config['training'].get('deterministic', True)
@@ -69,6 +70,9 @@ def set_optimizer(self, optimizer):
6970

7071
def set_scheduler(self, scheduler):
7172
self.scheduler = scheduler
73+
74+
def set_inferer(self, inferer):
75+
self.inferer = inferer
7276

7377
def get_checkpoint_name(self):
7478
ckpt_mode = self.config['testing']['ckpt_mode']

pymic/net_run/agent_cls.py

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,28 +7,20 @@
77
import sys
88
import time
99
import random
10-
import scipy
1110
import torch
12-
import torchvision
13-
from torchvision import datasets, models, transforms
11+
from torchvision import transforms
1412
import numpy as np
1513
import torch.nn as nn
1614
import torch.optim as optim
1715
import torch.nn.functional as F
18-
from torch.optim import lr_scheduler
1916
import matplotlib.pyplot as plt
2017
from PIL import Image
21-
from scipy import special
2218
from datetime import datetime
2319
from tensorboardX import SummaryWriter
24-
from pymic.io.image_read_write import save_nd_array_as_image
2520
from pymic.io.nifty_dataset import ClassificationDataset
2621
from pymic.loss.loss_dict_cls import PyMICClsLossDict
2722
from pymic.net.net_dict_cls import TorchClsNetDict
28-
from pymic.net_run.get_optimizer import get_optimiser
2923
from pymic.transform.trans_dict import TransformDict
30-
from pymic.util.image_process import convert_label
31-
from pymic.util.parse_config import parse_config
3224
from pymic.net_run.agent_abstract import NetRunAgent
3325
import warnings
3426
warnings.filterwarnings('ignore', '.*output shape of zoom.*')
@@ -207,6 +199,8 @@ def train_valid(self):
207199
self.net.to(self.device)
208200

209201
ckpt_dir = self.config['training']['ckpt_save_dir']
202+
if(ckpt_dir[-1] == "/"):
203+
ckpt_dir = ckpt_dir[:-1]
210204
ckpt_prefx = ckpt_dir.split('/')[-1]
211205
iter_start = self.config['training']['iter_start']
212206
iter_max = self.config['training']['iter_max']
@@ -292,10 +286,7 @@ def infer(self):
292286
with torch.no_grad():
293287
for data in self.test_loader:
294288
names = data['names']
295-
if(names[0] != "data3_process/20190711_1005487059.png"):
296-
continue
297289
inputs = self.convert_tensor_type(data['image'])
298-
print("intensity mean and std", inputs.mean(), inputs.std())
299290
inputs = inputs.to(device)
300291

301292
start_time = time.time()

pymic/net_run/agent_seg.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -221,21 +221,22 @@ def training(self):
221221

222222
def validation(self):
223223
class_num = self.config['network']['class_num']
224-
infer_cfg = self.config['testing']
225-
infer_cfg['class_num'] = class_num
224+
if(self.inferer is None):
225+
infer_cfg = self.config['testing']
226+
infer_cfg['class_num'] = class_num
227+
self.inferer = Inferer(infer_cfg)
226228

227229
valid_loss_list = []
228230
valid_dice_list = []
229231
validIter = iter(self.valid_loader)
230232
with torch.no_grad():
231233
self.net.eval()
232-
infer_obj = Inferer(self.net, infer_cfg)
233234
for data in validIter:
234235
inputs = self.convert_tensor_type(data['image'])
235236
labels_prob = self.convert_tensor_type(data['label_prob'])
236237
inputs, labels_prob = inputs.to(self.device), labels_prob.to(self.device)
237238
batch_n = inputs.shape[0]
238-
outputs = infer_obj.run(inputs)
239+
outputs = self.inferer.run(self.net, inputs)
239240

240241
# The tensors are on CPU when calculating loss for validation data
241242
loss = self.get_loss_value(data, outputs, labels_prob)
@@ -286,6 +287,8 @@ def train_valid(self):
286287
self.device = torch.device("cuda:{0:}".format(device_ids[0]))
287288
self.net.to(self.device)
288289
ckpt_dir = self.config['training']['ckpt_save_dir']
290+
if(ckpt_dir[-1] == "/"):
291+
ckpt_dir = ckpt_dir[:-1]
289292
ckpt_prefx = ckpt_dir.split('/')[-1]
290293
iter_start = self.config['training']['iter_start']
291294
iter_max = self.config['training']['iter_max']
@@ -392,9 +395,10 @@ def test_time_dropout(m):
392395
checkpoint = torch.load(ckpt_name, map_location = device)
393396
self.net.load_state_dict(checkpoint['model_state_dict'])
394397

395-
infer_cfg = self.config['testing']
396-
infer_cfg['class_num'] = self.config['network']['class_num']
397-
infer_obj = Inferer(self.net, infer_cfg)
398+
if(self.inferer is None):
399+
infer_cfg = self.config['testing']
400+
infer_cfg['class_num'] = self.config['network']['class_num']
401+
self.inferer = Inferer(infer_cfg)
398402
infer_time_list = []
399403
with torch.no_grad():
400404
for data in self.test_loader:
@@ -412,7 +416,7 @@ def test_time_dropout(m):
412416
# continue
413417
start_time = time.time()
414418

415-
pred = infer_obj.run(images)
419+
pred = self.inferer.run(self.net, images)
416420
# convert tensor to numpy
417421
if(isinstance(pred, (tuple, list))):
418422
pred = [item.cpu().numpy() for item in pred]
@@ -438,10 +442,11 @@ def infer_with_multiple_checkpoints(self):
438442
device_ids = self.config['testing']['gpus']
439443
device = torch.device("cuda:{0:}".format(device_ids[0]))
440444

445+
if(self.inferer is None):
446+
infer_cfg = self.config['testing']
447+
infer_cfg['class_num'] = self.config['network']['class_num']
448+
self.inferer = Inferer(infer_cfg)
441449
ckpt_names = self.config['testing']['ckpt_name']
442-
infer_cfg = self.config['testing']
443-
infer_cfg['class_num'] = self.config['network']['class_num']
444-
infer_obj = Inferer(self.net, infer_cfg)
445450
infer_time_list = []
446451
with torch.no_grad():
447452
for data in self.test_loader:
@@ -463,7 +468,7 @@ def infer_with_multiple_checkpoints(self):
463468
checkpoint = torch.load(ckpt_name, map_location = device)
464469
self.net.load_state_dict(checkpoint['model_state_dict'])
465470

466-
pred = infer_obj.run(images)
471+
pred = self.inferer.run(self.net, images)
467472
# convert tensor to numpy
468473
if(isinstance(pred, (tuple, list))):
469474
pred = [item.cpu().numpy() for item in pred]

pymic/net_run/infer_func.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,8 @@
99
from torch.nn.functional import interpolate
1010

1111
class Inferer(object):
12-
def __init__(self, model, config):
13-
self.model = model
14-
self.config = config
12+
def __init__(self, config):
13+
self.config = config
1514

1615
def __infer(self, image):
1716
use_sw = self.config.get('sliding_window_enable', False)
@@ -131,8 +130,9 @@ def __infer_with_sliding_window(self, image):
131130
output_list[i] = output_list[i] / counter_i
132131
return output_list
133132

134-
def run(self, image):
135-
tta_mode = self.config.get('tta_mode', 0)
133+
def run(self, model, image):
134+
self.model = model
135+
tta_mode = self.config.get('tta_mode', 0)
136136
if(tta_mode == 0):
137137
outputs = self.__infer(image)
138138
elif(tta_mode == 1): # test time augmentation with flip in 2D

pymic/net_run_noise/cl.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -77,10 +77,11 @@ def test_time_dropout(m):
7777
checkpoint = torch.load(ckpt_name, map_location = device)
7878
self.net.load_state_dict(checkpoint['model_state_dict'])
7979

80-
infer_cfg = self.config['testing']
81-
class_num = self.config['network']['class_num']
82-
infer_cfg['class_num'] = class_num
83-
infer_obj = Inferer(self.net, infer_cfg)
80+
if(self.inferer is None):
81+
infer_cfg = self.config['testing']
82+
class_num = self.config['network']['class_num']
83+
infer_cfg['class_num'] = class_num
84+
self.inferer = Inferer(infer_cfg)
8485
pred_list = []
8586
gt_list = []
8687
filename_list = []
@@ -102,7 +103,7 @@ def test_time_dropout(m):
102103
# save_nd_array_as_image(label_i, label_name, reference_name = None)
103104
# continue
104105

105-
pred = infer_obj.run(images)
106+
pred = self.inferer.run(self.net, images)
106107
# convert tensor to numpy
107108
if(isinstance(pred, (tuple, list))):
108109
pred = [item.cpu().numpy() for item in pred]

0 commit comments

Comments
 (0)