Skip to content

Commit 5a781df

Browse files
rutb327kevinjqliuFokko
authored
Validation partition against schema (#2305)
Closes [#2272](#2272) Collaborator: @geruh # Rationale for this change Implements the validation logic described in [#2272](#2272) to match Java and Rust behavior for partition field name conflicts with schema fields. This mirrors the method in Java checkAndAddPartitionName(): https://github.com/apache/iceberg/blob/4dbc7f578eee7ceb9def35ebfa1a4cc236fb598f/api/src/main/java/org/apache/iceberg/PartitionSpec.java#L392-L416 **Identity transforms** (`sourceColumnID != null`)- Allow schema field name conflicts only when sourced form the same field **Non-identity** (`sourceColumnID == null`)- Disallow any schema field name conflicts. In this PR `isinstance(transform, (IdentityTransform, VoidTransform))` is used to achieve the same logic as Java’s `sourceColumnID` check. # Are these changes tested? Yes, all existing tests pass and added a test covering validation scenarios. # Are there any user-facing changes? Yes. Non-identity transforms can no longer use schema field names as partition field names. --------- Co-authored-by: Kevin Liu <[email protected]> Co-authored-by: Kevin Liu <[email protected]> Co-authored-by: Fokko Driesprong <[email protected]>
1 parent 183333d commit 5a781df

File tree

6 files changed

+205
-70
lines changed

6 files changed

+205
-70
lines changed

pyiceberg/partitioning.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from dataclasses import dataclass
2222
from datetime import date, datetime, time
2323
from functools import cached_property, singledispatch
24-
from typing import Annotated, Any, Dict, Generic, List, Optional, Tuple, TypeVar, Union
24+
from typing import Annotated, Any, Dict, Generic, List, Optional, Set, Tuple, TypeVar, Union
2525
from urllib.parse import quote_plus
2626

2727
from pydantic import (
@@ -249,6 +249,31 @@ def partition_to_path(self, data: Record, schema: Schema) -> str:
249249
UNPARTITIONED_PARTITION_SPEC = PartitionSpec(spec_id=0)
250250

251251

252+
def validate_partition_name(
253+
field_name: str,
254+
partition_transform: Transform[Any, Any],
255+
source_id: int,
256+
schema: Schema,
257+
partition_names: Set[str],
258+
) -> None:
259+
"""Validate that a partition field name doesn't conflict with schema field names."""
260+
try:
261+
schema_field = schema.find_field(field_name)
262+
except ValueError:
263+
return # No conflict if field doesn't exist in schema
264+
265+
if isinstance(partition_transform, (IdentityTransform, VoidTransform)):
266+
# For identity and void transforms, allow conflict only if sourced from the same schema field
267+
if schema_field.field_id != source_id:
268+
raise ValueError(f"Cannot create identity partition sourced from different field in schema: {field_name}")
269+
else:
270+
raise ValueError(f"Cannot create partition with a name that exists in schema: {field_name}")
271+
if not field_name:
272+
raise ValueError("Undefined name")
273+
if field_name in partition_names:
274+
raise ValueError(f"Partition name has to be unique: {field_name}")
275+
276+
252277
def assign_fresh_partition_spec_ids(spec: PartitionSpec, old_schema: Schema, fresh_schema: Schema) -> PartitionSpec:
253278
partition_fields = []
254279
for pos, field in enumerate(spec.fields):
@@ -258,6 +283,9 @@ def assign_fresh_partition_spec_ids(spec: PartitionSpec, old_schema: Schema, fre
258283
fresh_field = fresh_schema.find_field(original_column_name)
259284
if fresh_field is None:
260285
raise ValueError(f"Could not find field in fresh schema: {original_column_name}")
286+
287+
validate_partition_name(field.name, field.transform, fresh_field.field_id, fresh_schema, set())
288+
261289
partition_fields.append(
262290
PartitionField(
263291
name=field.name,

pyiceberg/table/update/schema.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -662,6 +662,13 @@ def _apply(self) -> Schema:
662662

663663
# Check the field-ids
664664
new_schema = Schema(*struct.fields)
665+
from pyiceberg.partitioning import validate_partition_name
666+
667+
for spec in self._transaction.table_metadata.partition_specs:
668+
for partition_field in spec.fields:
669+
validate_partition_name(
670+
partition_field.name, partition_field.transform, partition_field.source_id, new_schema, set()
671+
)
665672
field_ids = set()
666673
for name in self._identifier_field_names:
667674
try:

pyiceberg/table/update/spec.py

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -174,26 +174,18 @@ def _commit(self) -> UpdatesAndRequirements:
174174
return updates, requirements
175175

176176
def _apply(self) -> PartitionSpec:
177-
def _check_and_add_partition_name(schema: Schema, name: str, source_id: int, partition_names: Set[str]) -> None:
178-
try:
179-
field = schema.find_field(name)
180-
except ValueError:
181-
field = None
182-
183-
if source_id is not None and field is not None and field.field_id != source_id:
184-
raise ValueError(f"Cannot create identity partition from a different field in the schema {name}")
185-
elif field is not None and source_id != field.field_id:
186-
raise ValueError(f"Cannot create partition from name that exists in schema {name}")
187-
if not name:
188-
raise ValueError("Undefined name")
189-
if name in partition_names:
190-
raise ValueError(f"Partition name has to be unique: {name}")
177+
def _check_and_add_partition_name(
178+
schema: Schema, name: str, source_id: int, transform: Transform[Any, Any], partition_names: Set[str]
179+
) -> None:
180+
from pyiceberg.partitioning import validate_partition_name
181+
182+
validate_partition_name(name, transform, source_id, schema, partition_names)
191183
partition_names.add(name)
192184

193185
def _add_new_field(
194186
schema: Schema, source_id: int, field_id: int, name: str, transform: Transform[Any, Any], partition_names: Set[str]
195187
) -> PartitionField:
196-
_check_and_add_partition_name(schema, name, source_id, partition_names)
188+
_check_and_add_partition_name(schema, name, source_id, transform, partition_names)
197189
return PartitionField(source_id, field_id, transform, name)
198190

199191
partition_fields = []
@@ -244,6 +236,13 @@ def _add_new_field(
244236
partition_fields.append(new_field)
245237

246238
for added_field in self._adds:
239+
_check_and_add_partition_name(
240+
self._transaction.table_metadata.schema(),
241+
added_field.name,
242+
added_field.source_id,
243+
added_field.transform,
244+
partition_names,
245+
)
247246
new_field = PartitionField(
248247
source_id=added_field.source_id,
249248
field_id=added_field.field_id,

tests/integration/test_partition_evolution.py

Lines changed: 85 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
from pyiceberg.catalog import Catalog
2222
from pyiceberg.exceptions import NoSuchTableError
23-
from pyiceberg.partitioning import PartitionField, PartitionSpec
23+
from pyiceberg.partitioning import UNPARTITIONED_PARTITION_SPEC, PartitionField, PartitionSpec
2424
from pyiceberg.schema import Schema
2525
from pyiceberg.table import Table
2626
from pyiceberg.transforms import (
@@ -63,13 +63,18 @@ def _table_v2(catalog: Catalog) -> Table:
6363
return _create_table_with_schema(catalog, schema_with_timestamp, "2")
6464

6565

66-
def _create_table_with_schema(catalog: Catalog, schema: Schema, format_version: str) -> Table:
66+
def _create_table_with_schema(
67+
catalog: Catalog, schema: Schema, format_version: str, partition_spec: PartitionSpec = UNPARTITIONED_PARTITION_SPEC
68+
) -> Table:
6769
tbl_name = "default.test_schema_evolution"
6870
try:
6971
catalog.drop_table(tbl_name)
7072
except NoSuchTableError:
7173
pass
72-
return catalog.create_table(identifier=tbl_name, schema=schema, properties={"format-version": format_version})
74+
75+
return catalog.create_table(
76+
identifier=tbl_name, schema=schema, partition_spec=partition_spec, properties={"format-version": format_version}
77+
)
7378

7479

7580
@pytest.mark.integration
@@ -564,3 +569,80 @@ def _validate_new_partition_fields(
564569
assert len(spec.fields) == len(expected_partition_fields)
565570
for i in range(len(spec.fields)):
566571
assert spec.fields[i] == expected_partition_fields[i]
572+
573+
574+
@pytest.mark.integration
575+
@pytest.mark.parametrize("catalog", [pytest.lazy_fixture("session_catalog_hive"), pytest.lazy_fixture("session_catalog")])
576+
def test_partition_schema_field_name_conflict(catalog: Catalog) -> None:
577+
schema = Schema(
578+
NestedField(1, "id", LongType(), required=False),
579+
NestedField(2, "event_ts", TimestampType(), required=False),
580+
NestedField(3, "another_ts", TimestampType(), required=False),
581+
NestedField(4, "str", StringType(), required=False),
582+
)
583+
table = _create_table_with_schema(catalog, schema, "2")
584+
585+
with pytest.raises(ValueError, match="Cannot create partition with a name that exists in schema: another_ts"):
586+
table.update_spec().add_field("event_ts", YearTransform(), "another_ts").commit()
587+
with pytest.raises(ValueError, match="Cannot create partition with a name that exists in schema: id"):
588+
table.update_spec().add_field("event_ts", DayTransform(), "id").commit()
589+
590+
with pytest.raises(ValueError, match="Cannot create identity partition sourced from different field in schema: another_ts"):
591+
table.update_spec().add_field("event_ts", IdentityTransform(), "another_ts").commit()
592+
with pytest.raises(ValueError, match="Cannot create identity partition sourced from different field in schema: str"):
593+
table.update_spec().add_field("id", IdentityTransform(), "str").commit()
594+
595+
table.update_spec().add_field("id", IdentityTransform(), "id").commit()
596+
table.update_spec().add_field("event_ts", YearTransform(), "event_year").commit()
597+
598+
599+
@pytest.mark.integration
600+
@pytest.mark.parametrize("catalog", [pytest.lazy_fixture("session_catalog_hive"), pytest.lazy_fixture("session_catalog")])
601+
def test_partition_validation_during_table_creation(catalog: Catalog) -> None:
602+
schema = Schema(
603+
NestedField(1, "id", LongType(), required=False),
604+
NestedField(2, "event_ts", TimestampType(), required=False),
605+
NestedField(3, "another_ts", TimestampType(), required=False),
606+
NestedField(4, "str", StringType(), required=False),
607+
)
608+
609+
partition_spec = PartitionSpec(
610+
PartitionField(source_id=2, field_id=1000, transform=YearTransform(), name="another_ts"), spec_id=1
611+
)
612+
with pytest.raises(ValueError, match="Cannot create partition with a name that exists in schema: another_ts"):
613+
_create_table_with_schema(catalog, schema, "2", partition_spec)
614+
615+
partition_spec = PartitionSpec(
616+
PartitionField(source_id=1, field_id=1000, transform=IdentityTransform(), name="id"), spec_id=1
617+
)
618+
_create_table_with_schema(catalog, schema, "2", partition_spec)
619+
620+
621+
@pytest.mark.integration
622+
@pytest.mark.parametrize("catalog", [pytest.lazy_fixture("session_catalog_hive"), pytest.lazy_fixture("session_catalog")])
623+
def test_schema_evolution_partition_conflict(catalog: Catalog) -> None:
624+
schema = Schema(
625+
NestedField(1, "id", LongType(), required=False),
626+
NestedField(2, "event_ts", TimestampType(), required=False),
627+
)
628+
partition_spec = PartitionSpec(
629+
PartitionField(source_id=2, field_id=1000, transform=YearTransform(), name="event_year"),
630+
PartitionField(source_id=2, field_id=1001, transform=IdentityTransform(), name="first_name"),
631+
PartitionField(source_id=1, field_id=1002, transform=IdentityTransform(), name="id"),
632+
spec_id=1,
633+
)
634+
table = _create_table_with_schema(catalog, schema, "2", partition_spec)
635+
636+
with pytest.raises(ValueError, match="Cannot create partition with a name that exists in schema: event_year"):
637+
table.update_schema().add_column("event_year", StringType()).commit()
638+
with pytest.raises(ValueError, match="Cannot create identity partition sourced from different field in schema: first_name"):
639+
table.update_schema().add_column("first_name", StringType()).commit()
640+
641+
table.update_schema().add_column("other_field", StringType()).commit()
642+
643+
with pytest.raises(ValueError, match="Cannot create partition with a name that exists in schema: event_year"):
644+
table.update_schema().rename_column("other_field", "event_year").commit()
645+
with pytest.raises(ValueError, match="Cannot create identity partition sourced from different field in schema: first_name"):
646+
table.update_schema().rename_column("other_field", "first_name").commit()
647+
648+
table.update_schema().rename_column("other_field", "valid_name").commit()

tests/integration/test_writes/test_partitioned_writes.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -980,8 +980,16 @@ def test_append_ymd_transform_partitioned(
980980
# Given
981981
identifier = f"default.arrow_table_v{format_version}_with_{str(transform)}_partition_on_col_{part_col}"
982982
nested_field = TABLE_SCHEMA.find_field(part_col)
983+
984+
if isinstance(transform, YearTransform):
985+
partition_name = f"{part_col}_year"
986+
elif isinstance(transform, MonthTransform):
987+
partition_name = f"{part_col}_month"
988+
elif isinstance(transform, DayTransform):
989+
partition_name = f"{part_col}_day"
990+
983991
partition_spec = PartitionSpec(
984-
PartitionField(source_id=nested_field.field_id, field_id=1001, transform=transform, name=part_col)
992+
PartitionField(source_id=nested_field.field_id, field_id=1001, transform=transform, name=partition_name)
985993
)
986994

987995
# When
@@ -1037,8 +1045,18 @@ def test_append_transform_partition_verify_partitions_count(
10371045
part_col = "timestamptz"
10381046
identifier = f"default.arrow_table_v{format_version}_with_{str(transform)}_transform_partitioned_on_col_{part_col}"
10391047
nested_field = table_date_timestamps_schema.find_field(part_col)
1048+
1049+
if isinstance(transform, YearTransform):
1050+
partition_name = f"{part_col}_year"
1051+
elif isinstance(transform, MonthTransform):
1052+
partition_name = f"{part_col}_month"
1053+
elif isinstance(transform, DayTransform):
1054+
partition_name = f"{part_col}_day"
1055+
elif isinstance(transform, HourTransform):
1056+
partition_name = f"{part_col}_hour"
1057+
10401058
partition_spec = PartitionSpec(
1041-
PartitionField(source_id=nested_field.field_id, field_id=1001, transform=transform, name=part_col),
1059+
PartitionField(source_id=nested_field.field_id, field_id=1001, transform=transform, name=partition_name),
10421060
)
10431061

10441062
# When
@@ -1061,7 +1079,7 @@ def test_append_transform_partition_verify_partitions_count(
10611079

10621080
partitions_table = tbl.inspect.partitions()
10631081
assert partitions_table.num_rows == len(expected_partitions)
1064-
assert {part[part_col] for part in partitions_table["partition"].to_pylist()} == expected_partitions
1082+
assert {part[partition_name] for part in partitions_table["partition"].to_pylist()} == expected_partitions
10651083
files_df = spark.sql(
10661084
f"""
10671085
SELECT *

0 commit comments

Comments
 (0)