Skip to content

partition field names validation against schema field conflicts #2305

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions pyiceberg/partitioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,21 @@ def assign_fresh_partition_spec_ids(spec: PartitionSpec, old_schema: Schema, fre
fresh_field = fresh_schema.find_field(original_column_name)
if fresh_field is None:
raise ValueError(f"Could not find field in fresh schema: {original_column_name}")

try:
schema_field = fresh_schema.find_field(field.name)
except ValueError:
schema_field = None

if schema_field is not None:
if isinstance(field.transform, (IdentityTransform, VoidTransform)):
# For identity transforms, allow only if sourced from the same field
if schema_field.field_id != fresh_field.field_id:
raise ValueError(f"Cannot create identity partition sourced from different field in schema: {field.name}")
else:
# For non-identity transforms, never allow conflicts
raise ValueError(f"Cannot create partition from name that exists in schema: {field.name}")

partition_fields.append(
PartitionField(
name=field.name,
Expand Down
30 changes: 30 additions & 0 deletions pyiceberg/table/update/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,22 @@ def add_column(
parent_full_path = ".".join(parent)
parent_id: int = TABLE_ROOT_ID

# Check for conflicts with partition field names
if self._transaction is not None:
for spec in self._transaction.table_metadata.partition_specs:
for field in spec.fields:
if field.name == name:
from pyiceberg.transforms import IdentityTransform, VoidTransform

if isinstance(field.transform, (IdentityTransform, VoidTransform)):
# For identity transforms, allow conflict only if partition field sources from a field with same name
source_field = self._schema.find_field(field.source_id)
if source_field is None or source_field.name != name:
raise ValueError(f"Cannot add column with name that conflicts with partition field: {name}")
else:
# For non-identity transforms, never allow conflicts
raise ValueError(f"Cannot add column with name that conflicts with partition field: {name}")

if len(parent) > 0:
parent_field = self._schema.find_field(parent_full_path, self._case_sensitive)
parent_type = parent_field.field_type
Expand Down Expand Up @@ -304,6 +320,20 @@ def rename_column(self, path_from: Union[str, Tuple[str, ...]], new_name: str) -
if field_from.field_id in self._deletes:
raise ValueError(f"Cannot rename a column that will be deleted: {path_from}")

if self._transaction is not None:
for spec in self._transaction.table_metadata.partition_specs:
for field in spec.fields:
if field.name == new_name:
from pyiceberg.transforms import IdentityTransform, VoidTransform

if isinstance(field.transform, (IdentityTransform, VoidTransform)):
# For identity transforms, allow conflict only if partition field sources from the renamed field
if field.source_id != field_from.field_id:
raise ValueError(f"Cannot rename column to name that conflicts with partition field: {new_name}")
else:
# For non-identity transforms, never allow conflicts
raise ValueError(f"Cannot rename column to name that conflicts with partition field: {new_name}")

if updated := self._updates.get(field_from.field_id):
self._updates[field_from.field_id] = NestedField(
field_id=updated.field_id,
Expand Down
24 changes: 18 additions & 6 deletions pyiceberg/table/update/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,16 +174,21 @@ def _commit(self) -> UpdatesAndRequirements:
return updates, requirements

def _apply(self) -> PartitionSpec:
def _check_and_add_partition_name(schema: Schema, name: str, source_id: int, partition_names: Set[str]) -> None:
def _check_and_add_partition_name(
schema: Schema, name: str, source_id: int, transform: Transform[Any, Any], partition_names: Set[str]
) -> None:
try:
field = schema.find_field(name)
except ValueError:
field = None

if source_id is not None and field is not None and field.field_id != source_id:
raise ValueError(f"Cannot create identity partition from a different field in the schema {name}")
elif field is not None and source_id != field.field_id:
raise ValueError(f"Cannot create partition from name that exists in schema {name}")
if field is not None:
if isinstance(transform, (IdentityTransform, VoidTransform)):
# For identity transforms allow name conflict only if sourced from the same schema field
if field.field_id != source_id:
raise ValueError(f"Cannot create identity partition from a different field in the schema: {name}")
else:
raise ValueError(f"Cannot create partition from name that exists in schema: {name}")
if not name:
raise ValueError("Undefined name")
if name in partition_names:
Expand All @@ -193,7 +198,7 @@ def _check_and_add_partition_name(schema: Schema, name: str, source_id: int, par
def _add_new_field(
schema: Schema, source_id: int, field_id: int, name: str, transform: Transform[Any, Any], partition_names: Set[str]
) -> PartitionField:
_check_and_add_partition_name(schema, name, source_id, partition_names)
_check_and_add_partition_name(schema, name, source_id, transform, partition_names)
return PartitionField(source_id, field_id, transform, name)

partition_fields = []
Expand Down Expand Up @@ -244,6 +249,13 @@ def _add_new_field(
partition_fields.append(new_field)

for added_field in self._adds:
_check_and_add_partition_name(
self._transaction.table_metadata.schema(),
added_field.name,
added_field.source_id,
added_field.transform,
partition_names,
)
new_field = PartitionField(
source_id=added_field.source_id,
field_id=added_field.field_id,
Expand Down
87 changes: 86 additions & 1 deletion tests/integration/test_partition_evolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
# pylint:disable=redefined-outer-name
from typing import Optional

import pytest

Expand Down Expand Up @@ -63,12 +64,19 @@ def _table_v2(catalog: Catalog) -> Table:
return _create_table_with_schema(catalog, schema_with_timestamp, "2")


def _create_table_with_schema(catalog: Catalog, schema: Schema, format_version: str) -> Table:
def _create_table_with_schema(
catalog: Catalog, schema: Schema, format_version: str, partition_spec: Optional[PartitionSpec] = None
) -> Table:
tbl_name = "default.test_schema_evolution"
try:
catalog.drop_table(tbl_name)
except NoSuchTableError:
pass

if partition_spec:
return catalog.create_table(
identifier=tbl_name, schema=schema, partition_spec=partition_spec, properties={"format-version": format_version}
)
return catalog.create_table(identifier=tbl_name, schema=schema, properties={"format-version": format_version})


Expand Down Expand Up @@ -564,3 +572,80 @@ def _validate_new_partition_fields(
assert len(spec.fields) == len(expected_partition_fields)
for i in range(len(spec.fields)):
assert spec.fields[i] == expected_partition_fields[i]


@pytest.mark.integration
@pytest.mark.parametrize("catalog", [pytest.lazy_fixture("session_catalog_hive"), pytest.lazy_fixture("session_catalog")])
def test_partition_schema_field_name_conflict(catalog: Catalog) -> None:
schema = Schema(
NestedField(1, "id", LongType(), required=False),
NestedField(2, "event_ts", TimestampType(), required=False),
NestedField(3, "another_ts", TimestampType(), required=False),
NestedField(4, "str", StringType(), required=False),
)
table = _create_table_with_schema(catalog, schema, "2")

with pytest.raises(ValueError, match="Cannot create partition from name that exists in schema: another_ts"):
table.update_spec().add_field("event_ts", YearTransform(), "another_ts").commit()
with pytest.raises(ValueError, match="Cannot create partition from name that exists in schema: id"):
table.update_spec().add_field("event_ts", DayTransform(), "id").commit()

with pytest.raises(ValueError, match="Cannot create identity partition from a different field in the schema: another_ts"):
table.update_spec().add_field("event_ts", IdentityTransform(), "another_ts").commit()
with pytest.raises(ValueError, match="Cannot create identity partition from a different field in the schema: str"):
table.update_spec().add_field("id", IdentityTransform(), "str").commit()

table.update_spec().add_field("id", IdentityTransform(), "id").commit()
table.update_spec().add_field("event_ts", YearTransform(), "event_year").commit()


@pytest.mark.integration
@pytest.mark.parametrize("catalog", [pytest.lazy_fixture("session_catalog_hive"), pytest.lazy_fixture("session_catalog")])
def test_partition_validation_during_table_creation(catalog: Catalog) -> None:
schema = Schema(
NestedField(1, "id", LongType(), required=False),
NestedField(2, "event_ts", TimestampType(), required=False),
NestedField(3, "another_ts", TimestampType(), required=False),
NestedField(4, "str", StringType(), required=False),
)

partition_spec = PartitionSpec(
PartitionField(source_id=2, field_id=1000, transform=YearTransform(), name="another_ts"), spec_id=1
)
with pytest.raises(ValueError, match="Cannot create partition from name that exists in schema: another_ts"):
_create_table_with_schema(catalog, schema, "2", partition_spec)

partition_spec = PartitionSpec(
PartitionField(source_id=1, field_id=1000, transform=IdentityTransform(), name="id"), spec_id=1
)
_create_table_with_schema(catalog, schema, "2", partition_spec)


@pytest.mark.integration
@pytest.mark.parametrize("catalog", [pytest.lazy_fixture("session_catalog_hive"), pytest.lazy_fixture("session_catalog")])
def test_schema_evolution_partition_conflict(catalog: Catalog) -> None:
schema = Schema(
NestedField(1, "id", LongType(), required=False),
NestedField(2, "event_ts", TimestampType(), required=False),
)
partition_spec = PartitionSpec(
PartitionField(source_id=2, field_id=1000, transform=YearTransform(), name="event_year"),
PartitionField(source_id=2, field_id=1001, transform=IdentityTransform(), name="first_name"),
PartitionField(source_id=1, field_id=1002, transform=IdentityTransform(), name="id"),
spec_id=1,
)
table = _create_table_with_schema(catalog, schema, "2", partition_spec)

with pytest.raises(ValueError, match="Cannot add column with name that conflicts with partition field: event_year"):
table.update_schema().add_column("event_year", StringType()).commit()
with pytest.raises(ValueError, match="Cannot add column with name that conflicts with partition field: first_name"):
table.update_schema().add_column("first_name", StringType()).commit()

table.update_schema().add_column("other_field", StringType()).commit()

with pytest.raises(ValueError, match="Cannot rename column to name that conflicts with partition field: event_year"):
table.update_schema().rename_column("other_field", "event_year").commit()
with pytest.raises(ValueError, match="Cannot rename column to name that conflicts with partition field: first_name"):
table.update_schema().rename_column("other_field", "first_name").commit()

table.update_schema().rename_column("other_field", "valid_name").commit()
24 changes: 21 additions & 3 deletions tests/integration/test_writes/test_partitioned_writes.py
Original file line number Diff line number Diff line change
Expand Up @@ -980,8 +980,16 @@ def test_append_ymd_transform_partitioned(
# Given
identifier = f"default.arrow_table_v{format_version}_with_{str(transform)}_partition_on_col_{part_col}"
nested_field = TABLE_SCHEMA.find_field(part_col)

if isinstance(transform, YearTransform):
partition_name = f"{part_col}_year"
elif isinstance(transform, MonthTransform):
partition_name = f"{part_col}_month"
elif isinstance(transform, DayTransform):
partition_name = f"{part_col}_day"

partition_spec = PartitionSpec(
PartitionField(source_id=nested_field.field_id, field_id=1001, transform=transform, name=part_col)
PartitionField(source_id=nested_field.field_id, field_id=1001, transform=transform, name=partition_name)
)

# When
Expand Down Expand Up @@ -1037,8 +1045,18 @@ def test_append_transform_partition_verify_partitions_count(
part_col = "timestamptz"
identifier = f"default.arrow_table_v{format_version}_with_{str(transform)}_transform_partitioned_on_col_{part_col}"
nested_field = table_date_timestamps_schema.find_field(part_col)

if isinstance(transform, YearTransform):
partition_name = f"{part_col}_year"
elif isinstance(transform, MonthTransform):
partition_name = f"{part_col}_month"
elif isinstance(transform, DayTransform):
partition_name = f"{part_col}_day"
elif isinstance(transform, HourTransform):
partition_name = f"{part_col}_hour"

partition_spec = PartitionSpec(
PartitionField(source_id=nested_field.field_id, field_id=1001, transform=transform, name=part_col),
PartitionField(source_id=nested_field.field_id, field_id=1001, transform=transform, name=partition_name),
)

# When
Expand All @@ -1061,7 +1079,7 @@ def test_append_transform_partition_verify_partitions_count(

partitions_table = tbl.inspect.partitions()
assert partitions_table.num_rows == len(expected_partitions)
assert {part[part_col] for part in partitions_table["partition"].to_pylist()} == expected_partitions
assert {part[partition_name] for part in partitions_table["partition"].to_pylist()} == expected_partitions
files_df = spark.sql(
f"""
SELECT *
Expand Down