Skip to content

Commit 7b9ab6a

Browse files
authored
Fix: UpSampling2D bilinear set_image_data_format(channels_first) bug (#21456)
* The current approach transforms the tensor to channels_last, before passing it in ops.image.resize, which has been defined as channels_first if keras.backend.set_image_data_format has been called on channelsfirst in user's code. This creates a bug, a passed in tensor [16, 3, 224, 224] will return as shape [16, 448, 224, 448] instead of [16, 3, 448, 448]. Setting the data_format as the expected channels_last fixes that issue. * ruff format corrections * Added setup and teardown to ensure tests are self contained, backend returns to original state after each test
1 parent 90c8da6 commit 7b9ab6a

File tree

2 files changed

+24
-2
lines changed

2 files changed

+24
-2
lines changed

keras/src/layers/reshaping/up_sampling2d.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,12 @@ def _resize_images(
163163
shape[1] * height_factor,
164164
shape[2] * width_factor,
165165
)
166-
x = ops.image.resize(x, new_shape, interpolation=interpolation)
166+
x = ops.image.resize(
167+
x,
168+
new_shape,
169+
data_format="channels_last",
170+
interpolation=interpolation,
171+
)
167172
if data_format == "channels_first":
168173
x = ops.transpose(x, [0, 3, 1, 2])
169174

keras/src/layers/reshaping/up_sampling2d_test.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,18 @@
66
from keras.src import backend
77
from keras.src import layers
88
from keras.src import testing
9+
from keras.backend import set_image_data_format
910

1011

1112
class UpSampling2dTest(testing.TestCase):
13+
@classmethod
14+
def setUpClass(cls):
15+
cls.original_image_data_format = backend.image_data_format()
16+
17+
@classmethod
18+
def tearDownClass(cls):
19+
backend.set_image_data_format(cls.original_image_data_format)
20+
1221
@parameterized.product(
1322
data_format=["channels_first", "channels_last"],
1423
length_row=[2],
@@ -62,15 +71,22 @@ def test_upsampling_2d(self, data_format, length_row, length_col):
6271

6372
@parameterized.product(
6473
data_format=["channels_first", "channels_last"],
74+
use_set_image_data_format=[True, False],
6575
length_row=[2],
6676
length_col=[2, 3],
6777
)
6878
@pytest.mark.requires_trainable_backend
69-
def test_upsampling_2d_bilinear(self, data_format, length_row, length_col):
79+
def test_upsampling_2d_bilinear(
80+
self, data_format, use_set_image_data_format, length_row, length_col
81+
):
7082
num_samples = 2
7183
stack_size = 2
7284
input_num_row = 11
7385
input_num_col = 12
86+
87+
if use_set_image_data_format:
88+
set_image_data_format(data_format)
89+
7490
if data_format == "channels_first":
7591
inputs = np.random.rand(
7692
num_samples, stack_size, input_num_row, input_num_col
@@ -93,6 +109,7 @@ def test_upsampling_2d_bilinear(self, data_format, length_row, length_col):
93109
layer = layers.UpSampling2D(
94110
size=(length_row, length_col),
95111
data_format=data_format,
112+
interpolation="bilinear",
96113
)
97114
layer.build(inputs.shape)
98115
np_output = layer(inputs=backend.Variable(inputs))

0 commit comments

Comments
 (0)