Skip to content

Commit aa7f2a9

Browse files
authored
Fix polars cast column image (#7800)
* fix polars cast_column issue * remove debug statements * cast large_strings to string for image handling
1 parent 63c933a commit aa7f2a9

File tree

3 files changed

+23
-1
lines changed

3 files changed

+23
-1
lines changed

src/datasets/features/image.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,7 @@ def cast_storage(self, storage: Union[pa.StringArray, pa.StructArray, pa.ListArr
215215
The Arrow types that can be converted to the Image pyarrow storage type are:
216216
217217
- `pa.string()` - it must contain the "path" data
218+
- `pa.large_string()` - it must contain the "path" data (will be cast to string if possible)
218219
- `pa.binary()` - it must contain the image bytes
219220
- `pa.struct({"bytes": pa.binary()})`
220221
- `pa.struct({"path": pa.string()})`
@@ -229,6 +230,15 @@ def cast_storage(self, storage: Union[pa.StringArray, pa.StructArray, pa.ListArr
229230
`pa.StructArray`: Array in the Image arrow storage type, that is
230231
`pa.struct({"bytes": pa.binary(), "path": pa.string()})`.
231232
"""
233+
if pa.types.is_large_string(storage.type):
234+
try:
235+
storage = storage.cast(pa.string())
236+
except pa.ArrowInvalid as e:
237+
raise ValueError(
238+
f"Failed to cast large_string to string for Image feature. "
239+
f"This can happen if string values exceed 2GB. "
240+
f"Original error: {e}"
241+
) from e
232242
if pa.types.is_string(storage.type):
233243
bytes_array = pa.array([None] * len(storage), type=pa.binary())
234244
storage = pa.StructArray.from_arrays([bytes_array, storage], ["bytes", "path"], mask=storage.is_null())

tests/features/test_image.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,18 @@ def test_dataset_cast_to_image_features(shared_datadir, build_data):
320320
assert isinstance(item["image"], PIL.Image.Image)
321321

322322

323+
def test_dataset_cast_to_image_features_polars(shared_datadir):
324+
import PIL.Image
325+
326+
pl = pytest.importorskip("polars")
327+
image_path = str(shared_datadir / "test_image_rgb.jpg")
328+
df = pl.DataFrame({"image_path": [image_path]})
329+
dataset = Dataset.from_polars(df)
330+
item = dataset.cast_column("image_path", Image())[0]
331+
assert item.keys() == {"image_path"}
332+
assert isinstance(item["image_path"], PIL.Image.Image)
333+
334+
323335
@require_pil
324336
def test_dataset_concatenate_image_features(shared_datadir):
325337
# we use a different data structure between 1 and 2 to make sure they are compatible with each other

tests/test_download_manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ def test_download_manager_delete_extracted_files(xz_file):
131131
assert extracted_path == dl_manager.extracted_paths[xz_file]
132132
extracted_path = Path(extracted_path)
133133
parts = extracted_path.parts
134-
# import pdb; pdb.set_trace()
134+
135135
assert parts[-1] == hash_url_to_filename(str(xz_file), etag=None)
136136
assert parts[-2] == extracted_subdir
137137
assert extracted_path.exists()

0 commit comments

Comments
 (0)