Skip to content

Commit b564af7

Browse files
authored
Don't duplicate data when encoding audio or image (#4187)
* don't duplicate data in audio * don't duplicate data in image * one more comment
1 parent 966d3bc commit b564af7

File tree

2 files changed

+15
-2
lines changed

2 files changed

+15
-2
lines changed

src/datasets/features/audio.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import os
12
from dataclasses import dataclass, field
23
from io import BytesIO
34
from typing import TYPE_CHECKING, Any, ClassVar, Dict, Optional, Union
@@ -70,11 +71,16 @@ def encode_example(self, value: Union[str, dict]) -> dict:
7071
raise ImportError("To support encoding audio data, please install 'soundfile'.") from err
7172
if isinstance(value, str):
7273
return {"bytes": None, "path": value}
73-
elif isinstance(value, dict) and "array" in value:
74+
elif "array" in value:
75+
# convert the audio array to wav bytes
7476
buffer = BytesIO()
7577
sf.write(buffer, value["array"], value["sampling_rate"], format="wav")
76-
return {"bytes": buffer.getvalue(), "path": value.get("path")}
78+
return {"bytes": buffer.getvalue(), "path": None}
79+
elif value.get("path") is not None and os.path.isfile(value["path"]):
80+
# we set "bytes": None to not duplicate the data if they're already available locally
81+
return {"bytes": None, "path": value.get("path")}
7782
elif value.get("bytes") is not None or value.get("path") is not None:
83+
# store the audio bytes, and path is used to infer the audio format using the file extension
7884
return {"bytes": value.get("bytes"), "path": value.get("path")}
7985
else:
8086
raise ValueError(

src/datasets/features/image.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import os
12
from dataclasses import dataclass, field
23
from io import BytesIO
34
from typing import TYPE_CHECKING, Any, ClassVar, Dict, List, Optional, Union
@@ -69,11 +70,17 @@ def encode_example(self, value: Union[str, dict, np.ndarray, "PIL.Image.Image"])
6970
if isinstance(value, str):
7071
return {"path": value, "bytes": None}
7172
elif isinstance(value, np.ndarray):
73+
# convert the image array to png bytes
7274
image = PIL.Image.fromarray(value.astype(np.uint8))
7375
return {"path": None, "bytes": image_to_bytes(image)}
7476
elif isinstance(value, PIL.Image.Image):
77+
# convert the PIL image to bytes (default format is png)
7578
return encode_pil_image(value)
79+
elif value.get("path") is not None and os.path.isfile(value["path"]):
80+
# we set "bytes": None to not duplicate the data if they're already available locally
81+
return {"bytes": None, "path": value.get("path")}
7682
elif value.get("bytes") is not None or value.get("path") is not None:
83+
# store the image bytes, and path is used to infer the image format using the file extension
7784
return {"bytes": value.get("bytes"), "path": value.get("path")}
7885
else:
7986
raise ValueError(

0 commit comments

Comments
 (0)