Skip to content

Commit 3dd958b

Browse files
authored
Add training status condition during image processing (#20677)
* Add training status condition during image processing * Revert "Add training status condition during image processing" This reverts commit 8fc5ae2. * Reapply "Add training status condition during image processing" This reverts commit 25a4bd1. * Revert center_crop
1 parent 0d3ba37 commit 3dd958b

File tree

8 files changed

+361
-333
lines changed

8 files changed

+361
-333
lines changed

keras/src/layers/preprocessing/image_preprocessing/equalization.py

Lines changed: 33 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -170,40 +170,51 @@ def _apply_equalization(self, channel, hist):
170170
)
171171
return self.backend.numpy.take(lookup_table, indices)
172172

173-
def transform_images(self, images, transformations=None, **kwargs):
174-
images = self.backend.cast(images, self.compute_dtype)
175-
176-
if self.data_format == "channels_first":
177-
channels = []
178-
for i in range(self.backend.core.shape(images)[-3]):
179-
channel = images[..., i, :, :]
180-
equalized = self._equalize_channel(channel, self.value_range)
181-
channels.append(equalized)
182-
equalized_images = self.backend.numpy.stack(channels, axis=-3)
183-
else:
184-
channels = []
185-
for i in range(self.backend.core.shape(images)[-1]):
186-
channel = images[..., i]
187-
equalized = self._equalize_channel(channel, self.value_range)
188-
channels.append(equalized)
189-
equalized_images = self.backend.numpy.stack(channels, axis=-1)
190-
191-
return self.backend.cast(equalized_images, self.compute_dtype)
173+
def transform_images(self, images, transformation, training=True):
174+
if training:
175+
images = self.backend.cast(images, self.compute_dtype)
176+
177+
if self.data_format == "channels_first":
178+
channels = []
179+
for i in range(self.backend.core.shape(images)[-3]):
180+
channel = images[..., i, :, :]
181+
equalized = self._equalize_channel(
182+
channel, self.value_range
183+
)
184+
channels.append(equalized)
185+
equalized_images = self.backend.numpy.stack(channels, axis=-3)
186+
else:
187+
channels = []
188+
for i in range(self.backend.core.shape(images)[-1]):
189+
channel = images[..., i]
190+
equalized = self._equalize_channel(
191+
channel, self.value_range
192+
)
193+
channels.append(equalized)
194+
equalized_images = self.backend.numpy.stack(channels, axis=-1)
195+
196+
return self.backend.cast(equalized_images, self.compute_dtype)
197+
return images
192198

193199
def compute_output_shape(self, input_shape):
194200
return input_shape
195201

196202
def compute_output_spec(self, inputs, **kwargs):
197203
return inputs
198204

199-
def transform_bounding_boxes(self, bounding_boxes, **kwargs):
205+
def transform_bounding_boxes(
206+
self,
207+
bounding_boxes,
208+
transformation,
209+
training=True,
210+
):
200211
return bounding_boxes
201212

202-
def transform_labels(self, labels, transformations=None, **kwargs):
213+
def transform_labels(self, labels, transformation, training=True):
203214
return labels
204215

205216
def transform_segmentation_masks(
206-
self, segmentation_masks, transformations, **kwargs
217+
self, segmentation_masks, transformation, training=True
207218
):
208219
return segmentation_masks
209220

keras/src/layers/preprocessing/image_preprocessing/random_crop.py

Lines changed: 101 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -122,59 +122,60 @@ def get_random_transformation(self, data, training=True, seed=None):
122122
return h_start, w_start
123123

124124
def transform_images(self, images, transformation, training=True):
125-
images = self.backend.cast(images, self.compute_dtype)
126-
crop_box_hstart, crop_box_wstart = transformation
127-
crop_height = self.height
128-
crop_width = self.width
125+
if training:
126+
images = self.backend.cast(images, self.compute_dtype)
127+
crop_box_hstart, crop_box_wstart = transformation
128+
crop_height = self.height
129+
crop_width = self.width
129130

130-
if self.data_format == "channels_last":
131-
if len(images.shape) == 4:
132-
images = images[
133-
:,
134-
crop_box_hstart : crop_box_hstart + crop_height,
135-
crop_box_wstart : crop_box_wstart + crop_width,
136-
:,
137-
]
138-
else:
139-
images = images[
140-
crop_box_hstart : crop_box_hstart + crop_height,
141-
crop_box_wstart : crop_box_wstart + crop_width,
142-
:,
143-
]
144-
else:
145-
if len(images.shape) == 4:
146-
images = images[
147-
:,
148-
:,
149-
crop_box_hstart : crop_box_hstart + crop_height,
150-
crop_box_wstart : crop_box_wstart + crop_width,
151-
]
131+
if self.data_format == "channels_last":
132+
if len(images.shape) == 4:
133+
images = images[
134+
:,
135+
crop_box_hstart : crop_box_hstart + crop_height,
136+
crop_box_wstart : crop_box_wstart + crop_width,
137+
:,
138+
]
139+
else:
140+
images = images[
141+
crop_box_hstart : crop_box_hstart + crop_height,
142+
crop_box_wstart : crop_box_wstart + crop_width,
143+
:,
144+
]
152145
else:
153-
images = images[
154-
:,
155-
crop_box_hstart : crop_box_hstart + crop_height,
156-
crop_box_wstart : crop_box_wstart + crop_width,
157-
]
146+
if len(images.shape) == 4:
147+
images = images[
148+
:,
149+
:,
150+
crop_box_hstart : crop_box_hstart + crop_height,
151+
crop_box_wstart : crop_box_wstart + crop_width,
152+
]
153+
else:
154+
images = images[
155+
:,
156+
crop_box_hstart : crop_box_hstart + crop_height,
157+
crop_box_wstart : crop_box_wstart + crop_width,
158+
]
158159

