1+ # -*- coding: utf-8 -*-
2+ from __future__ import print_function , division
3+
4+ import torch
5+ import json
6+ import math
7+ import random
8+ import numpy as np
9+ from scipy import ndimage
10+ from pymic .transform .abstract_transform import AbstractTransform
11+ from pymic .util .image_process import *
12+
13+
14+ class CropWithBoundingBox (AbstractTransform ):
15+ """Crop the image (shape [C, D, H, W] or [C, H, W]) based on bounding box
16+ """
17+ def __init__ (self , params ):
18+ """
19+ start (None or tuple/list): The start index along each spatial axis.
20+ if None, calculate the start index automatically so that
21+ the cropped region is centered at the non-zero region.
22+ output_size (None or tuple/list): Desired spatial output size.
23+ if None, set it as the size of bounding box of non-zero region
24+ """
25+ self .start = params ['CropWithBoundingBox_start' .lower ()]
26+ self .output_size = params ['CropWithBoundingBox_output_size' .lower ()]
27+ self .inverse = params ['CropWithBoundingBox_inverse' .lower ()]
28+
29+ def __call__ (self , sample ):
30+ image = sample ['image' ]
31+ input_shape = image .shape
32+ input_dim = len (input_shape ) - 1
33+ bb_min , bb_max = get_ND_bounding_box (image )
34+ bb_min , bb_max = bb_min [1 :], bb_max [1 :]
35+ if (self .start is None ):
36+ if (self .output_size is None ):
37+ crop_min , crop_max = bb_min , bb_max
38+ else :
39+ assert (len (self .output_size ) == input_dim )
40+ crop_min = [int ((bb_min [i ] + bb_max [i ] + 1 )/ 2 ) - int (self .output_size [i ]/ 2 ) \
41+ for i in range (input_dim )]
42+ crop_min = [max (0 , crop_min [i ]) for i in range (input_dim )]
43+ crop_max = [crop_min [i ] + self .output_size [i ] for i in range (input_dim )]
44+ else :
45+ assert (len (self .start ) == input_dim )
46+ crop_min = self .start
47+ if (self .output_size is None ):
48+ assert (len (self .output_size ) == input_dim )
49+ crop_max = [crop_min [i ] + bb_max [i ] - bb_min [i ] \
50+ for i in range (input_dim )]
51+ else :
52+ crop_max = [crop_min [i ] + self .output_size [i ] for i in range (input_dim )]
53+ crop_min = [0 ] + crop_min
54+ crop_max = list (input_shape [0 :1 ]) + crop_max
55+ sample ['CropWithBoundingBox_Param' ] = json .dumps ((input_shape , crop_min , crop_max ))
56+
57+ image_t = crop_ND_volume_with_bounding_box (image , crop_min , crop_max )
58+ sample ['image' ] = image_t
59+
60+ if ('label' in sample and sample ['label' ].shape [1 :] == image .shape [1 :]):
61+ label = sample ['label' ]
62+ crop_max [0 ] = label .shape [0 ]
63+ label = crop_ND_volume_with_bounding_box (label , crop_min , crop_max )
64+ sample ['label' ] = label
65+ if ('weight' in sample and sample ['weight' ].shape [1 :] == image .shape [1 :]):
66+ weight = sample ['weight' ]
67+ crop_max [0 ] = weight .shape [0 ]
68+ weight = crop_ND_volume_with_bounding_box (weight , crop_min , crop_max )
69+ sample ['weight' ] = weight
70+ return sample
71+
72+ def inverse_transform_for_prediction (self , sample ):
73+ ''' rescale sample['predict'] (5D or 4D) to the original spatial shape.
74+ assume batch size is 1, otherwise scale may be different for
75+ different elemenets in the batch.
76+
77+ origin_shape is a 4D or 3D vector as saved in __call__().'''
78+ if (isinstance (sample ['CropWithBoundingBox_Param' ], list ) or \
79+ isinstance (sample ['CropWithBoundingBox_Param' ], tuple )):
80+ params = json .loads (sample ['CropWithBoundingBox_Param' ][0 ])
81+ else :
82+ params = json .loads (sample ['CropWithBoundingBox_Param' ])
83+ origin_shape = params [0 ]
84+ crop_min = params [1 ]
85+ crop_max = params [2 ]
86+ predict = sample ['predict' ]
87+ if (isinstance (predict , tuple ) or isinstance (predict , list )):
88+ output_predict = []
89+ for predict_i in predict :
90+ origin_shape = list (predict_i .shape [:2 ]) + origin_shape [1 :]
91+ output_predict_i = np .zeros (origin_shape , predict_i .dtype )
92+ crop_min = [0 , 0 ] + crop_min [1 :]
93+ crop_max = list (predict_i .shape [:2 ]) + crop_max [1 :]
94+ output_predict_i = set_ND_volume_roi_with_bounding_box_range (output_predict_i ,
95+ crop_min , crop_max , predict_i )
96+ output_predict .append (output_predict_i )
97+ else :
98+ origin_shape = list (predict .shape [:2 ]) + origin_shape [1 :]
99+ output_predict = np .zeros (origin_shape , predict .dtype )
100+ crop_min = [0 , 0 ] + crop_min [1 :]
101+ crop_max = list (predict .shape [:2 ]) + crop_max [1 :]
102+ output_predict = set_ND_volume_roi_with_bounding_box_range (output_predict ,
103+ crop_min , crop_max , predict )
104+
105+ sample ['predict' ] = output_predict
106+ return sample
107+
108+ class RandomCrop (object ):
109+ """Randomly crop the input image (shape [C, D, H, W] or [C, H, W])
110+ """
111+ def __init__ (self , params ):
112+ """
113+ output_size (tuple or list): Desired output size [D, H, W] or [H, W].
114+ the output channel is the same as the input channel.
115+ foreground_focus (bool): If true, allow crop around the foreground.
116+ foreground_ratio (float): Specifying the probability of foreground
117+ focus cropping when foreground_focus is true.
118+ mask_label (None, or tuple / list): Specifying the foreground labels for foreground
119+ focus cropping
120+ """
121+ self .output_size = params ['RandomCrop_output_size' .lower ()]
122+ self .fg_focus = params ['RandomCrop_foreground_focus' .lower ()]
123+ self .fg_ratio = params ['RandomCrop_foreground_ratio' .lower ()]
124+ self .mask_label = params ['RandomCrop_mask_label' .lower ()]
125+ self .inverse = params ['RandomCrop_inverse' .lower ()]
126+ assert isinstance (self .output_size , (list , tuple ))
127+ if (self .mask_label is not None ):
128+ assert isinstance (self .mask_label , (list , tuple ))
129+
130+ def __call__ (self , sample ):
131+ image = sample ['image' ]
132+ input_shape = image .shape
133+ input_dim = len (input_shape ) - 1
134+
135+ assert (input_dim == len (self .output_size ))
136+ crop_margin = [input_shape [i + 1 ] - self .output_size [i ]\
137+ for i in range (input_dim )]
138+ crop_min = [random .randint (0 , item ) for item in crop_margin ]
139+ if (self .fg_focus and random .random () < self .fg_ratio ):
140+ label = sample ['label' ]
141+ mask = np .zeros_like (label )
142+ for temp_lab in self .mask_label :
143+ mask = np .maximum (mask , label == temp_lab )
144+ if (mask .sum () == 0 ):
145+ bb_min = [0 ] * (input_dim + 1 )
146+ bb_max = mask .shape
147+ else :
148+ bb_min , bb_max = get_ND_bounding_box (mask )
149+ bb_min , bb_max = bb_min [1 :], bb_max [1 :]
150+ crop_min = [random .randint (bb_min [i ], bb_max [i ]) - int (self .output_size [i ]/ 2 ) \
151+ for i in range (input_dim )]
152+ crop_min = [max (0 , item ) for item in crop_min ]
153+ crop_min = [min (crop_min [i ], input_shape [i + 1 ] - self .output_size [i ]) \
154+ for i in range (input_dim )]
155+
156+ crop_max = [crop_min [i ] + self .output_size [i ] \
157+ for i in range (input_dim )]
158+ crop_min = [0 ] + crop_min
159+ crop_max = list (input_shape [0 :1 ]) + crop_max
160+ sample ['RandomCrop_Param' ] = json .dumps ((input_shape , crop_min , crop_max ))
161+
162+ image_t = crop_ND_volume_with_bounding_box (image , crop_min , crop_max )
163+ sample ['image' ] = image_t
164+
165+ if ('label' in sample and sample ['label' ].shape [1 :] == image .shape [1 :]):
166+ label = sample ['label' ]
167+ crop_max [0 ] = label .shape [0 ]
168+ label = crop_ND_volume_with_bounding_box (label , crop_min , crop_max )
169+ sample ['label' ] = label
170+ if ('weight' in sample and sample ['weight' ].shape [1 :] == image .shape [1 :]):
171+ weight = sample ['weight' ]
172+ crop_max [0 ] = weight .shape [0 ]
173+ weight = crop_ND_volume_with_bounding_box (weight , crop_min , crop_max )
174+ sample ['weight' ] = weight
175+ return sample
176+
177+ def inverse_transform_for_prediction (self , sample ):
178+ ''' rescale sample['predict'] (5D or 4D) to the original spatial shape.
179+ assume batch size is 1, otherwise scale may be different for
180+ different elemenets in the batch.
181+
182+ origin_shape is a 4D or 3D vector as saved in __call__().'''
183+ if (isinstance (sample ['RandomCrop_Param' ], list ) or \
184+ isinstance (sample ['RandomCrop_Param' ], tuple )):
185+ params = json .loads (sample ['RandomCrop_Param' ][0 ])
186+ else :
187+ params = json .loads (sample ['RandomCrop_Param' ])
188+ origin_shape = params [0 ]
189+ crop_min = params [1 ]
190+ crop_max = params [2 ]
191+ predict = sample ['predict' ]
192+ origin_shape = list (predict .shape [:2 ]) + origin_shape [1 :]
193+ output_predict = np .zeros (origin_shape , predict .dtype )
194+ crop_min = [0 , 0 ] + crop_min [1 :]
195+ crop_max = list (predict .shape [:2 ]) + crop_max [1 :]
196+ output_predict = set_ND_volume_roi_with_bounding_box_range (output_predict ,
197+ crop_min , crop_max , predict )
198+
199+ sample ['predict' ] = output_predict
200+ return sample
0 commit comments