Skip to content

Commit dd31bce

Browse files
authored
Bump pyarrow to 8.0.0 (#5620)
* bump pyarrow for pandas 2.0 * bump to 8.0.0 * remove all the pyarrow 7 related checks * update ci * minor * albert's comment
1 parent a0a35c5 commit dd31bce

File tree

9 files changed

+16
-68
lines changed

9 files changed

+16
-68
lines changed

.github/conda/meta.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ requirements:
1515
- python
1616
- pip
1717
- numpy >=1.17
18-
- pyarrow >=6.0.0
18+
- pyarrow >=8.0.0
1919
- python-xxhash
2020
- dill
2121
- pandas
@@ -32,7 +32,7 @@ requirements:
3232
- python
3333
- pip
3434
- numpy >=1.17
35-
- pyarrow >=6.0.0
35+
- pyarrow >=8.0.0
3636
- python-xxhash
3737
- dill
3838
- pandas

.github/workflows/benchmarks.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@ jobs:
1919
pip install setuptools wheel
2020
pip install -e .[benchmarks]
2121
22-
# pyarrow==6.0.0
23-
pip install pyarrow==6.0.0
22+
# pyarrow==8.0.0
23+
pip install pyarrow==8.0.0
2424
2525
dvc repro --force
2626
@@ -29,7 +29,7 @@ jobs:
2929
3030
python ./benchmarks/format.py report.json report.md
3131
32-
echo "<details>\n<summary>Show benchmarks</summary>\n\nPyArrow==6.0.0\n" > final_report.md
32+
echo "<details>\n<summary>Show benchmarks</summary>\n\nPyArrow==8.0.0\n" > final_report.md
3333
cat report.md >> final_report.md
3434
3535
# pyarrow

.github/workflows/ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ jobs:
6363
run: pip install --upgrade pyarrow huggingface-hub dill
6464
- name: Install depencencies (minimum versions)
6565
if: ${{ matrix.deps_versions != 'deps-latest' }}
66-
run: pip install pyarrow==6.0.1 huggingface-hub==0.2.0 transformers dill==0.3.1.1
66+
run: pip install pyarrow==8.0.0 huggingface-hub==0.2.0 transformers dill==0.3.1.1
6767
- name: Test with pytest
6868
run: |
6969
python -m pytest -rfExX -m ${{ matrix.test }} -n 2 --dist loadfile -sv ./tests/

setup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,8 +110,8 @@
110110
# We use numpy>=1.17 to have np.random.Generator (Dataset shuffling)
111111
"numpy>=1.17",
112112
# Backend and serialization.
113-
# Minimum 6.0.0 to support wrap_array which is needed for ArrayND features
114-
"pyarrow>=6.0.0",
113+
# Minimum 8.0.0 to be able to use .to_reader()
114+
"pyarrow>=8.0.0",
115115
# For smart caching dataset processing
116116
"dill>=0.3.0,<0.3.7", # tmp pin until next 0.3.7 release: see https://github.com/huggingface/datasets/pull/5166
117117
# For performance gains with apache arrow

src/datasets/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,9 @@
3030
"To use `datasets`, Python>=3.7 is required, and the current version of Python doesn't match this condition."
3131
)
3232

33-
if version.parse(pyarrow.__version__).major < 6:
33+
if version.parse(pyarrow.__version__).major < 8:
3434
raise ImportWarning(
35-
"To use `datasets`, the module `pyarrow>=6.0.0` is required, and the current version of `pyarrow` doesn't match this condition.\n"
35+
"To use `datasets`, the module `pyarrow>=8.0.0` is required, and the current version of `pyarrow` doesn't match this condition.\n"
3636
"If you are running this in a Google Colab, you should probably just restart the runtime to use the right version of `pyarrow`."
3737
)
3838

