Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 29 additions & 1 deletion pyiceberg/partitioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from dataclasses import dataclass
from datetime import date, datetime, time
from functools import cached_property, singledispatch
from typing import Annotated, Any, Dict, Generic, List, Optional, Tuple, TypeVar, Union
from typing import Annotated, Any, Dict, Generic, List, Optional, Set, Tuple, TypeVar, Union
from urllib.parse import quote_plus

from pydantic import (
Expand Down Expand Up @@ -249,6 +249,31 @@ def partition_to_path(self, data: Record, schema: Schema) -> str:
UNPARTITIONED_PARTITION_SPEC = PartitionSpec(spec_id=0)


def validate_partition_name(
field_name: str,
partition_transform: Transform[Any, Any],
source_id: int,
schema: Schema,
partition_names: Set[str],
) -> None:
"""Validate that a partition field name doesn't conflict with schema field names."""
try:
schema_field = schema.find_field(field_name)
except ValueError:
return # No conflict if field doesn't exist in schema

if isinstance(partition_transform, (IdentityTransform, VoidTransform)):
# For identity and void transforms, allow conflict only if sourced from the same schema field
if schema_field.field_id != source_id:
raise ValueError(f"Cannot create identity partition sourced from different field in schema: {field_name}")
else:
raise ValueError(f"Cannot create partition from name that exists in schema: {field_name}")
if not field_name:
raise ValueError("Undefined name")
if field_name in partition_names:
raise ValueError(f"Partition name has to be unique: {field_name}")


def assign_fresh_partition_spec_ids(spec: PartitionSpec, old_schema: Schema, fresh_schema: Schema) -> PartitionSpec:
partition_fields = []
for pos, field in enumerate(spec.fields):
Expand All @@ -258,6 +283,9 @@ def assign_fresh_partition_spec_ids(spec: PartitionSpec, old_schema: Schema, fre
fresh_field = fresh_schema.find_field(original_column_name)
if fresh_field is None:
raise ValueError(f"Could not find field in fresh schema: {original_column_name}")

validate_partition_name(field.name, field.transform, fresh_field.field_id, fresh_schema, set())

partition_fields.append(
PartitionField(
name=field.name,
Expand Down
7 changes: 7 additions & 0 deletions pyiceberg/table/update/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -658,6 +658,13 @@ def _apply(self) -> Schema:

# Check the field-ids
new_schema = Schema(*struct.fields)
from pyiceberg.partitioning import validate_partition_name

for spec in self._transaction.table_metadata.partition_specs:
for partition_field in spec.fields:
validate_partition_name(
partition_field.name, partition_field.transform, partition_field.source_id, new_schema, set()
)
field_ids = set()
for name in self._identifier_field_names:
try:
Expand Down
29 changes: 14 additions & 15 deletions pyiceberg/table/update/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,26 +174,18 @@ def _commit(self) -> UpdatesAndRequirements:
return updates, requirements

def _apply(self) -> PartitionSpec:
def _check_and_add_partition_name(schema: Schema, name: str, source_id: int, partition_names: Set[str]) -> None:
try:
field = schema.find_field(name)
except ValueError:
field = None

if source_id is not None and field is not None and field.field_id != source_id:
raise ValueError(f"Cannot create identity partition from a different field in the schema {name}")
elif field is not None and source_id != field.field_id:
raise ValueError(f"Cannot create partition from name that exists in schema {name}")
if not name:
raise ValueError("Undefined name")
if name in partition_names:
raise ValueError(f"Partition name has to be unique: {name}")
def _check_and_add_partition_name(
schema: Schema, name: str, source_id: int, transform: Transform[Any, Any], partition_names: Set[str]
) -> None:
from pyiceberg.partitioning import validate_partition_name

validate_partition_name(name, transform, source_id, schema, partition_names)
partition_names.add(name)

def _add_new_field(
schema: Schema, source_id: int, field_id: int, name: str, transform: Transform[Any, Any], partition_names: Set[str]
) -> PartitionField:
_check_and_add_partition_name(schema, name, source_id, partition_names)
_check_and_add_partition_name(schema, name, source_id, transform, partition_names)
return PartitionField(source_id, field_id, transform, name)

partition_fields = []
Expand Down Expand Up @@ -244,6 +236,13 @@ def _add_new_field(
partition_fields.append(new_field)

for added_field in self._adds:
_check_and_add_partition_name(
self._transaction.table_metadata.schema(),
added_field.name,
added_field.source_id,
added_field.transform,
partition_names,
)
Comment on lines +239 to +245
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch. just to confirm this covers the newly added partition fields?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, that's correct

new_field = PartitionField(
source_id=added_field.source_id,
field_id=added_field.field_id,
Expand Down
88 changes: 85 additions & 3 deletions tests/integration/test_partition_evolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from pyiceberg.catalog import Catalog
from pyiceberg.exceptions import NoSuchTableError
from pyiceberg.partitioning import PartitionField, PartitionSpec
from pyiceberg.partitioning import UNPARTITIONED_PARTITION_SPEC, PartitionField, PartitionSpec
from pyiceberg.schema import Schema
from pyiceberg.table import Table
from pyiceberg.transforms import (
Expand Down Expand Up @@ -63,13 +63,18 @@ def _table_v2(catalog: Catalog) -> Table:
return _create_table_with_schema(catalog, schema_with_timestamp, "2")


def _create_table_with_schema(catalog: Catalog, schema: Schema, format_version: str) -> Table:
def _create_table_with_schema(
catalog: Catalog, schema: Schema, format_version: str, partition_spec: PartitionSpec = UNPARTITIONED_PARTITION_SPEC
) -> Table:
tbl_name = "default.test_schema_evolution"
try:
catalog.drop_table(tbl_name)
except NoSuchTableError:
pass
return catalog.create_table(identifier=tbl_name, schema=schema, properties={"format-version": format_version})

return catalog.create_table(
identifier=tbl_name, schema=schema, partition_spec=partition_spec, properties={"format-version": format_version}
)


@pytest.mark.integration
Expand Down Expand Up @@ -564,3 +569,80 @@ def _validate_new_partition_fields(
assert len(spec.fields) == len(expected_partition_fields)
for i in range(len(spec.fields)):
assert spec.fields[i] == expected_partition_fields[i]


@pytest.mark.integration
@pytest.mark.parametrize("catalog", [pytest.lazy_fixture("session_catalog_hive"), pytest.lazy_fixture("session_catalog")])
def test_partition_schema_field_name_conflict(catalog: Catalog) -> None:
schema = Schema(
NestedField(1, "id", LongType(), required=False),
NestedField(2, "event_ts", TimestampType(), required=False),
NestedField(3, "another_ts", TimestampType(), required=False),
NestedField(4, "str", StringType(), required=False),
)
table = _create_table_with_schema(catalog, schema, "2")

with pytest.raises(ValueError, match="Cannot create partition from name that exists in schema: another_ts"):
table.update_spec().add_field("event_ts", YearTransform(), "another_ts").commit()
with pytest.raises(ValueError, match="Cannot create partition from name that exists in schema: id"):
table.update_spec().add_field("event_ts", DayTransform(), "id").commit()

with pytest.raises(ValueError, match="Cannot create identity partition sourced from different field in schema: another_ts"):
table.update_spec().add_field("event_ts", IdentityTransform(), "another_ts").commit()
with pytest.raises(ValueError, match="Cannot create identity partition sourced from different field in schema: str"):
table.update_spec().add_field("id", IdentityTransform(), "str").commit()

table.update_spec().add_field("id", IdentityTransform(), "id").commit()
table.update_spec().add_field("event_ts", YearTransform(), "event_year").commit()


@pytest.mark.integration
@pytest.mark.parametrize("catalog", [pytest.lazy_fixture("session_catalog_hive"), pytest.lazy_fixture("session_catalog")])
def test_partition_validation_during_table_creation(catalog: Catalog) -> None:
schema = Schema(
NestedField(1, "id", LongType(), required=False),
NestedField(2, "event_ts", TimestampType(), required=False),
NestedField(3, "another_ts", TimestampType(), required=False),
NestedField(4, "str", StringType(), required=False),
)

partition_spec = PartitionSpec(
PartitionField(source_id=2, field_id=1000, transform=YearTransform(), name="another_ts"), spec_id=1
)
with pytest.raises(ValueError, match="Cannot create partition from name that exists in schema: another_ts"):
_create_table_with_schema(catalog, schema, "2", partition_spec)

partition_spec = PartitionSpec(
PartitionField(source_id=1, field_id=1000, transform=IdentityTransform(), name="id"), spec_id=1
)
_create_table_with_schema(catalog, schema, "2", partition_spec)


@pytest.mark.integration
@pytest.mark.parametrize("catalog", [pytest.lazy_fixture("session_catalog_hive"), pytest.lazy_fixture("session_catalog")])
def test_schema_evolution_partition_conflict(catalog: Catalog) -> None:
schema = Schema(
NestedField(1, "id", LongType(), required=False),
NestedField(2, "event_ts", TimestampType(), required=False),
)
partition_spec = PartitionSpec(
PartitionField(source_id=2, field_id=1000, transform=YearTransform(), name="event_year"),
PartitionField(source_id=2, field_id=1001, transform=IdentityTransform(), name="first_name"),
PartitionField(source_id=1, field_id=1002, transform=IdentityTransform(), name="id"),
spec_id=1,
)
table = _create_table_with_schema(catalog, schema, "2", partition_spec)

with pytest.raises(ValueError, match="Cannot create partition from name that exists in schema: event_year"):
table.update_schema().add_column("event_year", StringType()).commit()
with pytest.raises(ValueError, match="Cannot create identity partition sourced from different field in schema: first_name"):
table.update_schema().add_column("first_name", StringType()).commit()

table.update_schema().add_column("other_field", StringType()).commit()

with pytest.raises(ValueError, match="Cannot create partition from name that exists in schema: event_year"):
table.update_schema().rename_column("other_field", "event_year").commit()
with pytest.raises(ValueError, match="Cannot create identity partition sourced from different field in schema: first_name"):
table.update_schema().rename_column("other_field", "first_name").commit()

table.update_schema().rename_column("other_field", "valid_name").commit()
24 changes: 21 additions & 3 deletions tests/integration/test_writes/test_partitioned_writes.py
Original file line number Diff line number Diff line change
Expand Up @@ -980,8 +980,16 @@ def test_append_ymd_transform_partitioned(
# Given
identifier = f"default.arrow_table_v{format_version}_with_{str(transform)}_partition_on_col_{part_col}"
nested_field = TABLE_SCHEMA.find_field(part_col)

if isinstance(transform, YearTransform):
partition_name = f"{part_col}_year"
elif isinstance(transform, MonthTransform):
partition_name = f"{part_col}_month"
elif isinstance(transform, DayTransform):
partition_name = f"{part_col}_day"

partition_spec = PartitionSpec(
PartitionField(source_id=nested_field.field_id, field_id=1001, transform=transform, name=part_col)
PartitionField(source_id=nested_field.field_id, field_id=1001, transform=transform, name=partition_name)
)

# When
Expand Down Expand Up @@ -1037,8 +1045,18 @@ def test_append_transform_partition_verify_partitions_count(
part_col = "timestamptz"
identifier = f"default.arrow_table_v{format_version}_with_{str(transform)}_transform_partitioned_on_col_{part_col}"
nested_field = table_date_timestamps_schema.find_field(part_col)

if isinstance(transform, YearTransform):
partition_name = f"{part_col}_year"
elif isinstance(transform, MonthTransform):
partition_name = f"{part_col}_month"
elif isinstance(transform, DayTransform):
partition_name = f"{part_col}_day"
elif isinstance(transform, HourTransform):
partition_name = f"{part_col}_hour"

partition_spec = PartitionSpec(
PartitionField(source_id=nested_field.field_id, field_id=1001, transform=transform, name=part_col),
PartitionField(source_id=nested_field.field_id, field_id=1001, transform=transform, name=partition_name),
)

# When
Expand All @@ -1061,7 +1079,7 @@ def test_append_transform_partition_verify_partitions_count(

partitions_table = tbl.inspect.partitions()
assert partitions_table.num_rows == len(expected_partitions)
assert {part[part_col] for part in partitions_table["partition"].to_pylist()} == expected_partitions
assert {part[partition_name] for part in partitions_table["partition"].to_pylist()} == expected_partitions
files_df = spark.sql(
f"""
SELECT *
Expand Down
Loading