159-
shape = self.backend.shape(images)
160-
new_height = shape[self.height_axis]
161-
new_width = shape[self.width_axis]
162-
if (
163-
not isinstance(new_height, int)
164-
or not isinstance(new_width, int)
165-
or new_height != self.height
166-
or new_width != self.width
167-
):
168-
# Resize images if size mismatch or
169-
# if size mismatch cannot be determined
170-
# (in the case of a TF dynamic shape).
171-
images = self.backend.image.resize(
172-
images,
173-
size=(self.height, self.width),
174-
data_format=self.data_format,
175-
)
176-
# Resize may have upcasted the outputs
177-
images = self.backend.cast(images, self.compute_dtype)
160+
shape = self.backend.shape(images)
161+
new_height = shape[self.height_axis]
162+
new_width = shape[self.width_axis]
163+
if (
164+
not isinstance(new_height, int)
165+
or not isinstance(new_width, int)
166+
or new_height != self.height
167+
or new_width != self.width
168+
):
169+
# Resize images if size mismatch or
170+
# if size mismatch cannot be determined
171+
# (in the case of a TF dynamic shape).
172+
images = self.backend.image.resize(
173+
images,
174+
size=(self.height, self.width),
175+
data_format=self.data_format,
176+
)
177+
# Resize may have upcasted the outputs
178+
images = self.backend.cast(images, self.compute_dtype)
178179
return images
179180

