Skip to content

Commit 4148edb

Browse files
authored
Partitioned Append on Identity Transform (#555)
* partitioned append on identity transform * remove unnecessary fixture * added null/empty table tests; fixed part of PR comments * tests for unsupported transforms; unit tests for partition slicing algorithm * add a comprehensive partition unit test * clean up * move common fixtures utils to utils.py and conftest * pull partitioned table fixtures into tests for more real-time feedback of running test * fix linting * license * save changes for swtiching codespaces * part of the comment fixes * fix one type error * add support for timetype * small fix for type hint
1 parent ee4dd92 commit 4148edb

File tree

11 files changed

+763
-146
lines changed

11 files changed

+763
-146
lines changed

pyiceberg/io/pyarrow.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1772,7 +1772,7 @@ def write_file(io: FileIO, table_metadata: TableMetadata, tasks: Iterator[WriteT
17721772
)
17731773

17741774
def write_parquet(task: WriteTask) -> DataFile:
1775-
file_path = f'{table_metadata.location}/data/{task.generate_data_file_filename("parquet")}'
1775+
file_path = f'{table_metadata.location}/data/{task.generate_data_file_path("parquet")}'
17761776
fo = io.new_output(file_path)
17771777
with fo.create(overwrite=True) as fos:
17781778
with pq.ParquetWriter(fos, schema=arrow_file_schema, **parquet_writer_kwargs) as writer:
@@ -1787,7 +1787,7 @@ def write_parquet(task: WriteTask) -> DataFile:
17871787
content=DataFileContent.DATA,
17881788
file_path=file_path,
17891789
file_format=FileFormat.PARQUET,
1790-
partition=Record(),
1790+
partition=task.partition_key.partition if task.partition_key else Record(),
17911791
file_size_in_bytes=len(fo),
17921792
# After this has been fixed:
17931793
# https://github.com/apache/iceberg-python/issues/271

pyiceberg/manifest.py

Lines changed: 1 addition & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
import math
2020
from abc import ABC, abstractmethod
2121
from enum import Enum
22-
from functools import singledispatch
2322
from types import TracebackType
2423
from typing import (
2524
Any,
@@ -41,8 +40,6 @@
4140
from pyiceberg.types import (
4241
BinaryType,
4342
BooleanType,
44-
DateType,
45-
IcebergType,
4643
IntegerType,
4744
ListType,
4845
LongType,
@@ -51,9 +48,6 @@
5148
PrimitiveType,
5249
StringType,
5350
StructType,
54-
TimestampType,
55-
TimestamptzType,
56-
TimeType,
5751
)
5852

5953
UNASSIGNED_SEQ = -1
@@ -283,31 +277,12 @@ def __repr__(self) -> str:
283277
}
284278

285279

286-
@singledispatch
287-
def partition_field_to_data_file_partition_field(partition_field_type: IcebergType) -> PrimitiveType:
288-
raise TypeError(f"Unsupported partition field type: {partition_field_type}")
289-
290-
291-
@partition_field_to_data_file_partition_field.register(LongType)
292-
@partition_field_to_data_file_partition_field.register(DateType)
293-
@partition_field_to_data_file_partition_field.register(TimeType)
294-
@partition_field_to_data_file_partition_field.register(TimestampType)
295-
@partition_field_to_data_file_partition_field.register(TimestamptzType)
296-
def _(partition_field_type: PrimitiveType) -> IntegerType:
297-
return IntegerType()
298-
299-
300-
@partition_field_to_data_file_partition_field.register(PrimitiveType)
301-
def _(partition_field_type: PrimitiveType) -> PrimitiveType:
302-
return partition_field_type
303-
304-
305280
def data_file_with_partition(partition_type: StructType, format_version: TableVersion) -> StructType:
306281
data_file_partition_type = StructType(*[
307282
NestedField(
308283
field_id=field.field_id,
309284
name=field.name,
310-
field_type=partition_field_to_data_file_partition_field(field.field_type),
285+
field_type=field.field_type,
311286
required=field.required,
312287
)
313288
for field in partition_type.fields

pyiceberg/partitioning.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import uuid
2020
from abc import ABC, abstractmethod
2121
from dataclasses import dataclass
22-
from datetime import date, datetime
22+
from datetime import date, datetime, time
2323
from functools import cached_property, singledispatch
2424
from typing import (
2525
Any,
@@ -62,9 +62,10 @@
6262
StructType,
6363
TimestampType,
6464
TimestamptzType,
65+
TimeType,
6566
UUIDType,
6667
)
67-
from pyiceberg.utils.datetime import date_to_days, datetime_to_micros
68+
from pyiceberg.utils.datetime import date_to_days, datetime_to_micros, time_to_micros
6869

6970
INITIAL_PARTITION_SPEC_ID = 0
7071
PARTITION_FIELD_ID_START: int = 1000
@@ -431,6 +432,11 @@ def _(type: IcebergType, value: Optional[date]) -> Optional[int]:
431432
return date_to_days(value) if value is not None else None
432433

433434

435+
@_to_partition_representation.register(TimeType)
436+
def _(type: IcebergType, value: Optional[time]) -> Optional[int]:
437+
return time_to_micros(value) if value is not None else None
438+
439+
434440
@_to_partition_representation.register(UUIDType)
435441
def _(type: IcebergType, value: Optional[uuid.UUID]) -> Optional[str]:
436442
return str(value) if value is not None else None

pyiceberg/table/__init__.py

Lines changed: 155 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,13 @@
1616
# under the License.
1717
from __future__ import annotations
1818

19-
import datetime
2019
import itertools
2120
import uuid
2221
import warnings
2322
from abc import ABC, abstractmethod
2423
from copy import copy
2524
from dataclasses import dataclass
25+
from datetime import datetime
2626
from enum import Enum
2727
from functools import cached_property, singledispatch
2828
from itertools import chain
@@ -79,6 +79,8 @@
7979
PARTITION_FIELD_ID_START,
8080
UNPARTITIONED_PARTITION_SPEC,
8181
PartitionField,
82+
PartitionFieldValue,
83+
PartitionKey,
8284
PartitionSpec,
8385
_PartitionNameGenerator,
8486
_visit_partition_field,
@@ -373,8 +375,11 @@ def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT)
373375
if not isinstance(df, pa.Table):
374376
raise ValueError(f"Expected PyArrow table, got: {df}")
375377

376-
if len(self._table.spec().fields) > 0:
377-
raise ValueError("Cannot write to partitioned tables")
378+
supported_transforms = {IdentityTransform}
379+
if not all(type(field.transform) in supported_transforms for field in self.table_metadata.spec().fields):
380+
raise ValueError(
381+
f"All transforms are not supported, expected: {supported_transforms}, but get: {[str(field) for field in self.table_metadata.spec().fields if field.transform not in supported_transforms]}."
382+
)
378383

379384
_check_schema_compatible(self._table.schema(), other_schema=df.schema)
380385
# cast if the two schemas are compatible but not equal
@@ -897,7 +902,7 @@ def _(update: SetSnapshotRefUpdate, base_metadata: TableMetadata, context: _Tabl
897902
if update.ref_name == MAIN_BRANCH:
898903
metadata_updates["current_snapshot_id"] = snapshot_ref.snapshot_id
899904
if "last_updated_ms" not in metadata_updates:
900-
metadata_updates["last_updated_ms"] = datetime_to_millis(datetime.datetime.now().astimezone())
905+
metadata_updates["last_updated_ms"] = datetime_to_millis(datetime.now().astimezone())
901906

902907
metadata_updates["snapshot_log"] = base_metadata.snapshot_log + [
903908
SnapshotLogEntry(
@@ -2646,16 +2651,23 @@ def _add_and_move_fields(
26462651
class WriteTask:
26472652
write_uuid: uuid.UUID
26482653
task_id: int
2654+
schema: Schema
26492655
record_batches: List[pa.RecordBatch]
26502656
sort_order_id: Optional[int] = None
2651-
2652-
# Later to be extended with partition information
2657+
partition_key: Optional[PartitionKey] = None
26532658

26542659
def generate_data_file_filename(self, extension: str) -> str:
26552660
# Mimics the behavior in the Java API:
26562661
# https://github.com/apache/iceberg/blob/a582968975dd30ff4917fbbe999f1be903efac02/core/src/main/java/org/apache/iceberg/io/OutputFileFactory.java#L92-L101
26572662
return f"00000-{self.task_id}-{self.write_uuid}.{extension}"
26582663

2664+
def generate_data_file_path(self, extension: str) -> str:
2665+
if self.partition_key:
2666+
file_path = f"{self.partition_key.to_path()}/{self.generate_data_file_filename(extension)}"
2667+
return file_path
2668+
else:
2669+
return self.generate_data_file_filename(extension)
2670+
26592671

26602672
@dataclass(frozen=True)
26612673
class AddFileTask:
@@ -2683,25 +2695,40 @@ def _dataframe_to_data_files(
26832695
"""
26842696
from pyiceberg.io.pyarrow import bin_pack_arrow_table, write_file
26852697

2686-
if len([spec for spec in table_metadata.partition_specs if spec.spec_id != 0]) > 0:
2687-
raise ValueError("Cannot write to partitioned tables")
2688-
26892698
counter = itertools.count(0)
26902699
write_uuid = write_uuid or uuid.uuid4()
2691-
2692-
target_file_size = PropertyUtil.property_as_int(
2700+
target_file_size: int = PropertyUtil.property_as_int( # type: ignore # The property is set with non-None value.
26932701
properties=table_metadata.properties,
26942702
property_name=TableProperties.WRITE_TARGET_FILE_SIZE_BYTES,
26952703
default=TableProperties.WRITE_TARGET_FILE_SIZE_BYTES_DEFAULT,
26962704
)
26972705

2698-
# This is an iter, so we don't have to materialize everything every time
2699-
# This will be more relevant when we start doing partitioned writes
2700-
yield from write_file(
2701-
io=io,
2702-
table_metadata=table_metadata,
2703-
tasks=iter([WriteTask(write_uuid, next(counter), batches) for batches in bin_pack_arrow_table(df, target_file_size)]), # type: ignore
2704-
)
2706+
if len(table_metadata.spec().fields) > 0:
2707+
partitions = _determine_partitions(spec=table_metadata.spec(), schema=table_metadata.schema(), arrow_table=df)
2708+
yield from write_file(
2709+
io=io,
2710+
table_metadata=table_metadata,
2711+
tasks=iter([
2712+
WriteTask(
2713+
write_uuid=write_uuid,
2714+
task_id=next(counter),
2715+
record_batches=batches,
2716+
partition_key=partition.partition_key,
2717+
schema=table_metadata.schema(),
2718+
)
2719+
for partition in partitions
2720+
for batches in bin_pack_arrow_table(partition.arrow_table_partition, target_file_size)
2721+
]),
2722+
)
2723+
else:
2724+
yield from write_file(
2725+
io=io,
2726+
table_metadata=table_metadata,
2727+
tasks=iter([
2728+
WriteTask(write_uuid=write_uuid, task_id=next(counter), record_batches=batches, schema=table_metadata.schema())
2729+
for batches in bin_pack_arrow_table(df, target_file_size)
2730+
]),
2731+
)
27052732

27062733

27072734
def _parquet_files_to_data_files(table_metadata: TableMetadata, file_paths: List[str], io: FileIO) -> Iterable[DataFile]:
@@ -3253,7 +3280,7 @@ def snapshots(self) -> "pa.Table":
32533280
additional_properties = None
32543281

32553282
snapshots.append({
3256-
'committed_at': datetime.datetime.utcfromtimestamp(snapshot.timestamp_ms / 1000.0),
3283+
'committed_at': datetime.utcfromtimestamp(snapshot.timestamp_ms / 1000.0),
32573284
'snapshot_id': snapshot.snapshot_id,
32583285
'parent_id': snapshot.parent_snapshot_id,
32593286
'operation': str(operation),
@@ -3388,3 +3415,112 @@ def _readable_metrics_struct(bound_type: PrimitiveType) -> pa.StructType:
33883415
entries,
33893416
schema=entries_schema,
33903417
)
3418+
3419+
3420+
@dataclass(frozen=True)
3421+
class TablePartition:
3422+
partition_key: PartitionKey
3423+
arrow_table_partition: pa.Table
3424+
3425+
3426+
def _get_partition_sort_order(partition_columns: list[str], reverse: bool = False) -> dict[str, Any]:
3427+
order = 'ascending' if not reverse else 'descending'
3428+
null_placement = 'at_start' if reverse else 'at_end'
3429+
return {'sort_keys': [(column_name, order) for column_name in partition_columns], 'null_placement': null_placement}
3430+
3431+
3432+
def group_by_partition_scheme(arrow_table: pa.Table, partition_columns: list[str]) -> pa.Table:
3433+
"""Given a table, sort it by current partition scheme."""
3434+
# only works for identity for now
3435+
sort_options = _get_partition_sort_order(partition_columns, reverse=False)
3436+
sorted_arrow_table = arrow_table.sort_by(sorting=sort_options['sort_keys'], null_placement=sort_options['null_placement'])
3437+
return sorted_arrow_table
3438+
3439+
3440+
def get_partition_columns(
3441+
spec: PartitionSpec,
3442+
schema: Schema,
3443+
) -> list[str]:
3444+
partition_cols = []
3445+
for partition_field in spec.fields:
3446+
column_name = schema.find_column_name(partition_field.source_id)
3447+
if not column_name:
3448+
raise ValueError(f"{partition_field=} could not be found in {schema}.")
3449+
partition_cols.append(column_name)
3450+
return partition_cols
3451+
3452+
3453+
def _get_table_partitions(
3454+
arrow_table: pa.Table,
3455+
partition_spec: PartitionSpec,
3456+
schema: Schema,
3457+
slice_instructions: list[dict[str, Any]],
3458+
) -> list[TablePartition]:
3459+
sorted_slice_instructions = sorted(slice_instructions, key=lambda x: x['offset'])
3460+
3461+
partition_fields = partition_spec.fields
3462+
3463+
offsets = [inst["offset"] for inst in sorted_slice_instructions]
3464+
projected_and_filtered = {
3465+
partition_field.source_id: arrow_table[schema.find_field(name_or_id=partition_field.source_id).name]
3466+
.take(offsets)
3467+
.to_pylist()
3468+
for partition_field in partition_fields
3469+
}
3470+
3471+
table_partitions = []
3472+
for idx, inst in enumerate(sorted_slice_instructions):
3473+
partition_slice = arrow_table.slice(**inst)
3474+
fieldvalues = [
3475+
PartitionFieldValue(partition_field, projected_and_filtered[partition_field.source_id][idx])
3476+
for partition_field in partition_fields
3477+
]
3478+
partition_key = PartitionKey(raw_partition_field_values=fieldvalues, partition_spec=partition_spec, schema=schema)
3479+
table_partitions.append(TablePartition(partition_key=partition_key, arrow_table_partition=partition_slice))
3480+
return table_partitions
3481+
3482+
3483+
def _determine_partitions(spec: PartitionSpec, schema: Schema, arrow_table: pa.Table) -> List[TablePartition]:
3484+
"""Based on the iceberg table partition spec, slice the arrow table into partitions with their keys.
3485+
3486+
Example:
3487+
Input:
3488+
An arrow table with partition key of ['n_legs', 'year'] and with data of
3489+
{'year': [2020, 2022, 2022, 2021, 2022, 2022, 2022, 2019, 2021],
3490+
'n_legs': [2, 2, 2, 4, 4, 4, 4, 5, 100],
3491+
'animal': ["Flamingo", "Parrot", "Parrot", "Dog", "Horse", "Horse", "Horse","Brittle stars", "Centipede"]}.
3492+
The algrithm:
3493+
Firstly we group the rows into partitions by sorting with sort order [('n_legs', 'descending'), ('year', 'descending')]
3494+
and null_placement of "at_end".
3495+
This gives the same table as raw input.
3496+
Then we sort_indices using reverse order of [('n_legs', 'descending'), ('year', 'descending')]
3497+
and null_placement : "at_start".
3498+
This gives:
3499+
[8, 7, 4, 5, 6, 3, 1, 2, 0]
3500+
Based on this we get partition groups of indices:
3501+
[{'offset': 8, 'length': 1}, {'offset': 7, 'length': 1}, {'offset': 4, 'length': 3}, {'offset': 3, 'length': 1}, {'offset': 1, 'length': 2}, {'offset': 0, 'length': 1}]
3502+
We then retrieve the partition keys by offsets.
3503+
And slice the arrow table by offsets and lengths of each partition.
3504+
"""
3505+
import pyarrow as pa
3506+
3507+
partition_columns = get_partition_columns(spec=spec, schema=schema)
3508+
arrow_table = group_by_partition_scheme(arrow_table, partition_columns)
3509+
3510+
reversing_sort_order_options = _get_partition_sort_order(partition_columns, reverse=True)
3511+
reversed_indices = pa.compute.sort_indices(arrow_table, **reversing_sort_order_options).to_pylist()
3512+
3513+
slice_instructions: list[dict[str, Any]] = []
3514+
last = len(reversed_indices)
3515+
reversed_indices_size = len(reversed_indices)
3516+
ptr = 0
3517+
while ptr < reversed_indices_size:
3518+
group_size = last - reversed_indices[ptr]
3519+
offset = reversed_indices[ptr]
3520+
slice_instructions.append({"offset": offset, "length": group_size})
3521+
last = reversed_indices[ptr]
3522+
ptr = ptr + group_size
3523+
3524+
table_partitions: list[TablePartition] = _get_table_partitions(arrow_table, spec, schema, slice_instructions)
3525+
3526+
return table_partitions

pyiceberg/typedef.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,5 +202,9 @@ def record_fields(self) -> List[str]:
202202
"""Return values of all the fields of the Record class except those specified in skip_fields."""
203203
return [self.__getattribute__(v) if hasattr(self, v) else None for v in self._position_to_field_name]
204204

205+
def __hash__(self) -> int:
206+
"""Return hash value of the Record class."""
207+
return hash(str(self))
208+
205209

206210
TableVersion: TypeAlias = Literal[1, 2]

0 commit comments

Comments
 (0)