Skip to content

Commit 52ff684

Browse files
authored
V3: Fix invalid downcasting for nanos (apache#2397)
# Rationale for this change It looks like we downcast Arrow nanosecond types always to microseconds. cc @sungwy @kevinjqliu ## Are these changes tested? ## Are there any user-facing changes? <!-- In the case of user-facing changes, please add the changelog label. -->
1 parent 159b2f3 commit 52ff684

File tree

2 files changed

+82
-13
lines changed

2 files changed

+82
-13
lines changed

pyiceberg/io/pyarrow.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@
149149
visit,
150150
visit_with_partner,
151151
)
152-
from pyiceberg.table import TableProperties
152+
from pyiceberg.table import DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE, TableProperties
153153
from pyiceberg.table.locations import load_location_provider
154154
from pyiceberg.table.metadata import TableMetadata
155155
from pyiceberg.table.name_mapping import NameMapping, apply_name_mapping
@@ -1487,17 +1487,20 @@ def _task_to_record_batches(
14871487
name_mapping: Optional[NameMapping] = None,
14881488
partition_spec: Optional[PartitionSpec] = None,
14891489
format_version: TableVersion = TableProperties.DEFAULT_FORMAT_VERSION,
1490+
downcast_ns_timestamp_to_us: Optional[bool] = None,
14901491
) -> Iterator[pa.RecordBatch]:
14911492
arrow_format = ds.ParquetFileFormat(pre_buffer=True, buffer_size=(ONE_MEGABYTE * 8))
14921493
with io.new_input(task.file.file_path).open() as fin:
14931494
fragment = arrow_format.make_fragment(fin)
14941495
physical_schema = fragment.physical_schema
1495-
# In V1 and V2 table formats, we only support Timestamp 'us' in Iceberg Schema
1496-
# Hence it is reasonable to always cast 'ns' timestamp to 'us' on read.
1497-
# When V3 support is introduced, we will update `downcast_ns_timestamp_to_us` flag based on
1498-
# the table format version.
1496+
1497+
# For V1 and V2, we only support Timestamp 'us' in Iceberg Schema, therefore it is reasonable to always cast 'ns' timestamp to 'us' on read.
1498+
# For V3 this has to set explicitly to avoid nanosecond timestamp to be down-casted by default
1499+
downcast_ns_timestamp_to_us = (
1500+
downcast_ns_timestamp_to_us if downcast_ns_timestamp_to_us is not None else format_version <= 2
1501+
)
14991502
file_schema = pyarrow_to_schema(
1500-
physical_schema, name_mapping, downcast_ns_timestamp_to_us=True, format_version=format_version
1503+
physical_schema, name_mapping, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us, format_version=format_version
15011504
)
15021505

15031506
# Apply column projection rules: https://iceberg.apache.org/spec/#column-projection
@@ -1555,7 +1558,7 @@ def _task_to_record_batches(
15551558
projected_schema,
15561559
file_project_schema,
15571560
current_batch,
1558-
downcast_ns_timestamp_to_us=True,
1561+
downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us,
15591562
projected_missing_fields=projected_missing_fields,
15601563
)
15611564