src/datasets/arrow_dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2292,7 +2292,7 @@ def iter(self, batch_size: int, drop_last_batch: bool = False):
22922292
drop_last_batch (:obj:`bool`, default `False`): Whether a last batch smaller than the batch_size should be
22932293
dropped
22942294
"""
2295-
if self._indices is None and config.PYARROW_VERSION.major >= 8:
2295+
if self._indices is None:
22962296
# Fast iteration
22972297
# Benchmark: https://gist.github.com/mariosasko/0248288a2e3a7556873969717c1fe52b (fast_iter_batch)
22982298
format_kwargs = self._format_kwargs if self._format_kwargs is not None else {}

src/datasets/packaged_modules/folder_based_builder/folder_based_builder.py

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -16,19 +16,6 @@
1616
logger = datasets.utils.logging.get_logger(__name__)
1717

1818

19-
if datasets.config.PYARROW_VERSION.major >= 7:
20-
21-
def pa_table_to_pylist(table):
22-
return table.to_pylist()
23-
24-
else:
25-
26-
def pa_table_to_pylist(table):
27-
keys = table.column_names
28-
values = table.to_pydict().values()
29-
return [{k: v for k, v in zip(keys, row_values)} for row_values in zip(*values)]
30-
31-
3219
def count_path_segments(path):
3320
return path.replace("\\", "/").count("/")
3421

@@ -310,7 +297,7 @@ def _generate_examples(self, files, metadata_files, split_name, add_metadata, ad
310297
metadata_dict = {
311298
os.path.normpath(file_name).replace("\\", "/"): sample_metadata
312299
for file_name, sample_metadata in zip(
313-
pa_file_name_array.to_pylist(), pa_table_to_pylist(pa_metadata_table)
300+
pa_file_name_array.to_pylist(), pa_metadata_table.to_pylist()
314301
)
315302
}
316303
else:
@@ -376,7 +363,7 @@ def _generate_examples(self, files, metadata_files, split_name, add_metadata, ad
376363
metadata_dict = {
377364
os.path.normpath(file_name).replace("\\", "/"): sample_metadata
378365
for file_name, sample_metadata in zip(
379-
pa_file_name_array.to_pylist(), pa_table_to_pylist(pa_metadata_table)
366+
pa_file_name_array.to_pylist(), pa_metadata_table.to_pylist()
380367
)
381368
}
382369
else:

src/datasets/packaged_modules/json/json.py

Lines changed: 1 addition & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -15,25 +15,6 @@
1515
logger = datasets.utils.logging.get_logger(__name__)
1616

1717

18-
if datasets.config.PYARROW_VERSION.major >= 7:
19-
20-
def pa_table_from_pylist(mapping):
21-
return pa.Table.from_pylist(mapping)
22-
23-
else:
24-
25-
def pa_table_from_pylist(mapping):
26-
# Copied from: https://github.com/apache/arrow/blob/master/python/pyarrow/table.pxi#L5193
27-
arrays = []
28-
names = []
29-
if mapping:
30-
names = list(mapping[0].keys())
31-
for n in names:
32-
v = [row[n] if n in row else None for row in mapping]
33-
arrays.append(v)
34-
return pa.Table.from_arrays(arrays, names)
35-
36-
3718
@dataclass
3819
class JsonConfig(datasets.BuilderConfig):
3920
"""BuilderConfig for JSON."""
@@ -156,7 +137,7 @@ def _generate_tables(self, files):
156137
# If possible, parse the file as a list of json objects and exit the loop
157138
if isinstance(dataset, list): # list is the only sequence type supported in JSON
158139
try:
159-
pa_table = pa_table_from_pylist(dataset)
140+
pa_table = pa.Table.from_pylist(dataset)
160141
except (pa.ArrowInvalid, AttributeError) as e:
161142
logger.error(f"Failed to read file '{file}' with error {type(e)}: {e}")
162143
raise ValueError(f"Not able to read records in the JSON file at {file}.") from None

src/datasets/table.py

Lines changed: 2 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -148,20 +148,6 @@ def fast_slice(self, offset=0, length=None) -> pa.Table:
148148
return pa.Table.from_batches(batches, schema=self._schema)
149149

150150

151-
class _RecordBatchReader:
152-
def __init__(self, table: "Table", max_chunksize: Optional[int] = None):
153-
self.table = table
154-
self.max_chunksize = max_chunksize
155-
156-
def __iter__(self):
157-
for batch in self.table._batches:
158-
if self.max_chunksize is None or len(batch) <= self.max_chunksize:
159-
yield batch
160-
else:
161-
for offset in range(0, len(batch), self.max_chunksize):
162-
yield batch.slice(offset, self.max_chunksize)
163-
164-
165151
class Table(IndexedTableMixin):
166152
"""
167153
Wraps a pyarrow Table by using composition.
@@ -359,10 +345,8 @@ def to_reader(self, max_chunksize: Optional[int] = None):
359345
on the chunk layout of individual columns.
360346
361347
Returns:
362-
`pyarrow.RecordBatchReader` if pyarrow>=8.0.0, otherwise a `pyarrow.RecordBatch` iterable
348+
`pyarrow.RecordBatchReader`
363349
"""
364-
if config.PYARROW_VERSION.major < 8:
365-
return _RecordBatchReader(self, max_chunksize=max_chunksize)
366350
return self.table.to_reader(max_chunksize=max_chunksize)
367351

368352
def field(self, *args, **kwargs):
@@ -816,11 +800,7 @@ def from_pylist(cls, mapping, *args, **kwargs):
816800
Returns:
817801
`datasets.table.Table`
818802
"""
819-
try:
820-
return cls(pa.Table.from_pylist(mapping, *args, **kwargs))
821-
except AttributeError: # pyarrow <7 does not have from_pylist, so we convert and use from_pydict
822-
mapping = {k: [r.get(k) for r in mapping] for k in mapping[0]} if mapping else {}
823-
return cls(pa.Table.from_pydict(mapping, *args, **kwargs))
803+
return cls(pa.Table.from_pylist(mapping, *args, **kwargs))
824804

825805
@classmethod
826806
def from_batches(cls, *args, **kwargs):

0 commit comments

Comments
 (0)