@@ -122,59 +122,60 @@ def get_random_transformation(self, data, training=True, seed=None):
122
122
return h_start , w_start
123
123
124
124
def transform_images (self , images , transformation , training = True ):
125
- images = self .backend .cast (images , self .compute_dtype )
126
- crop_box_hstart , crop_box_wstart = transformation
127
- crop_height = self .height
128
- crop_width = self .width
125
+ if training :
126
+ images = self .backend .cast (images , self .compute_dtype )
127
+ crop_box_hstart , crop_box_wstart = transformation
128
+ crop_height = self .height
129
+ crop_width = self .width
129
130
130
- if self .data_format == "channels_last" :
131
- if len (images .shape ) == 4 :
132
- images = images [
133
- :,
134
- crop_box_hstart : crop_box_hstart + crop_height ,
135
- crop_box_wstart : crop_box_wstart + crop_width ,
136
- :,
137
- ]
138
- else :
139
- images = images [
140
- crop_box_hstart : crop_box_hstart + crop_height ,
141
- crop_box_wstart : crop_box_wstart + crop_width ,
142
- :,
143
- ]
144
- else :
145
- if len (images .shape ) == 4 :
146
- images = images [
147
- :,
148
- :,
149
- crop_box_hstart : crop_box_hstart + crop_height ,
150
- crop_box_wstart : crop_box_wstart + crop_width ,
151
- ]
131
+ if self .data_format == "channels_last" :
132
+ if len (images .shape ) == 4 :
133
+ images = images [
134
+ :,
135
+ crop_box_hstart : crop_box_hstart + crop_height ,
136
+ crop_box_wstart : crop_box_wstart + crop_width ,
137
+ :,
138
+ ]
139
+ else :
140
+ images = images [
141
+ crop_box_hstart : crop_box_hstart + crop_height ,
142
+ crop_box_wstart : crop_box_wstart + crop_width ,
143
+ :,
144
+ ]
152
145
else :
153
- images = images [
154
- :,
155
- crop_box_hstart : crop_box_hstart + crop_height ,
156
- crop_box_wstart : crop_box_wstart + crop_width ,
157
- ]
146
+ if len (images .shape ) == 4 :
147
+ images = images [
148
+ :,
149
+ :,
150
+ crop_box_hstart : crop_box_hstart + crop_height ,
151
+ crop_box_wstart : crop_box_wstart + crop_width ,
152
+ ]
153
+ else :
154
+ images = images [
155
+ :,
156
+ crop_box_hstart : crop_box_hstart + crop_height ,
157
+ crop_box_wstart : crop_box_wstart + crop_width ,
158
+ ]
158
159
159
- shape = self .backend .shape (images )
160
- new_height = shape [self .height_axis ]
161
- new_width = shape [self .width_axis ]
162
- if (
163
- not isinstance (new_height , int )
164
- or not isinstance (new_width , int )
165
- or new_height != self .height
166
- or new_width != self .width
167
- ):
168
- # Resize images if size mismatch or
169
- # if size mismatch cannot be determined
170
- # (in the case of a TF dynamic shape).
171
- images = self .backend .image .resize (
172
- images ,
173
- size = (self .height , self .width ),
174
- data_format = self .data_format ,
175
- )
176
- # Resize may have upcasted the outputs
177
- images = self .backend .cast (images , self .compute_dtype )
160
+ shape = self .backend .shape (images )
161
+ new_height = shape [self .height_axis ]
162
+ new_width = shape [self .width_axis ]
163
+ if (
164
+ not isinstance (new_height , int )
165
+ or not isinstance (new_width , int )
166
+ or new_height != self .height
167
+ or new_width != self .width
168
+ ):
169
+ # Resize images if size mismatch or
170
+ # if size mismatch cannot be determined
171
+ # (in the case of a TF dynamic shape).
172
+ images = self .backend .image .resize (
173
+ images ,
174
+ size = (self .height , self .width ),
175
+ data_format = self .data_format ,
176
+ )
177
+ # Resize may have upcasted the outputs
178
+ images = self .backend .cast (images , self .compute_dtype )
178
179
return images
179
180
180
181
def transform_labels (self , labels , transformation , training = True ):
@@ -197,56 +198,59 @@ def transform_bounding_boxes(
197
198
"labels": (num_boxes, num_classes),
198
199
}
199
200
"""
200
- h_start , w_start = transformation
201
- if not self .backend .is_tensor (bounding_boxes ["boxes" ]):
202
- bounding_boxes = densify_bounding_boxes (
203
- bounding_boxes , backend = self .backend
204
- )
205
- boxes = bounding_boxes ["boxes" ]
206
- # Convert to a standard xyxy as operations are done xyxy by default.
207
- boxes = convert_format (
208
- boxes = boxes ,
209
- source = self .bounding_box_format ,
210
- target = "xyxy" ,
211
- height = self .height ,
212
- width = self .width ,
213
- )
214
- h_start = self .backend .cast (h_start , boxes .dtype )
215
- w_start = self .backend .cast (w_start , boxes .dtype )
216
- if len (self .backend .shape (boxes )) == 3 :
217
- boxes = self .backend .numpy .stack (
218
- [
219
- self .backend .numpy .maximum (boxes [:, :, 0 ] - h_start , 0 ),
220
- self .backend .numpy .maximum (boxes [:, :, 1 ] - w_start , 0 ),
221
- self .backend .numpy .maximum (boxes [:, :, 2 ] - h_start , 0 ),
222
- self .backend .numpy .maximum (boxes [:, :, 3 ] - w_start , 0 ),
223
- ],
224
- axis = - 1 ,
225
- )
226
- else :
227
- boxes = self .backend .numpy .stack (
228
- [
229
- self .backend .numpy .maximum (boxes [:, 0 ] - h_start , 0 ),
230
- self .backend .numpy .maximum (boxes [:, 1 ] - w_start , 0 ),
231
- self .backend .numpy .maximum (boxes [:, 2 ] - h_start , 0 ),
232
- self .backend .numpy .maximum (boxes [:, 3 ] - w_start , 0 ),
233
- ],
234
- axis = - 1 ,
201
+
202
+ if training :
203
+ h_start , w_start = transformation
204
+ if not self .backend .is_tensor (bounding_boxes ["boxes" ]):
205
+ bounding_boxes = densify_bounding_boxes (
206
+ bounding_boxes , backend = self .backend
207
+ )
208
+ boxes = bounding_boxes ["boxes" ]
209
+ # Convert to a standard xyxy as operations are done xyxy by default.
210
+ boxes = convert_format (
211
+ boxes = boxes ,
212
+ source = self .bounding_box_format ,
213
+ target = "xyxy" ,
214
+ height = self .height ,
215
+ width = self .width ,
235
216
)
217
+ h_start = self .backend .cast (h_start , boxes .dtype )
218
+ w_start = self .backend .cast (w_start , boxes .dtype )
219
+ if len (self .backend .shape (boxes )) == 3 :
220
+ boxes = self .backend .numpy .stack (
221
+ [
222
+ self .backend .numpy .maximum (boxes [:, :, 0 ] - h_start , 0 ),
223
+ self .backend .numpy .maximum (boxes [:, :, 1 ] - w_start , 0 ),
224
+ self .backend .numpy .maximum (boxes [:, :, 2 ] - h_start , 0 ),
225
+ self .backend .numpy .maximum (boxes [:, :, 3 ] - w_start , 0 ),
226
+ ],
227
+ axis = - 1 ,
228
+ )
229
+ else :
230
+ boxes = self .backend .numpy .stack (
231
+ [
232
+ self .backend .numpy .maximum (boxes [:, 0 ] - h_start , 0 ),
233
+ self .backend .numpy .maximum (boxes [:, 1 ] - w_start , 0 ),
234
+ self .backend .numpy .maximum (boxes [:, 2 ] - h_start , 0 ),
235
+ self .backend .numpy .maximum (boxes [:, 3 ] - w_start , 0 ),
236
+ ],
237
+ axis = - 1 ,
238
+ )
236
239
237
- # Convert to user defined bounding box format
238
- boxes = convert_format (
239
- boxes = boxes ,
240
- source = "xyxy" ,
241
- target = self .bounding_box_format ,
242
- height = self .height ,
243
- width = self .width ,
244
- )
240
+ # Convert to user defined bounding box format
241
+ boxes = convert_format (
242
+ boxes = boxes ,
243
+ source = "xyxy" ,
244
+ target = self .bounding_box_format ,
245
+ height = self .height ,
246
+ width = self .width ,
247
+ )
245
248
246
- return {
247
- "boxes" : boxes ,
248
- "labels" : bounding_boxes ["labels" ],
249
- }
249
+ return {
250
+ "boxes" : boxes ,
251
+ "labels" : bounding_boxes ["labels" ],
252
+ }
253
+ return bounding_boxes
250
254
251
255
def transform_segmentation_masks (
252
256
self , segmentation_masks , transformation , training = True
0 commit comments