Skip to content

Commit 19ad4ab

Browse files
author
Tom McCormick
committed
Basic read/write support for ORC
1 parent 8042d82 commit 19ad4ab

File tree

3 files changed

+165
-2
lines changed

3 files changed

+165
-2
lines changed

pyiceberg/io/pyarrow.py

Lines changed: 62 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363
import pyarrow.dataset as ds
6464
import pyarrow.lib
6565
import pyarrow.parquet as pq
66+
import pyarrow.orc as orc
6667
from pyarrow import ChunkedArray
6768
from pyarrow._s3fs import S3RetryStrategy
6869
from pyarrow.fs import (
@@ -973,6 +974,8 @@ def _expression_to_complementary_pyarrow(expr: BooleanExpression) -> pc.Expressi
973974
def _get_file_format(file_format: FileFormat, **kwargs: Dict[str, Any]) -> ds.FileFormat:
974975
if file_format == FileFormat.PARQUET:
975976
return ds.ParquetFileFormat(**kwargs)
977+
elif file_format == FileFormat.ORC:
978+
return ds.OrcFileFormat(**kwargs)
976979
else:
977980
raise ValueError(f"Unsupported file format: {file_format}")
978981

@@ -1431,7 +1434,13 @@ def _task_to_record_batches(
14311434
name_mapping: Optional[NameMapping] = None,
14321435
partition_spec: Optional[PartitionSpec] = None,
14331436
) -> Iterator[pa.RecordBatch]:
1434-
arrow_format = ds.ParquetFileFormat(pre_buffer=True, buffer_size=(ONE_MEGABYTE * 8))
1437+
if task.file.file_format == FileFormat.PARQUET:
1438+
arrow_format = ds.ParquetFileFormat(pre_buffer=True, buffer_size=(ONE_MEGABYTE * 8))
1439+
elif task.file.file_format == FileFormat.ORC:
1440+
arrow_format = ds.OrcFileFormat()
1441+
# arrow_format = ds.OrcFileFormat(pre_buffer=True, buffer_size=(ONE_MEGABYTE * 8))
1442+
else:
1443+
raise ValueError("Unsupported file format")
14351444
with io.new_input(task.file.file_path).open() as fin:
14361445
fragment = arrow_format.make_fragment(fin)
14371446
physical_schema = fragment.physical_schema
@@ -2498,9 +2507,60 @@ def write_parquet(task: WriteTask) -> DataFile:
24982507

24992508
return data_file
25002509

2510+
def write_orc(task: WriteTask) -> DataFile:
2511+
table_schema = table_metadata.schema()
2512+
if (sanitized_schema := sanitize_column_names(table_schema)) != table_schema:
2513+
file_schema = sanitized_schema
2514+
else:
2515+
file_schema = table_schema
2516+
2517+
downcast_ns_timestamp_to_us = Config().get_bool(DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE) or False
2518+
batches = [
2519+
_to_requested_schema(
2520+
requested_schema=file_schema,
2521+
file_schema=task.schema,
2522+
batch=batch,
2523+
downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us,
2524+
include_field_ids=True,
2525+
)
2526+
for batch in task.record_batches
2527+
]
2528+
arrow_table = pa.Table.from_batches(batches)
2529+
file_path = location_provider.new_data_location(
2530+
data_file_name=task.generate_data_file_filename("orc"),
2531+
partition_key=task.partition_key,
2532+
)
2533+
fo = io.new_output(file_path)
2534+
with fo.create(overwrite=True) as fos:
2535+
orc.write_table(arrow_table, fos)
2536+
# You may want to add statistics extraction here if needed
2537+
data_file = DataFile.from_args(
2538+
content=DataFileContent.DATA,
2539+
file_path=file_path,
2540+
file_format=FileFormat.ORC,
2541+
partition=task.partition_key.partition if task.partition_key else Record(),
2542+
file_size_in_bytes=len(fo),
2543+
sort_order_id=None,
2544+
spec_id=table_metadata.default_spec_id,
2545+
equality_ids=None,
2546+
key_metadata=None,
2547+
# statistics=... (if you implement ORC stats)
2548+
)
2549+
return data_file
2550+
25012551
executor = ExecutorFactory.get_or_create()
2502-
data_files = executor.map(write_parquet, tasks)
2552+
def dispatch(task: WriteTask) -> DataFile:
2553+
file_format = FileFormat(table_metadata.properties.get(
2554+
TableProperties.WRITE_FILE_FORMAT,
2555+
TableProperties.WRITE_FILE_FORMAT_DEFAULT))
2556+
if file_format == FileFormat.PARQUET:
2557+
return write_parquet(task)
2558+
elif file_format == FileFormat.ORC:
2559+
return write_orc(task)
2560+
else:
2561+
raise ValueError(f"Unsupported file format: {file_format}")
25032562

2563+
data_files = executor.map(dispatch, tasks)
25042564
return iter(data_files)
25052565

25062566

pyiceberg/table/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,9 @@ class UpsertResult:
161161

162162

163163
class TableProperties:
164+
WRITE_FILE_FORMAT = "write.format.default"
165+
WRITE_FILE_FORMAT_DEFAULT = "parquet"
166+
164167
PARQUET_ROW_GROUP_SIZE_BYTES = "write.parquet.row-group-size-bytes"
165168
PARQUET_ROW_GROUP_SIZE_BYTES_DEFAULT = 128 * 1024 * 1024 # 128 MB
166169

tests/io/test_pyarrow.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import pyarrow
2929
import pyarrow as pa
3030
import pyarrow.parquet as pq
31+
import pyarrow.orc as orc
3132
import pytest
3233
from packaging import version
3334
from pyarrow.fs import AwsDefaultS3RetryStrategy, FileType, LocalFileSystem, S3FileSystem
@@ -2654,3 +2655,102 @@ def test_retry_strategy_not_found() -> None:
26542655
io = PyArrowFileIO(properties={S3_RETRY_STRATEGY_IMPL: "pyiceberg.DoesNotExist"})
26552656
with pytest.warns(UserWarning, match="Could not initialize S3 retry strategy: pyiceberg.DoesNotExist"):
26562657
io.new_input("s3://bucket/path/to/file")
2658+
2659+
2660+
def test_write_and_read_orc(tmp_path):
2661+
# Create a simple Arrow table
2662+
data = pa.table({'a': [1, 2, 3], 'b': ['x', 'y', 'z']})
2663+
orc_path = tmp_path / 'test.orc'
2664+
orc.write_table(data, str(orc_path))
2665+
# Read it back
2666+
orc_file = orc.ORCFile(str(orc_path))
2667+
table_read = orc_file.read()
2668+
assert table_read.equals(data)
2669+
2670+
2671+
def test_orc_file_format_integration(tmp_path):
2672+
# This test mimics a minimal integration with PyIceberg's FileFormat enum and pyarrow.orc
2673+
from pyiceberg.manifest import FileFormat
2674+
import pyarrow.dataset as ds
2675+
data = pa.table({'a': [10, 20], 'b': ['foo', 'bar']})
2676+
orc_path = tmp_path / 'iceberg.orc'
2677+
orc.write_table(data, str(orc_path))
2678+
# Use PyArrow dataset API to read as ORC
2679+
dataset = ds.dataset(str(orc_path), format=ds.OrcFileFormat())
2680+
table_read = dataset.to_table()
2681+
assert table_read.equals(data)
2682+
2683+
2684+
def test_iceberg_write_and_read_orc(tmp_path):
2685+
"""
2686+
Integration test: Write and read ORC via Iceberg API.
2687+
To run just this test:
2688+
pytest tests/io/test_pyarrow.py -k test_iceberg_write_and_read_orc
2689+
"""
2690+
import pyarrow as pa
2691+
from pyiceberg.schema import Schema, NestedField
2692+
from pyiceberg.types import IntegerType, StringType
2693+
from pyiceberg.manifest import FileFormat, DataFileContent
2694+
from pyiceberg.table.metadata import TableMetadataV2
2695+
from pyiceberg.partitioning import PartitionSpec
2696+
from pyiceberg.io.pyarrow import write_file, PyArrowFileIO, ArrowScan
2697+
from pyiceberg.table import WriteTask, FileScanTask
2698+
import uuid
2699+
2700+
# Define schema and data
2701+
schema = Schema(
2702+
NestedField(1, "id", IntegerType(), required=True),
2703+
NestedField(2, "name", StringType(), required=False),
2704+
)
2705+
data = pa.table({"id": pa.array([1, 2, 3], type=pa.int32()), "name": ["a", "b", "c"]})
2706+
2707+
# Create table metadata
2708+
table_metadata = TableMetadataV2(
2709+
location=str(tmp_path),
2710+
last_column_id=2,
2711+
format_version=2,
2712+
schemas=[schema],
2713+
partition_specs=[PartitionSpec()],
2714+
properties={
2715+
"write.format.default": "orc",
2716+
}
2717+
)
2718+
io = PyArrowFileIO()
2719+
2720+
# Write ORC file using Iceberg API
2721+
write_uuid = uuid.uuid4()
2722+
tasks = [
2723+
WriteTask(
2724+
write_uuid=write_uuid,
2725+
task_id=0,
2726+
record_batches=data.to_batches(),
2727+
schema=schema,
2728+
)
2729+
]
2730+
data_files = list(write_file(io, table_metadata, iter(tasks)))
2731+
assert len(data_files) == 1
2732+
data_file = data_files[0]
2733+
assert data_file.file_format == FileFormat.ORC
2734+
assert data_file.content == DataFileContent.DATA
2735+
2736+
# Read back using ArrowScan
2737+
scan = ArrowScan(
2738+
table_metadata=table_metadata,
2739+
io=io,
2740+
projected_schema=schema,
2741+
row_filter=AlwaysTrue(),
2742+
case_sensitive=True,
2743+
)
2744+
scan_task = FileScanTask(data_file=data_file)
2745+
table_read = scan.to_table([scan_task])
2746+
2747+
# Compare data ignoring schema metadata (like not null constraints)
2748+
assert table_read.num_rows == data.num_rows
2749+
assert table_read.num_columns == data.num_columns
2750+
assert table_read.column_names == data.column_names
2751+
2752+
# Compare actual column data values
2753+
for col_name in data.column_names:
2754+
original_values = data.column(col_name).to_pylist()
2755+
read_values = table_read.column(col_name).to_pylist()
2756+
assert original_values == read_values, f"Column {col_name} values don't match"

0 commit comments

Comments
 (0)