diff --git a/.gitignore b/.gitignore index afd700b49952..416f213f2c82 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,7 @@ __pycache__ **/.vscode test/** **/.vscode-smoke/** **/.venv*/ +venv bin/** build/** obj/** diff --git a/integration_tests/test_save_img.py b/integration_tests/test_save_img.py new file mode 100644 index 000000000000..6ec7951564cb --- /dev/null +++ b/integration_tests/test_save_img.py @@ -0,0 +1,27 @@ +import os + +import numpy as np +import pytest + +from keras.utils import img_to_array +from keras.utils import load_img +from keras.utils import save_img + + +@pytest.mark.parametrize( + "shape, name", + [ + ((50, 50, 3), "rgb.jpg"), + ((50, 50, 4), "rgba.jpg"), + ], +) +def test_save_jpg(tmp_path, shape, name): + img = np.random.randint(0, 256, size=shape, dtype=np.uint8) + path = tmp_path / name + save_img(path, img, file_format="jpg") + assert os.path.exists(path) + + # Check that the image was saved correctly and converted to RGB if needed. + loaded_img = load_img(path) + loaded_array = img_to_array(loaded_img) + assert loaded_array.shape == (50, 50, 3) diff --git a/keras/src/utils/image_utils.py b/keras/src/utils/image_utils.py index ca8289c9f9b7..4ff3e3454c6c 100644 --- a/keras/src/utils/image_utils.py +++ b/keras/src/utils/image_utils.py @@ -175,13 +175,28 @@ def save_img(path, x, data_format=None, file_format=None, scale=True, **kwargs): **kwargs: Additional keyword arguments passed to `PIL.Image.save()`. """ data_format = backend.standardize_data_format(data_format) + + # Determine the actual save format (normalize jpg → jpeg for internal use) + save_format = file_format + if file_format is not None and file_format.lower() == "jpg": + save_format = "jpeg" + elif file_format is None and isinstance(path, (str, pathlib.Path)): + # Infer format from file extension + ext = pathlib.Path(path).suffix[1:].lower() + if ext == "jpg": + save_format = "jpeg" + + # Convert array to PIL Image img = array_to_img(x, data_format=data_format, scale=scale) - if img.mode == "RGBA" and (file_format == "jpg" or file_format == "jpeg"): + + # Handle RGBA → RGB if saving to JPEG + if img.mode == "RGBA" and save_format == "jpeg": warnings.warn( - "The JPG format does not support RGBA images, converting to RGB." + "The JPEG format does not support RGBA images, converting to RGB." ) img = img.convert("RGB") - img.save(path, format=file_format, **kwargs) + + img.save(path, format=save_format, **kwargs) @keras_export(["keras.utils.load_img", "keras.preprocessing.image.load_img"])