Skip to content

Commit b1f6d5d

Browse files
committed
partition-schema name conflict validation during table creation and schema update
1 parent 92a29e8 commit b1f6d5d

File tree

4 files changed

+107
-4
lines changed

4 files changed

+107
-4
lines changed

pyiceberg/partitioning.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,21 @@ def assign_fresh_partition_spec_ids(spec: PartitionSpec, old_schema: Schema, fre
258258
fresh_field = fresh_schema.find_field(original_column_name)
259259
if fresh_field is None:
260260
raise ValueError(f"Could not find field in fresh schema: {original_column_name}")
261+
262+
try:
263+
schema_field = fresh_schema.find_field(field.name)
264+
except ValueError:
265+
schema_field = None
266+
267+
if schema_field is not None:
268+
if isinstance(field.transform, (IdentityTransform, VoidTransform)):
269+
# For identity transforms, allow only if sourced from the same field
270+
if schema_field.field_id != fresh_field.field_id:
271+
raise ValueError(f"Cannot create identity partition sourced from different field in schema: {field.name}")
272+
else:
273+
# For non-identity transforms, never allow conflicts
274+
raise ValueError(f"Cannot create partition from name that exists in schema: {field.name}")
275+
261276
partition_fields.append(
262277
PartitionField(
263278
name=field.name,

pyiceberg/table/update/schema.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,22 @@ def add_column(
192192
parent_full_path = ".".join(parent)
193193
parent_id: int = TABLE_ROOT_ID
194194

195+
# Check for conflicts with partition field names
196+
if len(parent) == 0 and self._transaction is not None:
197+
for spec in self._transaction.table_metadata.partition_specs:
198+
for field in spec.fields:
199+
if field.name == name:
200+
from pyiceberg.transforms import IdentityTransform, VoidTransform
201+
202+
if isinstance(field.transform, (IdentityTransform, VoidTransform)):
203+
# For identity transforms, allow conflict only if partition field sources from a field with same name
204+
source_field = self._schema.find_field(field.source_id)
205+
if source_field is None or source_field.name != name:
206+
raise ValueError(f"Cannot add column with name that conflicts with partition field: {name}")
207+
else:
208+
# For non-identity transforms, never allow conflicts
209+
raise ValueError(f"Cannot add column with name that conflicts with partition field: {name}")
210+
195211
if len(parent) > 0:
196212
parent_field = self._schema.find_field(parent_full_path, self._case_sensitive)
197213
parent_type = parent_field.field_type

tests/integration/test_partition_evolution.py

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717
# pylint:disable=redefined-outer-name
18+
from typing import Optional
1819

1920
import pytest
2021

@@ -63,12 +64,19 @@ def _table_v2(catalog: Catalog) -> Table:
6364
return _create_table_with_schema(catalog, schema_with_timestamp, "2")
6465

6566

66-
def _create_table_with_schema(catalog: Catalog, schema: Schema, format_version: str) -> Table:
67+
def _create_table_with_schema(
68+
catalog: Catalog, schema: Schema, format_version: str, partition_spec: Optional[PartitionSpec] = None
69+
) -> Table:
6770
tbl_name = "default.test_schema_evolution"
6871
try:
6972
catalog.drop_table(tbl_name)
7073
except NoSuchTableError:
7174
pass
75+
76+
if partition_spec:
77+
return catalog.create_table(
78+
identifier=tbl_name, schema=schema, partition_spec=partition_spec, properties={"format-version": format_version}
79+
)
7280
return catalog.create_table(identifier=tbl_name, schema=schema, properties={"format-version": format_version})
7381

7482

@@ -589,3 +597,49 @@ def test_partition_schema_field_name_conflict(catalog: Catalog) -> None:
589597

590598
table.update_spec().add_field("id", IdentityTransform(), "id").commit()
591599
table.update_spec().add_field("event_ts", YearTransform(), "event_year").commit()
600+
601+
602+
@pytest.mark.integration
603+
@pytest.mark.parametrize("catalog", [pytest.lazy_fixture("session_catalog_hive"), pytest.lazy_fixture("session_catalog")])
604+
def test_partition_validation_during_table_creation(catalog: Catalog) -> None:
605+
schema = Schema(
606+
NestedField(1, "id", LongType(), required=False),
607+
NestedField(2, "event_ts", TimestampType(), required=False),
608+
NestedField(3, "another_ts", TimestampType(), required=False),
609+
NestedField(4, "str", StringType(), required=False),
610+
)
611+
612+
partition_spec = PartitionSpec(
613+
PartitionField(source_id=2, field_id=1000, transform=YearTransform(), name="another_ts"), spec_id=1
614+
)
615+
with pytest.raises(ValueError, match="Cannot create partition from name that exists in schema: another_ts"):
616+
_create_table_with_schema(catalog, schema, "2", partition_spec)
617+
618+
partition_spec = PartitionSpec(
619+
PartitionField(source_id=1, field_id=1000, transform=IdentityTransform(), name="id"), spec_id=1
620+
)
621+
_create_table_with_schema(catalog, schema, "2", partition_spec)
622+
623+
624+
@pytest.mark.integration
625+
@pytest.mark.parametrize("catalog", [pytest.lazy_fixture("session_catalog_hive"), pytest.lazy_fixture("session_catalog")])
626+
def test_schema_evolution_partition_conflict(catalog: Catalog) -> None:
627+
schema = Schema(
628+
NestedField(1, "id", LongType(), required=False),
629+
NestedField(2, "event_ts", TimestampType(), required=False),
630+
)
631+
partition_spec = PartitionSpec(
632+
PartitionField(source_id=2, field_id=1000, transform=YearTransform(), name="event_year"),
633+
PartitionField(source_id=2, field_id=1001, transform=IdentityTransform(), name="first_name"),
634+
PartitionField(source_id=1, field_id=1002, transform=IdentityTransform(), name="id"),
635+
spec_id=1,
636+
)
637+
table = _create_table_with_schema(catalog, schema, "2", partition_spec)
638+
639+
with pytest.raises(ValueError, match="Cannot add column with name that conflicts with partition field: event_year"):
640+
table.update_schema().add_column("event_year", StringType()).commit()
641+
642+
with pytest.raises(ValueError, match="Cannot add column with name that conflicts with partition field: first_name"):
643+
table.update_schema().add_column("first_name", StringType()).commit()
644+
645+
table.update_schema().add_column("other_field", StringType()).commit()

tests/integration/test_writes/test_partitioned_writes.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -980,8 +980,16 @@ def test_append_ymd_transform_partitioned(
980980
# Given
981981
identifier = f"default.arrow_table_v{format_version}_with_{str(transform)}_partition_on_col_{part_col}"
982982
nested_field = TABLE_SCHEMA.find_field(part_col)
983+
984+
if isinstance(transform, YearTransform):
985+
partition_name = f"{part_col}_year"
986+
elif isinstance(transform, MonthTransform):
987+
partition_name = f"{part_col}_month"
988+
elif isinstance(transform, DayTransform):
989+
partition_name = f"{part_col}_day"
990+
983991
partition_spec = PartitionSpec(
984-
PartitionField(source_id=nested_field.field_id, field_id=1001, transform=transform, name=part_col)
992+
PartitionField(source_id=nested_field.field_id, field_id=1001, transform=transform, name=partition_name)
985993
)
986994

987995
# When
@@ -1037,8 +1045,18 @@ def test_append_transform_partition_verify_partitions_count(
10371045
part_col = "timestamptz"
10381046
identifier = f"default.arrow_table_v{format_version}_with_{str(transform)}_transform_partitioned_on_col_{part_col}"
10391047
nested_field = table_date_timestamps_schema.find_field(part_col)
1048+
1049+
if isinstance(transform, YearTransform):
1050+
partition_name = f"{part_col}_year"
1051+
elif isinstance(transform, MonthTransform):
1052+
partition_name = f"{part_col}_month"
1053+
elif isinstance(transform, DayTransform):
1054+
partition_name = f"{part_col}_day"
1055+
elif isinstance(transform, HourTransform):
1056+
partition_name = f"{part_col}_hour"
1057+
10401058
partition_spec = PartitionSpec(
1041-
PartitionField(source_id=nested_field.field_id, field_id=1001, transform=transform, name=part_col),
1059+
PartitionField(source_id=nested_field.field_id, field_id=1001, transform=transform, name=partition_name),
10421060
)
10431061

10441062
# When
@@ -1061,7 +1079,7 @@ def test_append_transform_partition_verify_partitions_count(
10611079

10621080
partitions_table = tbl.inspect.partitions()
10631081
assert partitions_table.num_rows == len(expected_partitions)
1064-
assert {part[part_col] for part in partitions_table["partition"].to_pylist()} == expected_partitions
1082+
assert {part[partition_name] for part in partitions_table["partition"].to_pylist()} == expected_partitions
10651083
files_df = spark.sql(
10661084
f"""
10671085
SELECT *

0 commit comments

Comments
 (0)