Skip to content
Open
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ __pycache__
**/.vscode test/**
**/.vscode-smoke/**
**/.venv*/
venv
bin/**
build/**
obj/**
Expand Down
27 changes: 27 additions & 0 deletions integration_tests/test_save_img.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import os

import numpy as np
import pytest

from keras.utils import img_to_array
from keras.utils import load_img
from keras.utils import save_img


@pytest.mark.parametrize(
"shape, name",
[
((50, 50, 3), "rgb.jpg"),
((50, 50, 4), "rgba.jpg"),
],
)
def test_save_jpg(tmp_path, shape, name):
img = np.random.randint(0, 256, size=shape, dtype=np.uint8)
path = tmp_path / name
save_img(path, img, file_format="jpg")
assert os.path.exists(path)

# Check that the image was saved correctly and converted to RGB if needed.
loaded_img = load_img(path)
loaded_array = img_to_array(loaded_img)
assert loaded_array.shape == (50, 50, 3)
12 changes: 10 additions & 2 deletions keras/src/utils/image_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,10 +175,18 @@ def save_img(path, x, data_format=None, file_format=None, scale=True, **kwargs):
**kwargs: Additional keyword arguments passed to `PIL.Image.save()`.
"""
data_format = backend.standardize_data_format(data_format)

# Normalize jpg → jpeg
if file_format is not None and file_format.lower() == "jpg":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would be breaking

file_format = "jpeg"

# Convert array to PIL Image
img = array_to_img(x, data_format=data_format, scale=scale)
if img.mode == "RGBA" and (file_format == "jpg" or file_format == "jpeg"):

# Handle RGBA → RGB if saving to JPEG
if img.mode == "RGBA" and file_format in ("jpeg", "jpg"):
warnings.warn(
"The JPG format does not support RGBA images, converting to RGB."
"The JPEG format does not support RGBA images, converting to RGB."
)
img = img.convert("RGB")
img.save(path, format=file_format, **kwargs)
Expand Down