Skip to content

Commit 6aeb126

Browse files
authored
Fan out writing to multiple Parquet files (#444)
* bin pack write * add write target file size config * test * add test for multiple data files * parquet writer write once * parallelize write tasks * refactor * chunk correctly using to_batches * change variable names * get rid of assert * configure PackingIterator * add more tests * rewrite set_properties * set int property
1 parent 4c1cfdc commit 6aeb126

File tree

5 files changed

+192
-45
lines changed

5 files changed

+192
-45
lines changed

pyiceberg/io/pyarrow.py

Lines changed: 52 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1761,54 +1761,67 @@ def data_file_statistics_from_parquet_metadata(
17611761

17621762

17631763
def write_file(io: FileIO, table_metadata: TableMetadata, tasks: Iterator[WriteTask]) -> Iterator[DataFile]:
1764-
task = next(tasks)
1765-
1766-
try:
1767-
_ = next(tasks)
1768-
# If there are more tasks, raise an exception
1769-
raise NotImplementedError("Only unpartitioned writes are supported: https://github.com/apache/iceberg-python/issues/208")
1770-
except StopIteration:
1771-
pass
1772-
1773-
parquet_writer_kwargs = _get_parquet_writer_kwargs(table_metadata.properties)
1774-
1775-
file_path = f'{table_metadata.location}/data/{task.generate_data_file_filename("parquet")}'
17761764
schema = table_metadata.schema()
17771765
arrow_file_schema = schema.as_arrow()
1766+
parquet_writer_kwargs = _get_parquet_writer_kwargs(table_metadata.properties)
17781767

1779-
fo = io.new_output(file_path)
17801768
row_group_size = PropertyUtil.property_as_int(
17811769
properties=table_metadata.properties,
17821770
property_name=TableProperties.PARQUET_ROW_GROUP_SIZE_BYTES,
17831771
default=TableProperties.PARQUET_ROW_GROUP_SIZE_BYTES_DEFAULT,
17841772
)
1785-
with fo.create(overwrite=True) as fos:
1786-
with pq.ParquetWriter(fos, schema=arrow_file_schema, **parquet_writer_kwargs) as writer:
1787-
writer.write_table(task.df, row_group_size=row_group_size)
1788-
1789-
statistics = data_file_statistics_from_parquet_metadata(
1790-
parquet_metadata=writer.writer.metadata,
1791-
stats_columns=compute_statistics_plan(schema, table_metadata.properties),
1792-
parquet_column_mapping=parquet_path_to_id_mapping(schema),
1793-
)
1794-
data_file = DataFile(
1795-
content=DataFileContent.DATA,
1796-
file_path=file_path,
1797-
file_format=FileFormat.PARQUET,
1798-
partition=Record(),
1799-
file_size_in_bytes=len(fo),
1800-
# After this has been fixed:
1801-
# https://github.com/apache/iceberg-python/issues/271
1802-
# sort_order_id=task.sort_order_id,
1803-
sort_order_id=None,
1804-
# Just copy these from the table for now
1805-
spec_id=table_metadata.default_spec_id,
1806-
equality_ids=None,
1807-
key_metadata=None,
1808-
**statistics.to_serialized_dict(),
1809-
)
18101773

1811-
return iter([data_file])
1774+
def write_parquet(task: WriteTask) -> DataFile:
1775+
file_path = f'{table_metadata.location}/data/{task.generate_data_file_filename("parquet")}'
1776+
fo = io.new_output(file_path)
1777+
with fo.create(overwrite=True) as fos:
1778+
with pq.ParquetWriter(fos, schema=arrow_file_schema, **parquet_writer_kwargs) as writer:
1779+
writer.write(pa.Table.from_batches(task.record_batches), row_group_size=row_group_size)
1780+
1781+
statistics = data_file_statistics_from_parquet_metadata(
1782+
parquet_metadata=writer.writer.metadata,
1783+
stats_columns=compute_statistics_plan(schema, table_metadata.properties),
1784+
parquet_column_mapping=parquet_path_to_id_mapping(schema),
1785+
)
1786+
data_file = DataFile(
1787+
content=DataFileContent.DATA,
1788+
file_path=file_path,
1789+
file_format=FileFormat.PARQUET,
1790+
partition=Record(),
1791+
file_size_in_bytes=len(fo),
1792+
# After this has been fixed:
1793+
# https://github.com/apache/iceberg-python/issues/271
1794+
# sort_order_id=task.sort_order_id,
1795+
sort_order_id=None,
1796+
# Just copy these from the table for now
1797+
spec_id=table_metadata.default_spec_id,
1798+
equality_ids=None,
1799+
key_metadata=None,
1800+
**statistics.to_serialized_dict(),
1801+
)
1802+
1803+
return data_file
1804+
1805+
executor = ExecutorFactory.get_or_create()
1806+
data_files = executor.map(write_parquet, tasks)
1807+
1808+
return iter(data_files)
1809+
1810+
1811+
def bin_pack_arrow_table(tbl: pa.Table, target_file_size: int) -> Iterator[List[pa.RecordBatch]]:
1812+
from pyiceberg.utils.bin_packing import PackingIterator
1813+
1814+
avg_row_size_bytes = tbl.nbytes / tbl.num_rows
1815+
target_rows_per_file = target_file_size // avg_row_size_bytes
1816+
batches = tbl.to_batches(max_chunksize=target_rows_per_file)
1817+
bin_packed_record_batches = PackingIterator(
1818+
items=batches,
1819+
target_weight=target_file_size,
1820+
lookback=len(batches), # ignore lookback
1821+
weight_func=lambda x: x.nbytes,
1822+
largest_bin_first=False,
1823+
)
1824+
return bin_packed_record_batches
18121825

18131826

18141827
def parquet_files_to_data_files(io: FileIO, table_metadata: TableMetadata, file_paths: Iterator[str]) -> Iterator[DataFile]:

pyiceberg/table/__init__.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,9 @@ class TableProperties:
215215

216216
PARQUET_BLOOM_FILTER_COLUMN_ENABLED_PREFIX = "write.parquet.bloom-filter-enabled.column"
217217

218+
WRITE_TARGET_FILE_SIZE_BYTES = "write.target-file-size-bytes"
219+
WRITE_TARGET_FILE_SIZE_BYTES_DEFAULT = 512 * 1024 * 1024 # 512 MB
220+
218221
DEFAULT_WRITE_METRICS_MODE = "write.metadata.metrics.default"
219222
DEFAULT_WRITE_METRICS_MODE_DEFAULT = "truncate(16)"
220223

@@ -2486,7 +2489,7 @@ def _add_and_move_fields(
24862489
class WriteTask:
24872490
write_uuid: uuid.UUID
24882491
task_id: int
2489-
df: pa.Table
2492+
record_batches: List[pa.RecordBatch]
24902493
sort_order_id: Optional[int] = None
24912494

24922495
# Later to be extended with partition information
@@ -2521,17 +2524,27 @@ def _dataframe_to_data_files(
25212524
Returns:
25222525
An iterable that supplies datafiles that represent the table.
25232526
"""
2524-
from pyiceberg.io.pyarrow import write_file
2527+
from pyiceberg.io.pyarrow import bin_pack_arrow_table, write_file
25252528

25262529
if len([spec for spec in table_metadata.partition_specs if spec.spec_id != 0]) > 0:
25272530
raise ValueError("Cannot write to partitioned tables")
25282531

25292532
counter = itertools.count(0)
25302533
write_uuid = write_uuid or uuid.uuid4()
25312534

2535+
target_file_size = PropertyUtil.property_as_int(
2536+
properties=table_metadata.properties,
2537+
property_name=TableProperties.WRITE_TARGET_FILE_SIZE_BYTES,
2538+
default=TableProperties.WRITE_TARGET_FILE_SIZE_BYTES_DEFAULT,
2539+
)
2540+
25322541
# This is an iter, so we don't have to materialize everything every time
25332542
# This will be more relevant when we start doing partitioned writes
2534-
yield from write_file(io=io, table_metadata=table_metadata, tasks=iter([WriteTask(write_uuid, next(counter), df)]))
2543+
yield from write_file(
2544+
io=io,
2545+
table_metadata=table_metadata,
2546+
tasks=iter([WriteTask(write_uuid, next(counter), batches) for batches in bin_pack_arrow_table(df, target_file_size)]), # type: ignore
2547+
)
25352548

25362549

25372550
def _parquet_files_to_data_files(table_metadata: TableMetadata, file_paths: List[str], io: FileIO) -> Iterable[DataFile]:

tests/conftest.py

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
import socket
3131
import string
3232
import uuid
33-
from datetime import datetime
33+
from datetime import date, datetime
3434
from pathlib import Path
3535
from random import choice
3636
from tempfile import TemporaryDirectory
@@ -1987,3 +1987,60 @@ def spark() -> SparkSession:
19871987
)
19881988

19891989
return spark
1990+
1991+
1992+
TEST_DATA_WITH_NULL = {
1993+
'bool': [False, None, True],
1994+
'string': ['a', None, 'z'],
1995+
# Go over the 16 bytes to kick in truncation
1996+
'string_long': ['a' * 22, None, 'z' * 22],
1997+
'int': [1, None, 9],
1998+
'long': [1, None, 9],
1999+
'float': [0.0, None, 0.9],
2000+
'double': [0.0, None, 0.9],
2001+
'timestamp': [datetime(2023, 1, 1, 19, 25, 00), None, datetime(2023, 3, 1, 19, 25, 00)],
2002+
'timestamptz': [datetime(2023, 1, 1, 19, 25, 00), None, datetime(2023, 3, 1, 19, 25, 00)],
2003+
'date': [date(2023, 1, 1), None, date(2023, 3, 1)],
2004+
# Not supported by Spark
2005+
# 'time': [time(1, 22, 0), None, time(19, 25, 0)],
2006+
# Not natively supported by Arrow
2007+
# 'uuid': [uuid.UUID('00000000-0000-0000-0000-000000000000').bytes, None, uuid.UUID('11111111-1111-1111-1111-111111111111').bytes],
2008+
'binary': [b'\01', None, b'\22'],
2009+
'fixed': [
2010+
uuid.UUID('00000000-0000-0000-0000-000000000000').bytes,
2011+
None,
2012+
uuid.UUID('11111111-1111-1111-1111-111111111111').bytes,
2013+
],
2014+
}
2015+
2016+
2017+
@pytest.fixture(scope="session")
2018+
def pa_schema() -> "pa.Schema":
2019+
import pyarrow as pa
2020+
2021+
return pa.schema([
2022+
("bool", pa.bool_()),
2023+
("string", pa.string()),
2024+
("string_long", pa.string()),
2025+
("int", pa.int32()),
2026+
("long", pa.int64()),
2027+
("float", pa.float32()),
2028+
("double", pa.float64()),
2029+
("timestamp", pa.timestamp(unit="us")),
2030+
("timestamptz", pa.timestamp(unit="us", tz="UTC")),
2031+
("date", pa.date32()),
2032+
# Not supported by Spark
2033+
# ("time", pa.time64("us")),
2034+
# Not natively supported by Arrow
2035+
# ("uuid", pa.fixed(16)),
2036+
("binary", pa.large_binary()),
2037+
("fixed", pa.binary(16)),
2038+
])
2039+
2040+
2041+
@pytest.fixture(scope="session")
2042+
def arrow_table_with_null(pa_schema: "pa.Schema") -> "pa.Table":
2043+
import pyarrow as pa
2044+
2045+
"""PyArrow table with all kinds of columns"""
2046+
return pa.Table.from_pydict(TEST_DATA_WITH_NULL, schema=pa_schema)

tests/integration/test_writes.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
from pyiceberg.catalog.sql import SqlCatalog
3838
from pyiceberg.exceptions import NoSuchTableError
3939
from pyiceberg.schema import Schema
40-
from pyiceberg.table import Table, _dataframe_to_data_files
40+
from pyiceberg.table import Table, TableProperties, _dataframe_to_data_files
4141
from pyiceberg.typedef import Properties
4242
from pyiceberg.types import (
4343
BinaryType,
@@ -383,6 +383,47 @@ def get_current_snapshot_id(identifier: str) -> int:
383383
assert tbl.current_snapshot().snapshot_id == get_current_snapshot_id(identifier) # type: ignore
384384

385385

386+
@pytest.mark.integration
387+
def test_write_bin_pack_data_files(spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table) -> None:
388+
identifier = "default.write_bin_pack_data_files"
389+
tbl = _create_table(session_catalog, identifier, {"format-version": "1"}, [])
390+
391+
def get_data_files_count(identifier: str) -> int:
392+
return spark.sql(
393+
f"""
394+
SELECT *
395+
FROM {identifier}.files
396+
"""
397+
).count()
398+
399+
# writes 1 data file since the table is smaller than default target file size
400+
assert arrow_table_with_null.nbytes < TableProperties.WRITE_TARGET_FILE_SIZE_BYTES_DEFAULT
401+
tbl.overwrite(arrow_table_with_null)
402+
assert get_data_files_count(identifier) == 1
403+
404+
# writes 1 data file as long as table is smaller than default target file size
405+
bigger_arrow_tbl = pa.concat_tables([arrow_table_with_null] * 10)
406+
assert bigger_arrow_tbl.nbytes < TableProperties.WRITE_TARGET_FILE_SIZE_BYTES_DEFAULT
407+
tbl.overwrite(bigger_arrow_tbl)
408+
assert get_data_files_count(identifier) == 1
409+
410+
# writes multiple data files once target file size is overridden
411+
target_file_size = arrow_table_with_null.nbytes
412+
tbl = tbl.transaction().set_properties({TableProperties.WRITE_TARGET_FILE_SIZE_BYTES: target_file_size}).commit_transaction()
413+
assert str(target_file_size) == tbl.properties.get(TableProperties.WRITE_TARGET_FILE_SIZE_BYTES)
414+
assert target_file_size < bigger_arrow_tbl.nbytes
415+
tbl.overwrite(bigger_arrow_tbl)
416+
assert get_data_files_count(identifier) == 10
417+
418+
# writes half the number of data files when target file size doubles
419+
target_file_size = arrow_table_with_null.nbytes * 2
420+
tbl = tbl.transaction().set_properties({TableProperties.WRITE_TARGET_FILE_SIZE_BYTES: target_file_size}).commit_transaction()
421+
assert str(target_file_size) == tbl.properties.get(TableProperties.WRITE_TARGET_FILE_SIZE_BYTES)
422+
assert target_file_size < bigger_arrow_tbl.nbytes
423+
tbl.overwrite(bigger_arrow_tbl)
424+
assert get_data_files_count(identifier) == 5
425+
426+
386427
@pytest.mark.integration
387428
@pytest.mark.parametrize("format_version", [1, 2])
388429
@pytest.mark.parametrize(

tests/io/test_pyarrow.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,14 +64,15 @@
6464
_ConvertToArrowSchema,
6565
_primitive_to_physical,
6666
_read_deletes,
67+
bin_pack_arrow_table,
6768
expression_to_pyarrow,
6869
project_table,
6970
schema_to_pyarrow,
7071
)
7172
from pyiceberg.manifest import DataFile, DataFileContent, FileFormat
7273
from pyiceberg.partitioning import PartitionSpec
7374
from pyiceberg.schema import Schema, make_compatible_name, visit
74-
from pyiceberg.table import FileScanTask, Table
75+
from pyiceberg.table import FileScanTask, Table, TableProperties
7576
from pyiceberg.table.metadata import TableMetadataV2
7677
from pyiceberg.typedef import UTF8
7778
from pyiceberg.types import (
@@ -1710,3 +1711,25 @@ def test_stats_aggregator_update_max(vals: List[Any], primitive_type: PrimitiveT
17101711
stats.update_max(val)
17111712

17121713
assert stats.current_max == expected_result
1714+
1715+
1716+
def test_bin_pack_arrow_table(arrow_table_with_null: pa.Table) -> None:
1717+
# default packs to 1 bin since the table is small
1718+
bin_packed = bin_pack_arrow_table(
1719+
arrow_table_with_null, target_file_size=TableProperties.WRITE_TARGET_FILE_SIZE_BYTES_DEFAULT
1720+
)
1721+
assert len(list(bin_packed)) == 1
1722+
1723+
# as long as table is smaller than default target size, it should pack to 1 bin
1724+
bigger_arrow_tbl = pa.concat_tables([arrow_table_with_null] * 10)
1725+
assert bigger_arrow_tbl.nbytes < TableProperties.WRITE_TARGET_FILE_SIZE_BYTES_DEFAULT
1726+
bin_packed = bin_pack_arrow_table(bigger_arrow_tbl, target_file_size=TableProperties.WRITE_TARGET_FILE_SIZE_BYTES_DEFAULT)
1727+
assert len(list(bin_packed)) == 1
1728+
1729+
# unless we override the target size to be smaller
1730+
bin_packed = bin_pack_arrow_table(bigger_arrow_tbl, target_file_size=arrow_table_with_null.nbytes)
1731+
assert len(list(bin_packed)) == 10
1732+
1733+
# and will produce half the number of files if we double the target size
1734+
bin_packed = bin_pack_arrow_table(bigger_arrow_tbl, target_file_size=arrow_table_with_null.nbytes * 2)
1735+
assert len(list(bin_packed)) == 5

0 commit comments

Comments
 (0)