Skip to content

Commit 322f474

Browse files
authored
JPEGSerializer: Fix serializer io.bytes image (#19369)
1 parent 10c3a71 commit 322f474

File tree

3 files changed

+35
-2
lines changed

3 files changed

+35
-2
lines changed

src/lightning/data/streaming/data_processor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,7 @@ def _download_data_target(input_dir: Dir, cache_dir: str, queue_in: Queue, queue
151151
s3.client.download_fileobj(obj.netloc, obj.path.lstrip("/"), f)
152152

153153
elif os.path.isfile(path):
154+
os.makedirs(os.path.dirname(local_path), exist_ok=True)
154155
shutil.copyfile(path, local_path)
155156
else:
156157
raise ValueError(f"The provided {input_dir.url} isn't supported.")

src/lightning/data/streaming/serializers.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# See the License for the specific language governing permissions and
1212
# limitations under the License.
1313

14+
import io
1415
import os
1516
import pickle
1617
import tempfile
@@ -109,8 +110,16 @@ def serialize(self, item: Image) -> Tuple[bytes, Optional[str]]:
109110
raise ValueError(
110111
"The JPEG Image's filename isn't defined. HINT: Open the image in your Dataset __getitem__ method."
111112
)
112-
with open(item.filename, "rb") as f:
113-
return f.read(), None
113+
if item.filename and os.path.exists(item.filename):
114+
# read the content of the file directly
115+
with open(item.filename, "rb") as f:
116+
return f.read(), None
117+
else:
118+
item_bytes = io.BytesIO()
119+
item.save(item_bytes, format="JPEG")
120+
item_bytes = item_bytes.getvalue()
121+
return item_bytes, None
122+
114123
raise TypeError(f"The provided itemect should be of type {JpegImageFile}. Found {item}.")
115124

116125
def deserialize(self, data: bytes) -> Union[JpegImageFile, torch.Tensor]:

tests/tests_data/streaming/test_serializer.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# See the License for the specific language governing permissions and
1212
# limitations under the License.
1313

14+
import io
1415
import os
1516
import sys
1617
from time import time
@@ -26,6 +27,7 @@
2627
_TORCH_DTYPES_MAPPING,
2728
_TORCH_VISION_AVAILABLE,
2829
IntSerializer,
30+
JPEGSerializer,
2931
NoHeaderNumpySerializer,
3032
NoHeaderTensorSerializer,
3133
NumpySerializer,
@@ -86,6 +88,27 @@ def test_pil_serializer(mode):
8688
assert np.array_equal(np_data, np_dec_data)
8789

8890

91+
@pytest.mark.skipif(condition=not _PIL_AVAILABLE, reason="Requires: ['pil']")
92+
def test_jpeg_serializer():
93+
serializer = JPEGSerializer()
94+
95+
from PIL import Image
96+
97+
array = np.random.randint(255, size=(28, 28, 3), dtype=np.uint8)
98+
img = Image.fromarray(array)
99+
img_bytes = io.BytesIO()
100+
img.save(img_bytes, format="JPEG")
101+
img_bytes = img_bytes.getvalue()
102+
103+
img = Image.open(io.BytesIO(img_bytes))
104+
105+
data, _ = serializer.serialize(img)
106+
assert isinstance(data, bytes)
107+
108+
deserialized_img = serializer.deserialize(data)
109+
assert deserialized_img.shape == torch.Size([3, 28, 28])
110+
111+
89112
@pytest.mark.flaky(reruns=3)
90113
@pytest.mark.skipif(sys.platform == "win32", reason="Not supported on windows")
91114
def test_tensor_serializer():

0 commit comments

Comments
 (0)