Skip to content
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/datasets/arrow_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,7 +524,9 @@ def _build_schema(self, inferred_schema: pa.Schema):

def _build_writer(self, inferred_schema: pa.Schema):
self._schema, self._features = self._build_schema(inferred_schema)
self.pa_writer = pa.RecordBatchStreamWriter(self.stream, self._schema)
self.pa_writer = pa.RecordBatchStreamWriter(
self.stream, self._schema, options=pa.ipc.IpcWriteOptions(allow_64bit=True)
)

@property
def schema(self):
Expand Down
8 changes: 6 additions & 2 deletions src/datasets/features/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -1535,9 +1535,13 @@ def list_of_pa_arrays_to_pyarrow_listarray(l_arr: list[Optional[pa.Array]]) -> p
[0] + [len(arr) for arr in l_arr], dtype=object
) # convert to dtype object to allow None insertion
offsets = np.insert(offsets, null_indices, None)
offsets = pa.array(offsets, type=pa.int32())
values = pa.concat_arrays(l_arr)
return pa.ListArray.from_arrays(offsets, values)
try:
offsets = pa.array(offsets, type=pa.int32())
return pa.ListArray.from_arrays(offsets, values)
except pa.lib.ArrowInvalid:
offsets = pa.array(offsets, type=pa.int64())
return pa.LargeListArray.from_arrays(offsets, values)


def list_of_np_array_to_pyarrow_listarray(l_arr: list[np.ndarray], type: pa.DataType = None) -> pa.ListArray:
Expand Down
20 changes: 20 additions & 0 deletions tests/test_arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4783,3 +4783,23 @@ def test_from_polars_save_to_disk_and_load_from_disk_round_trip_with_large_list(
def test_polars_round_trip():
ds = Dataset.from_dict({"x": [[1, 2], [3, 4, 5]], "y": ["a", "b"]})
assert isinstance(Dataset.from_polars(ds.to_polars()), Dataset)


def test_map_int32_overflow():
# GH: 7821
def process_batch(batch):
res = []
for _ in batch["id"]:
res.append(np.zeros((2**31)).astype(np.uint16))

return {"audio": res}

ds = Dataset.from_dict({"id": [0]})
mapped_ds = ds.map(
process_batch,
batched=True,
batch_size=1,
num_proc=0,
remove_columns=ds.column_names,
)
assert isinstance(mapped_ds, Dataset)