Skip to content

Commit 1b4152d

Browse files
committed
update train and test files
1 parent 50abf3b commit 1b4152d

File tree

12 files changed

+31
-2
lines changed

12 files changed

+31
-2
lines changed

pymic/io/image_read_write.py

100644100755
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ def save_nd_array_as_image(data, image_name, reference_name = None):
112112

113113
elif(image_name.endswith(".jpg") or image_name.endswith(".jpeg") or
114114
image_name.endswith(".tif") or image_name.endswith(".png")):
115+
assert(data_dim == 2)
115116
save_array_as_rgb_image(data, image_name)
116117
else:
117118
raise ValueError("unsupported image format {0:}".format(

pymic/io/nifty_dataset.py

100644100755
File mode changed.

pymic/io/transform3d.py

100644100755
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -363,11 +363,13 @@ class RandomCrop(object):
363363

364364
def __init__(self, output_size, fg_focus = False, fg_ratio = 0.0, mask_label = None, inverse = True):
365365
assert isinstance(output_size, (list, tuple))
366+
assert isinstance(mask_label, (list, tuple))
366367
self.output_size = output_size
367368
self.inverse = inverse
368369
self.fg_focus = fg_focus
369370
self.fg_ratio = fg_ratio
370371
self.mask_label = mask_label
372+
371373

372374
def __call__(self, sample):
373375
image = sample['image']

pymic/net2d/unet2d.py

100644100755
File mode changed.

pymic/net3d/unet2d5.py

100644100755
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,12 +99,13 @@ def forward(self, x):
9999
params = {'in_chns':4,
100100
'feature_chns':[2, 8, 32, 48, 64],
101101
'class_num': 2,
102+
'dropout': True,
102103
'acti_func': 'leakyrelu',
103104
'leakyrelu_alpha': 0.01}
104105
Net = UNet2D5(params)
105106
Net = Net.double()
106107

107-
x = np.random.rand(4, 4, 32, 96, 96)
108+
x = np.random.rand(1, 4, 16, 128, 128)
108109
xt = torch.from_numpy(x)
109110
xt = torch.tensor(xt)
110111

pymic/net3d/unet3d.py

100644100755
File mode changed.

pymic/train_infer/train_infer.py

100644100755
Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import torch.optim as optim
1313
import torch.nn.functional as F
1414

15+
from scipy import special
1516
from datetime import datetime
1617
from tensorboardX import SummaryWriter
1718
from pymic.io.image_read_write import save_nd_array_as_image
@@ -221,7 +222,7 @@ def __infer(self):
221222
device = torch.device(self.config['testing']['device_name'])
222223
self.net.to(device)
223224
# laod network parameters and set the network as evaluation mode
224-
self.checkpoint = torch.load(self.config['testing']['checkpoint_name'])
225+
self.checkpoint = torch.load(self.config['testing']['checkpoint_name'], map_location = device)
225226
self.net.load_state_dict(self.checkpoint['model_state_dict'])
226227

227228
if(self.config['testing']['evaluation_mode'] == True):
@@ -294,6 +295,8 @@ def test_time_dropout(m):
294295
for c in range(0, class_num):
295296
temp_prob = prob[c]
296297
prob_save_name = "{0:}_prob_{1:}.{2:}".format(save_prefix, c, save_format)
298+
if(len(temp_prob.shape) == 2):
299+
temp_prob = np.asarray(temp_prob * 255, np.uint8)
297300
save_nd_array_as_image(temp_prob, prob_save_name, root_dir + '/' + names[0])
298301

299302
avg_time = (time.time() - start_time) / len(self.test_loder)

pymic/util/__init__.py

Whitespace-only changes.

pymic/util/average_model.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
import torch
2+
3+
checkpoint_name1 = "/home/guotai/projects/PyMIC/examples/brats/model/casecade/wt/unet3d_4_8000.pt"
4+
checkpoint1 = torch.load(checkpoint_name1)
5+
state_dict1 = checkpoint1['model_state_dict']
6+
7+
checkpoint_name2 = "/home/guotai/projects/PyMIC/examples/brats/model/casecade/wt/unet3d_4_10000.pt"
8+
checkpoint2 = torch.load(checkpoint_name2)
9+
state_dict2 = checkpoint2['model_state_dict']
10+
11+
checkpoint_name3 = "/home/guotai/projects/PyMIC/examples/brats/model/casecade/wt/unet3d_4_12000.pt"
12+
checkpoint3 = torch.load(checkpoint_name3)
13+
state_dict3 = checkpoint3['model_state_dict']
14+
15+
state_dict = {}
16+
for item in state_dict1:
17+
print(item)
18+
state_dict[item] = (state_dict1[item] + state_dict2[item] + state_dict3[item])/3
19+
20+
save_dict = {'model_state_dict': state_dict}
21+
save_name = "/home/guotai/projects/PyMIC/examples/brats/model/casecade/wt/unet3d_4_avg.pt"
22+
torch.save(save_dict, save_name)

pymic/util/evaluation.py

100644100755
File mode changed.

0 commit comments

Comments
 (0)