Skip to content
Open
15 changes: 15 additions & 0 deletions integration_tests/test_save_img.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import numpy as np
import os
from keras.utils import save_img

def test_save_jpg_rgb(tmp_path):
img = np.random.randint(0, 256, size=(50, 50, 3), dtype=np.uint8)
path = tmp_path / "rgb.jpg"
save_img(path, img, file_format="jpg")
assert os.path.exists(path)

def test_save_jpg_rgba(tmp_path):
img = np.random.randint(0, 256, size=(50, 50, 4), dtype=np.uint8)
path = tmp_path / "rgba.jpg"
save_img(path, img, file_format="jpg")
assert os.path.exists(path)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

These tests are a good start, but they can be improved. The two test functions are very similar and can be combined into a single parameterized test using pytest.mark.parametrize for better maintainability. Also, the assertions only check for file existence. It would be more robust to load the saved image and verify its content, especially to confirm that RGBA images are correctly converted to RGB.

import os

import numpy as np
import pytest
from keras.utils import img_to_array, load_img, save_img


@pytest.mark.parametrize(
    "shape, name",
    [
        ((50, 50, 3), "rgb.jpg"),
        ((50, 50, 4), "rgba.jpg"),
    ],
)
def test_save_jpg(tmp_path, shape, name):
    img = np.random.randint(0, 256, size=shape, dtype=np.uint8)
    path = tmp_path / name
    save_img(path, img, file_format="jpg")
    assert os.path.exists(path)

    # Check that the image was saved correctly and converted to RGB if needed.
    loaded_img = load_img(path)
    loaded_array = img_to_array(loaded_img)
    assert loaded_array.shape == (50, 50, 3)

4 changes: 3 additions & 1 deletion keras/src/utils/image_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,8 +175,10 @@ def save_img(path, x, data_format=None, file_format=None, scale=True, **kwargs):
**kwargs: Additional keyword arguments passed to `PIL.Image.save()`.
"""
data_format = backend.standardize_data_format(data_format)
if file_format is not None and file_format.lower() == 'jpg':
file_format = 'jpeg'
img = array_to_img(x, data_format=data_format, scale=scale)
if img.mode == "RGBA" and (file_format == "jpg" or file_format == "jpeg"):
if img.mode == "RGBA" and file_format in ["jpeg", "jpg"]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This change correctly handles file_format='jpg', but it misses the case where file_format is None and the path ends with .jpg. This will cause a crash when saving an RGBA image. The logic can be made more robust by also inspecting the path extension to determine if the output is a JPEG. Also, the condition on line 181 becomes partially redundant after converting jpg to jpeg.

Suggested change
if file_format is not None and file_format.lower() == 'jpg':
file_format = 'jpeg'
img = array_to_img(x, data_format=data_format, scale=scale)
if img.mode == "RGBA" and (file_format == "jpg" or file_format == "jpeg"):
if img.mode == "RGBA" and file_format in ["jpeg", "jpg"]:
_format = file_format
if _format is None and isinstance(path, (str, pathlib.Path)):
_format = pathlib.Path(path).suffix[1:].lower()
if file_format is not None and file_format.lower() == 'jpg':
file_format = 'jpeg'
img = array_to_img(x, data_format=data_format, scale=scale)
if img.mode == "RGBA" and _format in ("jpeg", "jpg"):

warnings.warn(
"The JPG format does not support RGBA images, converting to RGB."
)
Expand Down
Loading