diff --git a/pyiceberg/partitioning.py b/pyiceberg/partitioning.py index dd707cea14..2070789a96 100644 --- a/pyiceberg/partitioning.py +++ b/pyiceberg/partitioning.py @@ -258,6 +258,21 @@ def assign_fresh_partition_spec_ids(spec: PartitionSpec, old_schema: Schema, fre fresh_field = fresh_schema.find_field(original_column_name) if fresh_field is None: raise ValueError(f"Could not find field in fresh schema: {original_column_name}") + + try: + schema_field = fresh_schema.find_field(field.name) + except ValueError: + schema_field = None + + if schema_field is not None: + if isinstance(field.transform, (IdentityTransform, VoidTransform)): + # For identity transforms, allow only if sourced from the same field + if schema_field.field_id != fresh_field.field_id: + raise ValueError(f"Cannot create identity partition sourced from different field in schema: {field.name}") + else: + # For non-identity transforms, never allow conflicts + raise ValueError(f"Cannot create partition from name that exists in schema: {field.name}") + partition_fields.append( PartitionField( name=field.name, diff --git a/pyiceberg/table/update/schema.py b/pyiceberg/table/update/schema.py index 6ad01e97f2..9fb87820d4 100644 --- a/pyiceberg/table/update/schema.py +++ b/pyiceberg/table/update/schema.py @@ -192,6 +192,22 @@ def add_column( parent_full_path = ".".join(parent) parent_id: int = TABLE_ROOT_ID + # Check for conflicts with partition field names + if self._transaction is not None: + for spec in self._transaction.table_metadata.partition_specs: + for field in spec.fields: + if field.name == name: + from pyiceberg.transforms import IdentityTransform, VoidTransform + + if isinstance(field.transform, (IdentityTransform, VoidTransform)): + # For identity transforms, allow conflict only if partition field sources from a field with same name + source_field = self._schema.find_field(field.source_id) + if source_field is None or source_field.name != name: + raise ValueError(f"Cannot add column with name that conflicts with partition field: {name}") + else: + # For non-identity transforms, never allow conflicts + raise ValueError(f"Cannot add column with name that conflicts with partition field: {name}") + if len(parent) > 0: parent_field = self._schema.find_field(parent_full_path, self._case_sensitive) parent_type = parent_field.field_type @@ -304,6 +320,20 @@ def rename_column(self, path_from: Union[str, Tuple[str, ...]], new_name: str) - if field_from.field_id in self._deletes: raise ValueError(f"Cannot rename a column that will be deleted: {path_from}") + if self._transaction is not None: + for spec in self._transaction.table_metadata.partition_specs: + for field in spec.fields: + if field.name == new_name: + from pyiceberg.transforms import IdentityTransform, VoidTransform + + if isinstance(field.transform, (IdentityTransform, VoidTransform)): + # For identity transforms, allow conflict only if partition field sources from the renamed field + if field.source_id != field_from.field_id: + raise ValueError(f"Cannot rename column to name that conflicts with partition field: {new_name}") + else: + # For non-identity transforms, never allow conflicts + raise ValueError(f"Cannot rename column to name that conflicts with partition field: {new_name}") + if updated := self._updates.get(field_from.field_id): self._updates[field_from.field_id] = NestedField( field_id=updated.field_id, diff --git a/pyiceberg/table/update/spec.py b/pyiceberg/table/update/spec.py index 1f91aa5d17..23651d5eb5 100644 --- a/pyiceberg/table/update/spec.py +++ b/pyiceberg/table/update/spec.py @@ -174,16 +174,21 @@ def _commit(self) -> UpdatesAndRequirements: return updates, requirements def _apply(self) -> PartitionSpec: - def _check_and_add_partition_name(schema: Schema, name: str, source_id: int, partition_names: Set[str]) -> None: + def _check_and_add_partition_name( + schema: Schema, name: str, source_id: int, transform: Transform[Any, Any], partition_names: Set[str] + ) -> None: try: field = schema.find_field(name) except ValueError: field = None - if source_id is not None and field is not None and field.field_id != source_id: - raise ValueError(f"Cannot create identity partition from a different field in the schema {name}") - elif field is not None and source_id != field.field_id: - raise ValueError(f"Cannot create partition from name that exists in schema {name}") + if field is not None: + if isinstance(transform, (IdentityTransform, VoidTransform)): + # For identity transforms allow name conflict only if sourced from the same schema field + if field.field_id != source_id: + raise ValueError(f"Cannot create identity partition from a different field in the schema: {name}") + else: + raise ValueError(f"Cannot create partition from name that exists in schema: {name}") if not name: raise ValueError("Undefined name") if name in partition_names: @@ -193,7 +198,7 @@ def _check_and_add_partition_name(schema: Schema, name: str, source_id: int, par def _add_new_field( schema: Schema, source_id: int, field_id: int, name: str, transform: Transform[Any, Any], partition_names: Set[str] ) -> PartitionField: - _check_and_add_partition_name(schema, name, source_id, partition_names) + _check_and_add_partition_name(schema, name, source_id, transform, partition_names) return PartitionField(source_id, field_id, transform, name) partition_fields = [] @@ -244,6 +249,13 @@ def _add_new_field( partition_fields.append(new_field) for added_field in self._adds: + _check_and_add_partition_name( + self._transaction.table_metadata.schema(), + added_field.name, + added_field.source_id, + added_field.transform, + partition_names, + ) new_field = PartitionField( source_id=added_field.source_id, field_id=added_field.field_id, diff --git a/tests/integration/test_partition_evolution.py b/tests/integration/test_partition_evolution.py index d489d6a5d0..186f0fbcef 100644 --- a/tests/integration/test_partition_evolution.py +++ b/tests/integration/test_partition_evolution.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. # pylint:disable=redefined-outer-name +from typing import Optional import pytest @@ -63,12 +64,19 @@ def _table_v2(catalog: Catalog) -> Table: return _create_table_with_schema(catalog, schema_with_timestamp, "2") -def _create_table_with_schema(catalog: Catalog, schema: Schema, format_version: str) -> Table: +def _create_table_with_schema( + catalog: Catalog, schema: Schema, format_version: str, partition_spec: Optional[PartitionSpec] = None +) -> Table: tbl_name = "default.test_schema_evolution" try: catalog.drop_table(tbl_name) except NoSuchTableError: pass + + if partition_spec: + return catalog.create_table( + identifier=tbl_name, schema=schema, partition_spec=partition_spec, properties={"format-version": format_version} + ) return catalog.create_table(identifier=tbl_name, schema=schema, properties={"format-version": format_version}) @@ -564,3 +572,80 @@ def _validate_new_partition_fields( assert len(spec.fields) == len(expected_partition_fields) for i in range(len(spec.fields)): assert spec.fields[i] == expected_partition_fields[i] + + +@pytest.mark.integration +@pytest.mark.parametrize("catalog", [pytest.lazy_fixture("session_catalog_hive"), pytest.lazy_fixture("session_catalog")]) +def test_partition_schema_field_name_conflict(catalog: Catalog) -> None: + schema = Schema( + NestedField(1, "id", LongType(), required=False), + NestedField(2, "event_ts", TimestampType(), required=False), + NestedField(3, "another_ts", TimestampType(), required=False), + NestedField(4, "str", StringType(), required=False), + ) + table = _create_table_with_schema(catalog, schema, "2") + + with pytest.raises(ValueError, match="Cannot create partition from name that exists in schema: another_ts"): + table.update_spec().add_field("event_ts", YearTransform(), "another_ts").commit() + with pytest.raises(ValueError, match="Cannot create partition from name that exists in schema: id"): + table.update_spec().add_field("event_ts", DayTransform(), "id").commit() + + with pytest.raises(ValueError, match="Cannot create identity partition from a different field in the schema: another_ts"): + table.update_spec().add_field("event_ts", IdentityTransform(), "another_ts").commit() + with pytest.raises(ValueError, match="Cannot create identity partition from a different field in the schema: str"): + table.update_spec().add_field("id", IdentityTransform(), "str").commit() + + table.update_spec().add_field("id", IdentityTransform(), "id").commit() + table.update_spec().add_field("event_ts", YearTransform(), "event_year").commit() + + +@pytest.mark.integration +@pytest.mark.parametrize("catalog", [pytest.lazy_fixture("session_catalog_hive"), pytest.lazy_fixture("session_catalog")]) +def test_partition_validation_during_table_creation(catalog: Catalog) -> None: + schema = Schema( + NestedField(1, "id", LongType(), required=False), + NestedField(2, "event_ts", TimestampType(), required=False), + NestedField(3, "another_ts", TimestampType(), required=False), + NestedField(4, "str", StringType(), required=False), + ) + + partition_spec = PartitionSpec( + PartitionField(source_id=2, field_id=1000, transform=YearTransform(), name="another_ts"), spec_id=1 + ) + with pytest.raises(ValueError, match="Cannot create partition from name that exists in schema: another_ts"): + _create_table_with_schema(catalog, schema, "2", partition_spec) + + partition_spec = PartitionSpec( + PartitionField(source_id=1, field_id=1000, transform=IdentityTransform(), name="id"), spec_id=1 + ) + _create_table_with_schema(catalog, schema, "2", partition_spec) + + +@pytest.mark.integration +@pytest.mark.parametrize("catalog", [pytest.lazy_fixture("session_catalog_hive"), pytest.lazy_fixture("session_catalog")]) +def test_schema_evolution_partition_conflict(catalog: Catalog) -> None: + schema = Schema( + NestedField(1, "id", LongType(), required=False), + NestedField(2, "event_ts", TimestampType(), required=False), + ) + partition_spec = PartitionSpec( + PartitionField(source_id=2, field_id=1000, transform=YearTransform(), name="event_year"), + PartitionField(source_id=2, field_id=1001, transform=IdentityTransform(), name="first_name"), + PartitionField(source_id=1, field_id=1002, transform=IdentityTransform(), name="id"), + spec_id=1, + ) + table = _create_table_with_schema(catalog, schema, "2", partition_spec) + + with pytest.raises(ValueError, match="Cannot add column with name that conflicts with partition field: event_year"): + table.update_schema().add_column("event_year", StringType()).commit() + with pytest.raises(ValueError, match="Cannot add column with name that conflicts with partition field: first_name"): + table.update_schema().add_column("first_name", StringType()).commit() + + table.update_schema().add_column("other_field", StringType()).commit() + + with pytest.raises(ValueError, match="Cannot rename column to name that conflicts with partition field: event_year"): + table.update_schema().rename_column("other_field", "event_year").commit() + with pytest.raises(ValueError, match="Cannot rename column to name that conflicts with partition field: first_name"): + table.update_schema().rename_column("other_field", "first_name").commit() + + table.update_schema().rename_column("other_field", "valid_name").commit() diff --git a/tests/integration/test_writes/test_partitioned_writes.py b/tests/integration/test_writes/test_partitioned_writes.py index e9698067c1..4b6c6a4d7b 100644 --- a/tests/integration/test_writes/test_partitioned_writes.py +++ b/tests/integration/test_writes/test_partitioned_writes.py @@ -980,8 +980,16 @@ def test_append_ymd_transform_partitioned( # Given identifier = f"default.arrow_table_v{format_version}_with_{str(transform)}_partition_on_col_{part_col}" nested_field = TABLE_SCHEMA.find_field(part_col) + + if isinstance(transform, YearTransform): + partition_name = f"{part_col}_year" + elif isinstance(transform, MonthTransform): + partition_name = f"{part_col}_month" + elif isinstance(transform, DayTransform): + partition_name = f"{part_col}_day" + partition_spec = PartitionSpec( - PartitionField(source_id=nested_field.field_id, field_id=1001, transform=transform, name=part_col) + PartitionField(source_id=nested_field.field_id, field_id=1001, transform=transform, name=partition_name) ) # When @@ -1037,8 +1045,18 @@ def test_append_transform_partition_verify_partitions_count( part_col = "timestamptz" identifier = f"default.arrow_table_v{format_version}_with_{str(transform)}_transform_partitioned_on_col_{part_col}" nested_field = table_date_timestamps_schema.find_field(part_col) + + if isinstance(transform, YearTransform): + partition_name = f"{part_col}_year" + elif isinstance(transform, MonthTransform): + partition_name = f"{part_col}_month" + elif isinstance(transform, DayTransform): + partition_name = f"{part_col}_day" + elif isinstance(transform, HourTransform): + partition_name = f"{part_col}_hour" + partition_spec = PartitionSpec( - PartitionField(source_id=nested_field.field_id, field_id=1001, transform=transform, name=part_col), + PartitionField(source_id=nested_field.field_id, field_id=1001, transform=transform, name=partition_name), ) # When @@ -1061,7 +1079,7 @@ def test_append_transform_partition_verify_partitions_count( partitions_table = tbl.inspect.partitions() assert partitions_table.num_rows == len(expected_partitions) - assert {part[part_col] for part in partitions_table["partition"].to_pylist()} == expected_partitions + assert {part[partition_name] for part in partitions_table["partition"].to_pylist()} == expected_partitions files_df = spark.sql( f""" SELECT *