@@ -1586,6 +1589,7 @@ class ArrowScan:
15861589
_bound_row_filter: BooleanExpression
15871590
_case_sensitive: bool
15881591
_limit: Optional[int]
1592+
_downcast_ns_timestamp_to_us: Optional[bool]
15891593
"""Scan the Iceberg Table and create an Arrow construct.
15901594
15911595
Attributes:
@@ -1612,6 +1616,7 @@ def __init__(
16121616
self._bound_row_filter = bind(table_metadata.schema(), row_filter, case_sensitive=case_sensitive)
16131617
self._case_sensitive = case_sensitive
16141618
self._limit = limit
1619+
self._downcast_ns_timestamp_to_us = Config().get_bool(DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE)
16151620

16161621
@property
16171622
def _projected_field_ids(self) -> Set[int]:
@@ -1728,6 +1733,7 @@ def _record_batches_from_scan_tasks_and_deletes(
17281733
self._table_metadata.name_mapping(),
17291734
self._table_metadata.specs().get(task.file.spec_id),
17301735
self._table_metadata.format_version,
1736+
self._downcast_ns_timestamp_to_us,
17311737
)
17321738
for batch in batches:
17331739
if self._limit is not None:

tests/io/test_pyarrow.py

Lines changed: 69 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import tempfile
2121
import uuid
2222
import warnings
23-
from datetime import date
23+
from datetime import date, datetime, timezone
2424
from typing import Any, List, Optional
2525
from unittest.mock import MagicMock, patch
2626
from uuid import uuid4
@@ -61,6 +61,7 @@
6161
from pyiceberg.io import S3_RETRY_STRATEGY_IMPL, InputStream, OutputStream, load_file_io
6262
from pyiceberg.io.pyarrow import (
6363
ICEBERG_SCHEMA,
64+
PYARROW_PARQUET_FIELD_ID_KEY,
6465
ArrowScan,
6566
PyArrowFile,
6667
PyArrowFileIO,
@@ -70,6 +71,7 @@
7071
_determine_partitions,
7172
_primitive_to_physical,
7273
_read_deletes,
74+
_task_to_record_batches,
7375
_to_requested_schema,
7476
bin_pack_arrow_table,
7577
compute_statistics_plan,
@@ -85,7 +87,7 @@
8587
from pyiceberg.table.metadata import TableMetadataV2
8688
from pyiceberg.table.name_mapping import create_mapping_from_schema
8789
from pyiceberg.transforms import HourTransform, IdentityTransform
88-
from pyiceberg.typedef import UTF8, Properties, Record
90+
from pyiceberg.typedef import UTF8, Properties, Record, TableVersion
8991
from pyiceberg.types import (
9092
BinaryType,
9193
BooleanType,
@@ -102,6 +104,7 @@
102104
PrimitiveType,
103105
StringType,
104106
StructType,
107+
TimestampNanoType,
105108
TimestampType,
106109
TimestamptzType,
107110
TimeType,
@@ -873,6 +876,18 @@ def _write_table_to_file(filepath: str, schema: pa.Schema, table: pa.Table) -> s
873876
return filepath
874877

875878

879+
def _write_table_to_data_file(filepath: str, schema: pa.Schema, table: pa.Table) -> DataFile:
880+
filepath = _write_table_to_file(filepath, schema, table)
881+
return DataFile.from_args(
882+
content=DataFileContent.DATA,
883+
file_path=filepath,
884+
file_format=FileFormat.PARQUET,
885+
partition={},
886+
record_count=len(table),
887+
file_size_in_bytes=22, # This is not relevant for now
888+
)
889+
890+
876891
@pytest.fixture
877892
def file_int(schema_int: Schema, tmpdir: str) -> str:
878893
pyarrow_schema = schema_to_pyarrow(schema_int, metadata={ICEBERG_SCHEMA: bytes(schema_int.model_dump_json(), UTF8)})
@@ -2411,8 +2426,6 @@ def test_partition_for_nested_field() -> None:
24112426

24122427
spec = PartitionSpec(PartitionField(source_id=3, field_id=1000, transform=HourTransform(), name="ts"))
24132428

2414-
from datetime import datetime
2415-
24162429
t1 = datetime(2025, 7, 11, 9, 30, 0)
24172430
t2 = datetime(2025, 7, 11, 10, 30, 0)
24182431

@@ -2551,8 +2564,6 @@ def test_initial_value() -> None:
25512564

25522565

25532566
def test__to_requested_schema_timestamp_to_timestamptz_projection() -> None:
2554-
from datetime import datetime, timezone
2555-
25562567
# file is written with timestamp without timezone
25572568
file_schema = Schema(NestedField(1, "ts_field", TimestampType(), required=False))
25582569
batch = pa.record_batch(
@@ -2722,3 +2733,55 @@ def test_retry_strategy_not_found() -> None:
27222733
io = PyArrowFileIO(properties={S3_RETRY_STRATEGY_IMPL: "pyiceberg.DoesNotExist"})
27232734
with pytest.warns(UserWarning, match="Could not initialize S3 retry strategy: pyiceberg.DoesNotExist"):
27242735
io.new_input("s3://bucket/path/to/file")
2736+
2737+
2738+
@pytest.mark.parametrize("format_version", [1, 2, 3])
2739+
def test_task_to_record_batches_nanos(format_version: TableVersion, tmpdir: str) -> None:
2740+
arrow_table = pa.table(
2741+
[
2742+
pa.array(
2743+
[
2744+
datetime(2025, 8, 14, 12, 0, 0),
2745+
datetime(2025, 8, 14, 13, 0, 0),
2746+
],
2747+
type=pa.timestamp("ns"),
2748+
)
2749+
],
2750+
pa.schema((pa.field("ts_field", pa.timestamp("ns"), nullable=True, metadata={PYARROW_PARQUET_FIELD_ID_KEY: "1"}),)),
2751+
)
2752+
2753+
data_file = _write_table_to_data_file(f"{tmpdir}/test_task_to_record_batches_nanos.parquet", arrow_table.schema, arrow_table)
2754+
2755+
if format_version <= 2:
2756+
table_schema = Schema(NestedField(1, "ts_field", TimestampType(), required=False))
2757+
else:
2758+
table_schema = Schema(NestedField(1, "ts_field", TimestampNanoType(), required=False))
2759+
2760+
actual_result = list(
2761+
_task_to_record_batches(
2762+
PyArrowFileIO(),
2763+
FileScanTask(data_file),
2764+
bound_row_filter=AlwaysTrue(),
2765+
projected_schema=table_schema,
2766+
projected_field_ids={1},
2767+
positional_deletes=None,
2768+
case_sensitive=True,
2769+
format_version=format_version,
2770+
)
2771+
)[0]
2772+
2773+
def _expected_batch(unit: str) -> pa.RecordBatch:
2774+
return pa.record_batch(
2775+
[
2776+
pa.array(
2777+
[
2778+
datetime(2025, 8, 14, 12, 0, 0),
2779+
datetime(2025, 8, 14, 13, 0, 0),
2780+
],
2781+
type=pa.timestamp(unit),
2782+
)
2783+
],
2784+
names=["ts_field"],
2785+
)
2786+
2787+
assert _expected_batch("ns" if format_version > 2 else "us").equals(actual_result)

0 commit comments

Comments
 (0)