Skip to content

Commit 2b65f89

Browse files
committed
Simplify save_img: remove _format, normalize jpg→jpeg, add RGBA→RGB handling and tests
1 parent 45b1039 commit 2b65f89

File tree

2 files changed

+32
-2
lines changed

2 files changed

+32
-2
lines changed

integration_tests/test_save_img.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
import os
2+
3+
import numpy as np
4+
import pytest
5+
6+
from keras.utils import img_to_array
7+
from keras.utils import load_img
8+
from keras.utils import save_img
9+
10+
11+
@pytest.mark.parametrize(
12+
"shape, name",
13+
[
14+
((50, 50, 3), "rgb.jpg"),
15+
((50, 50, 4), "rgba.jpg"),
16+
],
17+
)
18+
def test_save_jpg(tmp_path, shape, name):
19+
img = np.random.randint(0, 256, size=shape, dtype=np.uint8)
20+
path = tmp_path / name
21+
save_img(path, img, file_format="jpg")
22+
assert os.path.exists(path)
23+
24+
# Check that the image was saved correctly and converted to RGB if needed.
25+
loaded_img = load_img(path)
26+
loaded_array = img_to_array(loaded_img)
27+
assert loaded_array.shape == (50, 50, 3)

keras/src/utils/image_utils.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -175,10 +175,13 @@ def save_img(path, x, data_format=None, file_format=None, scale=True, **kwargs):
175175
**kwargs: Additional keyword arguments passed to `PIL.Image.save()`.
176176
"""
177177
data_format = backend.standardize_data_format(data_format)
178+
# Normalize jpg → jpeg
179+
if file_format is not None and file_format.lower() == "jpg":
180+
file_format = "jpeg"
178181
img = array_to_img(x, data_format=data_format, scale=scale)
179-
if img.mode == "RGBA" and (file_format == "jpg" or file_format == "jpeg"):
182+
if img.mode == "RGBA" and file_format == "jpeg":
180183
warnings.warn(
181-
"The JPG format does not support RGBA images, converting to RGB."
184+
"The JPEG format does not support RGBA images, converting to RGB."
182185
)
183186
img = img.convert("RGB")
184187
img.save(path, format=file_format, **kwargs)

0 commit comments

Comments
 (0)