|
16 | 16 | # under the License.
|
17 | 17 | from __future__ import annotations
|
18 | 18 |
|
19 |
| -import datetime |
20 | 19 | import itertools
|
21 | 20 | import uuid
|
22 | 21 | import warnings
|
23 | 22 | from abc import ABC, abstractmethod
|
24 | 23 | from copy import copy
|
25 | 24 | from dataclasses import dataclass
|
| 25 | +from datetime import datetime |
26 | 26 | from enum import Enum
|
27 | 27 | from functools import cached_property, singledispatch
|
28 | 28 | from itertools import chain
|
|
79 | 79 | PARTITION_FIELD_ID_START,
|
80 | 80 | UNPARTITIONED_PARTITION_SPEC,
|
81 | 81 | PartitionField,
|
| 82 | + PartitionFieldValue, |
| 83 | + PartitionKey, |
82 | 84 | PartitionSpec,
|
83 | 85 | _PartitionNameGenerator,
|
84 | 86 | _visit_partition_field,
|
@@ -373,8 +375,11 @@ def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT)
|
373 | 375 | if not isinstance(df, pa.Table):
|
374 | 376 | raise ValueError(f"Expected PyArrow table, got: {df}")
|
375 | 377 |
|
376 |
| - if len(self._table.spec().fields) > 0: |
377 |
| - raise ValueError("Cannot write to partitioned tables") |
| 378 | + supported_transforms = {IdentityTransform} |
| 379 | + if not all(type(field.transform) in supported_transforms for field in self.table_metadata.spec().fields): |
| 380 | + raise ValueError( |
| 381 | + f"All transforms are not supported, expected: {supported_transforms}, but get: {[str(field) for field in self.table_metadata.spec().fields if field.transform not in supported_transforms]}." |
| 382 | + ) |
378 | 383 |
|
379 | 384 | _check_schema_compatible(self._table.schema(), other_schema=df.schema)
|
380 | 385 | # cast if the two schemas are compatible but not equal
|
@@ -897,7 +902,7 @@ def _(update: SetSnapshotRefUpdate, base_metadata: TableMetadata, context: _Tabl
|
897 | 902 | if update.ref_name == MAIN_BRANCH:
|
898 | 903 | metadata_updates["current_snapshot_id"] = snapshot_ref.snapshot_id
|
899 | 904 | if "last_updated_ms" not in metadata_updates:
|
900 |
| - metadata_updates["last_updated_ms"] = datetime_to_millis(datetime.datetime.now().astimezone()) |
| 905 | + metadata_updates["last_updated_ms"] = datetime_to_millis(datetime.now().astimezone()) |
901 | 906 |
|
902 | 907 | metadata_updates["snapshot_log"] = base_metadata.snapshot_log + [
|
903 | 908 | SnapshotLogEntry(
|
@@ -2646,16 +2651,23 @@ def _add_and_move_fields(
|
2646 | 2651 | class WriteTask:
|
2647 | 2652 | write_uuid: uuid.UUID
|
2648 | 2653 | task_id: int
|
| 2654 | + schema: Schema |
2649 | 2655 | record_batches: List[pa.RecordBatch]
|
2650 | 2656 | sort_order_id: Optional[int] = None
|
2651 |
| - |
2652 |
| - # Later to be extended with partition information |
| 2657 | + partition_key: Optional[PartitionKey] = None |
2653 | 2658 |
|
2654 | 2659 | def generate_data_file_filename(self, extension: str) -> str:
|
2655 | 2660 | # Mimics the behavior in the Java API:
|
2656 | 2661 | # https://github.com/apache/iceberg/blob/a582968975dd30ff4917fbbe999f1be903efac02/core/src/main/java/org/apache/iceberg/io/OutputFileFactory.java#L92-L101
|
2657 | 2662 | return f"00000-{self.task_id}-{self.write_uuid}.{extension}"
|
2658 | 2663 |
|
| 2664 | + def generate_data_file_path(self, extension: str) -> str: |
| 2665 | + if self.partition_key: |
| 2666 | + file_path = f"{self.partition_key.to_path()}/{self.generate_data_file_filename(extension)}" |
| 2667 | + return file_path |
| 2668 | + else: |
| 2669 | + return self.generate_data_file_filename(extension) |
| 2670 | + |
2659 | 2671 |
|
2660 | 2672 | @dataclass(frozen=True)
|
2661 | 2673 | class AddFileTask:
|
@@ -2683,25 +2695,40 @@ def _dataframe_to_data_files(
|
2683 | 2695 | """
|
2684 | 2696 | from pyiceberg.io.pyarrow import bin_pack_arrow_table, write_file
|
2685 | 2697 |
|
2686 |
| - if len([spec for spec in table_metadata.partition_specs if spec.spec_id != 0]) > 0: |
2687 |
| - raise ValueError("Cannot write to partitioned tables") |
2688 |
| - |
2689 | 2698 | counter = itertools.count(0)
|
2690 | 2699 | write_uuid = write_uuid or uuid.uuid4()
|
2691 |
| - |
2692 |
| - target_file_size = PropertyUtil.property_as_int( |
| 2700 | + target_file_size: int = PropertyUtil.property_as_int( # type: ignore # The property is set with non-None value. |
2693 | 2701 | properties=table_metadata.properties,
|
2694 | 2702 | property_name=TableProperties.WRITE_TARGET_FILE_SIZE_BYTES,
|
2695 | 2703 | default=TableProperties.WRITE_TARGET_FILE_SIZE_BYTES_DEFAULT,
|
2696 | 2704 | )
|
2697 | 2705 |
|
2698 |
| - # This is an iter, so we don't have to materialize everything every time |
2699 |
| - # This will be more relevant when we start doing partitioned writes |
2700 |
| - yield from write_file( |
2701 |
| - io=io, |
2702 |
| - table_metadata=table_metadata, |
2703 |
| - tasks=iter([WriteTask(write_uuid, next(counter), batches) for batches in bin_pack_arrow_table(df, target_file_size)]), # type: ignore |
2704 |
| - ) |
| 2706 | + if len(table_metadata.spec().fields) > 0: |
| 2707 | + partitions = _determine_partitions(spec=table_metadata.spec(), schema=table_metadata.schema(), arrow_table=df) |
| 2708 | + yield from write_file( |
| 2709 | + io=io, |
| 2710 | + table_metadata=table_metadata, |
| 2711 | + tasks=iter([ |
| 2712 | + WriteTask( |
| 2713 | + write_uuid=write_uuid, |
| 2714 | + task_id=next(counter), |
| 2715 | + record_batches=batches, |
| 2716 | + partition_key=partition.partition_key, |
| 2717 | + schema=table_metadata.schema(), |
| 2718 | + ) |
| 2719 | + for partition in partitions |
| 2720 | + for batches in bin_pack_arrow_table(partition.arrow_table_partition, target_file_size) |
| 2721 | + ]), |
| 2722 | + ) |
| 2723 | + else: |
| 2724 | + yield from write_file( |
| 2725 | + io=io, |
| 2726 | + table_metadata=table_metadata, |
| 2727 | + tasks=iter([ |
| 2728 | + WriteTask(write_uuid=write_uuid, task_id=next(counter), record_batches=batches, schema=table_metadata.schema()) |
| 2729 | + for batches in bin_pack_arrow_table(df, target_file_size) |
| 2730 | + ]), |
| 2731 | + ) |
2705 | 2732 |
|
2706 | 2733 |
|
2707 | 2734 | def _parquet_files_to_data_files(table_metadata: TableMetadata, file_paths: List[str], io: FileIO) -> Iterable[DataFile]:
|
@@ -3253,7 +3280,7 @@ def snapshots(self) -> "pa.Table":
|
3253 | 3280 | additional_properties = None
|
3254 | 3281 |
|
3255 | 3282 | snapshots.append({
|
3256 |
| - 'committed_at': datetime.datetime.utcfromtimestamp(snapshot.timestamp_ms / 1000.0), |
| 3283 | + 'committed_at': datetime.utcfromtimestamp(snapshot.timestamp_ms / 1000.0), |
3257 | 3284 | 'snapshot_id': snapshot.snapshot_id,
|
3258 | 3285 | 'parent_id': snapshot.parent_snapshot_id,
|
3259 | 3286 | 'operation': str(operation),
|
@@ -3388,3 +3415,112 @@ def _readable_metrics_struct(bound_type: PrimitiveType) -> pa.StructType:
|
3388 | 3415 | entries,
|
3389 | 3416 | schema=entries_schema,
|
3390 | 3417 | )
|
| 3418 | + |
| 3419 | + |
| 3420 | +@dataclass(frozen=True) |
| 3421 | +class TablePartition: |
| 3422 | + partition_key: PartitionKey |
| 3423 | + arrow_table_partition: pa.Table |
| 3424 | + |
| 3425 | + |
| 3426 | +def _get_partition_sort_order(partition_columns: list[str], reverse: bool = False) -> dict[str, Any]: |
| 3427 | + order = 'ascending' if not reverse else 'descending' |
| 3428 | + null_placement = 'at_start' if reverse else 'at_end' |
| 3429 | + return {'sort_keys': [(column_name, order) for column_name in partition_columns], 'null_placement': null_placement} |
| 3430 | + |
| 3431 | + |
| 3432 | +def group_by_partition_scheme(arrow_table: pa.Table, partition_columns: list[str]) -> pa.Table: |
| 3433 | + """Given a table, sort it by current partition scheme.""" |
| 3434 | + # only works for identity for now |
| 3435 | + sort_options = _get_partition_sort_order(partition_columns, reverse=False) |
| 3436 | + sorted_arrow_table = arrow_table.sort_by(sorting=sort_options['sort_keys'], null_placement=sort_options['null_placement']) |
| 3437 | + return sorted_arrow_table |
| 3438 | + |
| 3439 | + |
| 3440 | +def get_partition_columns( |
| 3441 | + spec: PartitionSpec, |
| 3442 | + schema: Schema, |
| 3443 | +) -> list[str]: |
| 3444 | + partition_cols = [] |
| 3445 | + for partition_field in spec.fields: |
| 3446 | + column_name = schema.find_column_name(partition_field.source_id) |
| 3447 | + if not column_name: |
| 3448 | + raise ValueError(f"{partition_field=} could not be found in {schema}.") |
| 3449 | + partition_cols.append(column_name) |
| 3450 | + return partition_cols |
| 3451 | + |
| 3452 | + |
| 3453 | +def _get_table_partitions( |
| 3454 | + arrow_table: pa.Table, |
| 3455 | + partition_spec: PartitionSpec, |
| 3456 | + schema: Schema, |
| 3457 | + slice_instructions: list[dict[str, Any]], |
| 3458 | +) -> list[TablePartition]: |
| 3459 | + sorted_slice_instructions = sorted(slice_instructions, key=lambda x: x['offset']) |
| 3460 | + |
| 3461 | + partition_fields = partition_spec.fields |
| 3462 | + |
| 3463 | + offsets = [inst["offset"] for inst in sorted_slice_instructions] |
| 3464 | + projected_and_filtered = { |
| 3465 | + partition_field.source_id: arrow_table[schema.find_field(name_or_id=partition_field.source_id).name] |
| 3466 | + .take(offsets) |
| 3467 | + .to_pylist() |
| 3468 | + for partition_field in partition_fields |
| 3469 | + } |
| 3470 | + |
| 3471 | + table_partitions = [] |
| 3472 | + for idx, inst in enumerate(sorted_slice_instructions): |
| 3473 | + partition_slice = arrow_table.slice(**inst) |
| 3474 | + fieldvalues = [ |
| 3475 | + PartitionFieldValue(partition_field, projected_and_filtered[partition_field.source_id][idx]) |
| 3476 | + for partition_field in partition_fields |
| 3477 | + ] |
| 3478 | + partition_key = PartitionKey(raw_partition_field_values=fieldvalues, partition_spec=partition_spec, schema=schema) |
| 3479 | + table_partitions.append(TablePartition(partition_key=partition_key, arrow_table_partition=partition_slice)) |
| 3480 | + return table_partitions |
| 3481 | + |
| 3482 | + |
| 3483 | +def _determine_partitions(spec: PartitionSpec, schema: Schema, arrow_table: pa.Table) -> List[TablePartition]: |
| 3484 | + """Based on the iceberg table partition spec, slice the arrow table into partitions with their keys. |
| 3485 | +
|
| 3486 | + Example: |
| 3487 | + Input: |
| 3488 | + An arrow table with partition key of ['n_legs', 'year'] and with data of |
| 3489 | + {'year': [2020, 2022, 2022, 2021, 2022, 2022, 2022, 2019, 2021], |
| 3490 | + 'n_legs': [2, 2, 2, 4, 4, 4, 4, 5, 100], |
| 3491 | + 'animal': ["Flamingo", "Parrot", "Parrot", "Dog", "Horse", "Horse", "Horse","Brittle stars", "Centipede"]}. |
| 3492 | + The algrithm: |
| 3493 | + Firstly we group the rows into partitions by sorting with sort order [('n_legs', 'descending'), ('year', 'descending')] |
| 3494 | + and null_placement of "at_end". |
| 3495 | + This gives the same table as raw input. |
| 3496 | + Then we sort_indices using reverse order of [('n_legs', 'descending'), ('year', 'descending')] |
| 3497 | + and null_placement : "at_start". |
| 3498 | + This gives: |
| 3499 | + [8, 7, 4, 5, 6, 3, 1, 2, 0] |
| 3500 | + Based on this we get partition groups of indices: |
| 3501 | + [{'offset': 8, 'length': 1}, {'offset': 7, 'length': 1}, {'offset': 4, 'length': 3}, {'offset': 3, 'length': 1}, {'offset': 1, 'length': 2}, {'offset': 0, 'length': 1}] |
| 3502 | + We then retrieve the partition keys by offsets. |
| 3503 | + And slice the arrow table by offsets and lengths of each partition. |
| 3504 | + """ |
| 3505 | + import pyarrow as pa |
| 3506 | + |
| 3507 | + partition_columns = get_partition_columns(spec=spec, schema=schema) |
| 3508 | + arrow_table = group_by_partition_scheme(arrow_table, partition_columns) |
| 3509 | + |
| 3510 | + reversing_sort_order_options = _get_partition_sort_order(partition_columns, reverse=True) |
| 3511 | + reversed_indices = pa.compute.sort_indices(arrow_table, **reversing_sort_order_options).to_pylist() |
| 3512 | + |
| 3513 | + slice_instructions: list[dict[str, Any]] = [] |
| 3514 | + last = len(reversed_indices) |
| 3515 | + reversed_indices_size = len(reversed_indices) |
| 3516 | + ptr = 0 |
| 3517 | + while ptr < reversed_indices_size: |
| 3518 | + group_size = last - reversed_indices[ptr] |
| 3519 | + offset = reversed_indices[ptr] |
| 3520 | + slice_instructions.append({"offset": offset, "length": group_size}) |
| 3521 | + last = reversed_indices[ptr] |
| 3522 | + ptr = ptr + group_size |
| 3523 | + |
| 3524 | + table_partitions: list[TablePartition] = _get_table_partitions(arrow_table, spec, schema, slice_instructions) |
| 3525 | + |
| 3526 | + return table_partitions |
0 commit comments