Skip to content
Open
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ __pycache__
**/.vscode test/**
**/.vscode-smoke/**
**/.venv*/
venv
bin/**
build/**
obj/**
Expand Down
27 changes: 27 additions & 0 deletions integration_tests/test_save_img.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import os

import numpy as np
import pytest

from keras.utils import img_to_array
from keras.utils import load_img
from keras.utils import save_img


@pytest.mark.parametrize(
"shape, name",
[
((50, 50, 3), "rgb.jpg"),
((50, 50, 4), "rgba.jpg"),
],
)
def test_save_jpg(tmp_path, shape, name):
img = np.random.randint(0, 256, size=shape, dtype=np.uint8)
path = tmp_path / name
save_img(path, img, file_format="jpg")
assert os.path.exists(path)

# Check that the image was saved correctly and converted to RGB if needed.
loaded_img = load_img(path)
loaded_array = img_to_array(loaded_img)
assert loaded_array.shape == (50, 50, 3)
28 changes: 25 additions & 3 deletions keras/src/utils/image_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,13 +175,35 @@ def save_img(path, x, data_format=None, file_format=None, scale=True, **kwargs):
**kwargs: Additional keyword arguments passed to `PIL.Image.save()`.
"""
data_format = backend.standardize_data_format(data_format)

# Determine output format
_format = file_format
if _format is None and isinstance(path, (str, pathlib.Path)):
suffix = pathlib.Path(path).suffix.lower()
if suffix.startswith("."):
suffix = suffix[1:]
_format = suffix

# Normalize jpg → jpeg for both file_format and _format
if file_format is not None and file_format.lower() == "jpg":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would be breaking

file_format = "jpeg"
if _format is not None and _format.lower() == "jpg":
_format = "jpeg"

# Convert array to PIL Image
img = array_to_img(x, data_format=data_format, scale=scale)
if img.mode == "RGBA" and (file_format == "jpg" or file_format == "jpeg"):

# Handle RGBA → RGB if saving to JPEG
if img.mode == "RGBA" and (_format in ("jpeg", "jpg")):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're checking _format, but you're no longer verifying file_format.

warnings.warn(
"The JPG format does not support RGBA images, converting to RGB."
"The JPEG format does not support RGBA images, converting to RGB."
)
img = img.convert("RGB")
img.save(path, format=file_format, **kwargs)

# Finalize save format (explicit file_format wins)
save_format = file_format if file_format is not None else _format
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be easier to do this line 186 instead of having to check both line 197 and lines 188-190.

in fact, you don't need _format at all, use file_format lines 181-185.


img.save(path, format=save_format, **kwargs)


@keras_export(["keras.utils.load_img", "keras.preprocessing.image.load_img"])
Expand Down