180181
def transform_labels(self, labels, transformation, training=True):
@@ -197,56 +198,59 @@ def transform_bounding_boxes(
197198
"labels": (num_boxes, num_classes),
198199
}
199200
"""
200-
h_start, w_start = transformation
201-
if not self.backend.is_tensor(bounding_boxes["boxes"]):
202-
bounding_boxes = densify_bounding_boxes(
203-
bounding_boxes, backend=self.backend
204-
)
205-
boxes = bounding_boxes["boxes"]
206-
# Convert to a standard xyxy as operations are done xyxy by default.
207-
boxes = convert_format(
208-
boxes=boxes,
209-
source=self.bounding_box_format,
210-
target="xyxy",
211-
height=self.height,
212-
width=self.width,
213-
)
214-
h_start = self.backend.cast(h_start, boxes.dtype)
215-
w_start = self.backend.cast(w_start, boxes.dtype)
216-
if len(self.backend.shape(boxes)) == 3:
217-
boxes = self.backend.numpy.stack(
218-
[
219-
self.backend.numpy.maximum(boxes[:, :, 0] - h_start, 0),
220-
self.backend.numpy.maximum(boxes[:, :, 1] - w_start, 0),
221-
self.backend.numpy.maximum(boxes[:, :, 2] - h_start, 0),
222-
self.backend.numpy.maximum(boxes[:, :, 3] - w_start, 0),
223-
],
224-
axis=-1,
225-
)
226-
else:
227-
boxes = self.backend.numpy.stack(
228-
[
229-
self.backend.numpy.maximum(boxes[:, 0] - h_start, 0),
230-
self.backend.numpy.maximum(boxes[:, 1] - w_start, 0),
231-
self.backend.numpy.maximum(boxes[:, 2] - h_start, 0),
232-
self.backend.numpy.maximum(boxes[:, 3] - w_start, 0),
233-
],
234-
axis=-1,
201+
202+
if training:
203+
h_start, w_start = transformation
204+
if not self.backend.is_tensor(bounding_boxes["boxes"]):
205+
bounding_boxes = densify_bounding_boxes(
206+
bounding_boxes, backend=self.backend
207+
)
208+
boxes = bounding_boxes["boxes"]
209+
# Convert to a standard xyxy as operations are done xyxy by default.
210+
boxes = convert_format(
211+
boxes=boxes,
212+
source=self.bounding_box_format,
213+
target="xyxy",
214+
height=self.height,
215+
width=self.width,
235216
)
217+
h_start = self.backend.cast(h_start, boxes.dtype)
218+
w_start = self.backend.cast(w_start, boxes.dtype)
219+
if len(self.backend.shape(boxes)) == 3:
220+
boxes = self.backend.numpy.stack(
221+
[
222+
self.backend.numpy.maximum(boxes[:, :, 0] - h_start, 0),
223+
self.backend.numpy.maximum(boxes[:, :, 1] - w_start, 0),
224+
self.backend.numpy.maximum(boxes[:, :, 2] - h_start, 0),
225+
self.backend.numpy.maximum(boxes[:, :, 3] - w_start, 0),
226+
],
227+
axis=-1,
228+
)
229+
else:
230+
boxes = self.backend.numpy.stack(
231+
[
232+
self.backend.numpy.maximum(boxes[:, 0] - h_start, 0),
233+
self.backend.numpy.maximum(boxes[:, 1] - w_start, 0),
234+
self.backend.numpy.maximum(boxes[:, 2] - h_start, 0),
235+
self.backend.numpy.maximum(boxes[:, 3] - w_start, 0),
236+
],
237+
axis=-1,
238+
)
236239

237-
# Convert to user defined bounding box format
238-
boxes = convert_format(
239-
boxes=boxes,
240-
source="xyxy",
241-
target=self.bounding_box_format,
242-
height=self.height,
243-
width=self.width,
244-
)
240+
# Convert to user defined bounding box format
241+
boxes = convert_format(
242+
boxes=boxes,
243+
source="xyxy",
244+
target=self.bounding_box_format,
245+
height=self.height,
246+
width=self.width,
247+
)
245248

246-
return {
247-
"boxes": boxes,
248-
"labels": bounding_boxes["labels"],
249-
}
249+
return {
250+
"boxes": boxes,
251+
"labels": bounding_boxes["labels"],
252+
}
253+
return bounding_boxes
250254

251255
def transform_segmentation_masks(
252256
self, segmentation_masks, transformation, training=True

keras/src/layers/preprocessing/image_preprocessing/random_flip.py

Lines changed: 37 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -101,9 +101,6 @@ def transform_bounding_boxes(
101101
transformation,
102102
training=True,
103103
):
104-
if backend_utils.in_tf_graph():
105-
self.backend.set_backend("tensorflow")
106-
107104
def _flip_boxes_horizontal(boxes):
108105
x1, x2, x3, x4 = self.backend.numpy.split(boxes, 4, axis=-1)
109106
outputs = self.backend.numpy.concatenate(
@@ -134,46 +131,50 @@ def _transform_xyxy(boxes, box_flips):
134131
)
135132
return bboxes
136133

137-
flips = self.backend.numpy.squeeze(transformation["flips"], axis=-1)
134+
if training:
135+
if backend_utils.in_tf_graph():
136+
self.backend.set_backend("tensorflow")
138137

139-
if self.data_format == "channels_first":
140-
height_axis = -2
141-
width_axis = -1
142-
else:
143-
height_axis = -3
144-
width_axis = -2
138+
flips = self.backend.numpy.squeeze(transformation["flips"], axis=-1)
145139

146-
input_height, input_width = (
147-
transformation["input_shape"][height_axis],
148-
transformation["input_shape"][width_axis],
149-
)
140+
if self.data_format == "channels_first":
141+
height_axis = -2
142+
width_axis = -1
143+
else:
144+
height_axis = -3
145+
width_axis = -2
150146

151-
bounding_boxes = convert_format(
152-
bounding_boxes,
153-
source=self.bounding_box_format,
154-
target="rel_xyxy",
155-
height=input_height,
156-
width=input_width,
157-
)
147+
input_height, input_width = (
148+
transformation["input_shape"][height_axis],
149+
transformation["input_shape"][width_axis],
150+
)
158151

159-
bounding_boxes["boxes"] = _transform_xyxy(bounding_boxes, flips)
152+
bounding_boxes = convert_format(
153+
bounding_boxes,
154+
source=self.bounding_box_format,
155+
target="rel_xyxy",
156+
height=input_height,
157+
width=input_width,
158+
)
160159

161-
bounding_boxes = clip_to_image_size(
162-
bounding_boxes=bounding_boxes,
163-
height=input_height,
164-
width=input_width,
165-
bounding_box_format="xyxy",
166-
)
160+
bounding_boxes["boxes"] = _transform_xyxy(bounding_boxes, flips)
167161

168-
bounding_boxes = convert_format(
169-
bounding_boxes,
170-
source="rel_xyxy",
171-
target=self.bounding_box_format,
172-
height=input_height,
173-
width=input_width,
174-
)
162+
bounding_boxes = clip_to_image_size(
163+
bounding_boxes=bounding_boxes,
164+
height=input_height,
165+
width=input_width,
166+
bounding_box_format="xyxy",
167+
)
168+
169+
bounding_boxes = convert_format(
170+
bounding_boxes,
171+
source="rel_xyxy",
172+
target=self.bounding_box_format,
173+
height=input_height,
174+
width=input_width,
175+
)
175176

176-
self.backend.reset()
177+
self.backend.reset()
177178

178179
return bounding_boxes
179180

0 commit comments

Comments
 (0)