-
Notifications
You must be signed in to change notification settings - Fork 19.6k
Fix: Support 'jpg' format in keras.utils.save_img() #21683
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from 6 commits
bf4e0ce
b2a5668
d93f3c5
9ab0b4f
b652534
4a6a78e
524252e
97d00f5
2a24eaa
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,6 +6,7 @@ __pycache__ | |
**/.vscode test/** | ||
**/.vscode-smoke/** | ||
**/.venv*/ | ||
venv | ||
bin/** | ||
build/** | ||
obj/** | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
import os | ||
|
||
import numpy as np | ||
import pytest | ||
|
||
from keras.utils import img_to_array | ||
from keras.utils import load_img | ||
from keras.utils import save_img | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"shape, name", | ||
[ | ||
((50, 50, 3), "rgb.jpg"), | ||
((50, 50, 4), "rgba.jpg"), | ||
], | ||
) | ||
def test_save_jpg(tmp_path, shape, name): | ||
img = np.random.randint(0, 256, size=shape, dtype=np.uint8) | ||
path = tmp_path / name | ||
save_img(path, img, file_format="jpg") | ||
assert os.path.exists(path) | ||
|
||
# Check that the image was saved correctly and converted to RGB if needed. | ||
loaded_img = load_img(path) | ||
loaded_array = img_to_array(loaded_img) | ||
assert loaded_array.shape == (50, 50, 3) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -175,13 +175,35 @@ def save_img(path, x, data_format=None, file_format=None, scale=True, **kwargs): | |
**kwargs: Additional keyword arguments passed to `PIL.Image.save()`. | ||
""" | ||
data_format = backend.standardize_data_format(data_format) | ||
|
||
# Determine output format | ||
_format = file_format | ||
if _format is None and isinstance(path, (str, pathlib.Path)): | ||
suffix = pathlib.Path(path).suffix.lower() | ||
if suffix.startswith("."): | ||
suffix = suffix[1:] | ||
_format = suffix | ||
|
||
# Normalize jpg → jpeg for both file_format and _format | ||
if file_format is not None and file_format.lower() == "jpg": | ||
file_format = "jpeg" | ||
if _format is not None and _format.lower() == "jpg": | ||
_format = "jpeg" | ||
|
||
# Convert array to PIL Image | ||
img = array_to_img(x, data_format=data_format, scale=scale) | ||
if img.mode == "RGBA" and (file_format == "jpg" or file_format == "jpeg"): | ||
|
||
# Handle RGBA → RGB if saving to JPEG | ||
if img.mode == "RGBA" and (_format in ("jpeg", "jpg")): | ||
|
||
warnings.warn( | ||
"The JPG format does not support RGBA images, converting to RGB." | ||
"The JPEG format does not support RGBA images, converting to RGB." | ||
) | ||
img = img.convert("RGB") | ||
img.save(path, format=file_format, **kwargs) | ||
|
||
# Finalize save format (explicit file_format wins) | ||
save_format = file_format if file_format is not None else _format | ||
|
||
|
||
img.save(path, format=save_format, **kwargs) | ||
|
||
|
||
@keras_export(["keras.utils.load_img", "keras.preprocessing.image.load_img"]) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This would be breaking