Skip to content

Commit c39a94f

Browse files
author
Tom McCormick
committed
Basic read/write support for ORC
1 parent ad8263b commit c39a94f

File tree

3 files changed

+165
-2
lines changed

3 files changed

+165
-2
lines changed

pyiceberg/io/pyarrow.py

Lines changed: 62 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363
import pyarrow.dataset as ds
6464
import pyarrow.lib
6565
import pyarrow.parquet as pq
66+
import pyarrow.orc as orc
6667
from pyarrow import ChunkedArray
6768
from pyarrow._s3fs import S3RetryStrategy
6869
from pyarrow.fs import (
@@ -974,6 +975,8 @@ def _expression_to_complementary_pyarrow(expr: BooleanExpression) -> pc.Expressi
974975
def _get_file_format(file_format: FileFormat, **kwargs: Dict[str, Any]) -> ds.FileFormat:
975976
if file_format == FileFormat.PARQUET:
976977
return ds.ParquetFileFormat(**kwargs)
978+
elif file_format == FileFormat.ORC:
979+
return ds.OrcFileFormat(**kwargs)
977980
else:
978981
raise ValueError(f"Unsupported file format: {file_format}")
979982

@@ -1450,7 +1453,13 @@ def _task_to_record_batches(
14501453
name_mapping: Optional[NameMapping] = None,
14511454
partition_spec: Optional[PartitionSpec] = None,
14521455
) -> Iterator[pa.RecordBatch]:
1453-
arrow_format = ds.ParquetFileFormat(pre_buffer=True, buffer_size=(ONE_MEGABYTE * 8))
1456+
if task.file.file_format == FileFormat.PARQUET:
1457+
arrow_format = ds.ParquetFileFormat(pre_buffer=True, buffer_size=(ONE_MEGABYTE * 8))
1458+
elif task.file.file_format == FileFormat.ORC:
1459+
arrow_format = ds.OrcFileFormat()
1460+
# arrow_format = ds.OrcFileFormat(pre_buffer=True, buffer_size=(ONE_MEGABYTE * 8))
1461+
else:
1462+
raise ValueError("Unsupported file format")
14541463
with io.new_input(task.file.file_path).open() as fin:
14551464
fragment = arrow_format.make_fragment(fin)
14561465
physical_schema = fragment.physical_schema
@@ -2512,9 +2521,60 @@ def write_parquet(task: WriteTask) -> DataFile:
25122521

25132522
return data_file
25142523

2524+
def write_orc(task: WriteTask) -> DataFile:
2525+
table_schema = table_metadata.schema()
2526+
if (sanitized_schema := sanitize_column_names(table_schema)) != table_schema:
2527+
file_schema = sanitized_schema
2528+
else:
2529+
file_schema = table_schema
2530+
2531+
downcast_ns_timestamp_to_us = Config().get_bool(DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE) or False
2532+
batches = [
2533+
_to_requested_schema(
2534+
requested_schema=file_schema,
2535+
file_schema=task.schema,
2536+
batch=batch,
2537+
downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us,
2538+
include_field_ids=True,
2539+
)
2540+
for batch in task.record_batches
2541+
]
2542+
arrow_table = pa.Table.from_batches(batches)
2543+
file_path = location_provider.new_data_location(
2544+
data_file_name=task.generate_data_file_filename("orc"),
2545+
partition_key=task.partition_key,
2546+
)
2547+
fo = io.new_output(file_path)
2548+
with fo.create(overwrite=True) as fos:
2549+
orc.write_table(arrow_table, fos)
2550+
# You may want to add statistics extraction here if needed
2551+
data_file = DataFile.from_args(
2552+
content=DataFileContent.DATA,
2553+
file_path=file_path,
2554+
file_format=FileFormat.ORC,
2555+
partition=task.partition_key.partition if task.partition_key else Record(),
2556+
file_size_in_bytes=len(fo),
2557+
sort_order_id=None,
2558+
spec_id=table_metadata.default_spec_id,
2559+
equality_ids=None,
2560+
key_metadata=None,
2561+
# statistics=... (if you implement ORC stats)
2562+
)
2563+
return data_file
2564+
25152565
executor = ExecutorFactory.get_or_create()
2516-
data_files = executor.map(write_parquet, tasks)
2566+
def dispatch(task: WriteTask) -> DataFile:
2567+
file_format = FileFormat(table_metadata.properties.get(
2568+
TableProperties.WRITE_FILE_FORMAT,
2569+
TableProperties.WRITE_FILE_FORMAT_DEFAULT))
2570+
if file_format == FileFormat.PARQUET:
2571+
return write_parquet(task)
2572+
elif file_format == FileFormat.ORC:
2573+
return write_orc(task)
2574+
else:
2575+
raise ValueError(f"Unsupported file format: {file_format}")
25172576

2577+
data_files = executor.map(dispatch, tasks)
25182578
return iter(data_files)
25192579

25202580

pyiceberg/table/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,9 @@ class UpsertResult:
161161

162162

163163
class TableProperties:
164+
WRITE_FILE_FORMAT = "write.format.default"
165+
WRITE_FILE_FORMAT_DEFAULT = "parquet"
166+
164167
PARQUET_ROW_GROUP_SIZE_BYTES = "write.parquet.row-group-size-bytes"
165168
PARQUET_ROW_GROUP_SIZE_BYTES_DEFAULT = 128 * 1024 * 1024 # 128 MB
166169

tests/io/test_pyarrow.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import pyarrow
2929
import pyarrow as pa
3030
import pyarrow.parquet as pq
31+
import pyarrow.orc as orc
3132
import pytest
3233
from packaging import version
3334
from pyarrow.fs import AwsDefaultS3RetryStrategy, FileType, LocalFileSystem, S3FileSystem
@@ -2638,3 +2639,102 @@ def test_retry_strategy_not_found() -> None:
26382639
io = PyArrowFileIO(properties={S3_RETRY_STRATEGY_IMPL: "pyiceberg.DoesNotExist"})
26392640
with pytest.warns(UserWarning, match="Could not initialize S3 retry strategy: pyiceberg.DoesNotExist"):
26402641
io.new_input("s3://bucket/path/to/file")
2642+
2643+
2644+
def test_write_and_read_orc(tmp_path):
2645+
# Create a simple Arrow table
2646+
data = pa.table({'a': [1, 2, 3], 'b': ['x', 'y', 'z']})
2647+
orc_path = tmp_path / 'test.orc'
2648+
orc.write_table(data, str(orc_path))
2649+
# Read it back
2650+
orc_file = orc.ORCFile(str(orc_path))
2651+
table_read = orc_file.read()
2652+
assert table_read.equals(data)
2653+
2654+
2655+
def test_orc_file_format_integration(tmp_path):
2656+
# This test mimics a minimal integration with PyIceberg's FileFormat enum and pyarrow.orc
2657+
from pyiceberg.manifest import FileFormat
2658+
import pyarrow.dataset as ds
2659+
data = pa.table({'a': [10, 20], 'b': ['foo', 'bar']})
2660+
orc_path = tmp_path / 'iceberg.orc'
2661+
orc.write_table(data, str(orc_path))
2662+
# Use PyArrow dataset API to read as ORC
2663+
dataset = ds.dataset(str(orc_path), format=ds.OrcFileFormat())
2664+
table_read = dataset.to_table()
2665+
assert table_read.equals(data)
2666+
2667+
2668+
def test_iceberg_write_and_read_orc(tmp_path):
2669+
"""
2670+
Integration test: Write and read ORC via Iceberg API.
2671+
To run just this test:
2672+
pytest tests/io/test_pyarrow.py -k test_iceberg_write_and_read_orc
2673+
"""
2674+
import pyarrow as pa
2675+
from pyiceberg.schema import Schema, NestedField
2676+
from pyiceberg.types import IntegerType, StringType
2677+
from pyiceberg.manifest import FileFormat, DataFileContent
2678+
from pyiceberg.table.metadata import TableMetadataV2
2679+
from pyiceberg.partitioning import PartitionSpec
2680+
from pyiceberg.io.pyarrow import write_file, PyArrowFileIO, ArrowScan
2681+
from pyiceberg.table import WriteTask, FileScanTask
2682+
import uuid
2683+
2684+
# Define schema and data
2685+
schema = Schema(
2686+
NestedField(1, "id", IntegerType(), required=True),
2687+
NestedField(2, "name", StringType(), required=False),
2688+
)
2689+
data = pa.table({"id": pa.array([1, 2, 3], type=pa.int32()), "name": ["a", "b", "c"]})
2690+
2691+
# Create table metadata
2692+
table_metadata = TableMetadataV2(
2693+
location=str(tmp_path),
2694+
last_column_id=2,
2695+
format_version=2,
2696+
schemas=[schema],
2697+
partition_specs=[PartitionSpec()],
2698+
properties={
2699+
"write.format.default": "orc",
2700+
}
2701+
)
2702+
io = PyArrowFileIO()
2703+
2704+
# Write ORC file using Iceberg API
2705+
write_uuid = uuid.uuid4()
2706+
tasks = [
2707+
WriteTask(
2708+
write_uuid=write_uuid,
2709+
task_id=0,
2710+
record_batches=data.to_batches(),
2711+
schema=schema,
2712+
)
2713+
]
2714+
data_files = list(write_file(io, table_metadata, iter(tasks)))
2715+
assert len(data_files) == 1
2716+
data_file = data_files[0]
2717+
assert data_file.file_format == FileFormat.ORC
2718+
assert data_file.content == DataFileContent.DATA
2719+
2720+
# Read back using ArrowScan
2721+
scan = ArrowScan(
2722+
table_metadata=table_metadata,
2723+
io=io,
2724+
projected_schema=schema,
2725+
row_filter=AlwaysTrue(),
2726+
case_sensitive=True,
2727+
)
2728+
scan_task = FileScanTask(data_file=data_file)
2729+
table_read = scan.to_table([scan_task])
2730+
2731+
# Compare data ignoring schema metadata (like not null constraints)
2732+
assert table_read.num_rows == data.num_rows
2733+
assert table_read.num_columns == data.num_columns
2734+
assert table_read.column_names == data.column_names
2735+
2736+
# Compare actual column data values
2737+
for col_name in data.column_names:
2738+
original_values = data.column(col_name).to_pylist()
2739+
read_values = table_read.column(col_name).to_pylist()
2740+
assert original_values == read_values, f"Column {col_name} values don't match"

0 commit comments

Comments
 (0)