Skip to content

Commit 0d03c25

Browse files
committed
fix output number
1 parent ea24a5d commit 0d03c25

File tree

2 files changed

+8
-0
lines changed

2 files changed

+8
-0
lines changed

pymic/net_run/infer_func.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ def volume_infer(image, net, device, class_num,
2121
outputs = net(image)
2222
if(isinstance(outputs, tuple) or isinstance(outputs, list)):
2323
outputs = [item.cpu().numpy() for item in outputs]
24+
outputs = outputs[:output_num]
2425
else:
2526
outputs = outputs.cpu().numpy()
2627
else:
@@ -101,6 +102,8 @@ def volume_infer_by_patch(image, net, device, class_num,
101102
if(not(isinstance(out_mini_batch, tuple) or isinstance(out_mini_batch, list))):
102103
out_mini_batch = [out_mini_batch]
103104
out_mini_batch = [item.cpu().numpy() for item in out_mini_batch]
105+
106+
# use a mask to store overlapping regions
104107
mask_mini_batch = np.ones_like(out_mini_batch[0])
105108
for batch_idx in range(batch_start_idx, batch_end_idx):
106109
crop_start = sub_image_starts[batch_idx]

pymic/net_run/net_run_agent.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -370,6 +370,11 @@ def test_time_dropout(m):
370370

371371
prob_list = [scipy.special.softmax(predict[0], axis = 0) for predict in predict_list]
372372
if(multi_pred_avg):
373+
if(output_num == 1):
374+
raise ValueError("multiple predictions expected, but output_num was set to 1")
375+
if(output_num != len(prob_list)):
376+
raise ValueError("expected output_num was set to {0:}, but {1:} outputs obtained".format(
377+
output_dir, len(prob_list)))
373378
prob_stack = np.asarray(prob_list, np.float32)
374379
prob = np.mean(prob_stack, axis = 0)
375380
var = np.var(prob_stack, axis = 0)

0 commit comments

Comments
 (0)