Skip to content

Commit 84b531c

Browse files
authored
Fix up torch GPU failing test for mix up (#20666)
We need to make sure to use get any tensors places on cpu before using them in the tensorflow backend during preprocessing.
1 parent 9a3e173 commit 84b531c

File tree

1 file changed

+5
-0
lines changed
  • keras/src/layers/preprocessing/image_preprocessing

1 file changed

+5
-0
lines changed

keras/src/layers/preprocessing/image_preprocessing/mix_up.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from keras.src import ops
12
from keras.src.api_export import keras_export
23
from keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import ( # noqa: E501
34
BaseImagePreprocessingLayer,
@@ -114,6 +115,8 @@ def _mix_up_bounding_boxes(bounding_boxes, transformation):
114115
self.backend.set_backend("tensorflow")
115116

116117
permutation_order = transformation["permutation_order"]
118+
# Make sure we are on cpu for torch tensors.
119+
permutation_order = ops.convert_to_numpy(permutation_order)
117120

118121
boxes, labels = bounding_boxes["boxes"], bounding_boxes["labels"]
119122
boxes_for_mix_up = self.backend.numpy.take(
@@ -146,6 +149,8 @@ def transform_segmentation_masks(
146149
):
147150
def _mix_up_segmentation_masks(segmentation_masks, transformation):
148151
mix_weight = transformation["mix_weight"]
152+
# Make sure we are on cpu for torch tensors.
153+
mix_weight = ops.convert_to_numpy(mix_weight)
149154
permutation_order = transformation["permutation_order"]
150155
mix_weight = self.backend.numpy.reshape(mix_weight, [-1, 1, 1, 1])
151156
segmentation_masks_for_mix_up = self.backend.numpy.take(

0 commit comments

Comments
 (0)