Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions pyiceberg/partitioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,26 @@ def partition_to_path(self, data: Record, schema: Schema) -> str:
UNPARTITIONED_PARTITION_SPEC = PartitionSpec(spec_id=0)


def validate_partition_name(
field_name: str,
partition_transform: Transform[Any, Any],
source_id: int,
schema: Schema,
) -> None:
"""Validate that a partition field name doesn't conflict with schema field names."""
try:
schema_field = schema.find_field(field_name)
except ValueError:
return # No conflict if field doesn't exist in schema

if isinstance(partition_transform, (IdentityTransform, VoidTransform)):
# For identity transforms, allow conflict only if sourced from the same schema field
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# For identity transforms, allow conflict only if sourced from the same schema field
# For identity and void transforms, allow conflict only if sourced from the same schema field

if schema_field.field_id != source_id:
raise ValueError(f"Cannot create identity partition from a different source field in the schema: {field_name}")
else:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

match java error message

Suggested change
raise ValueError(f"Cannot create identity partition from a different source field in the schema: {field_name}")
else:
raise ValueError(f"Cannot create identity partition sourced from different field in schema: {field_name}")
else:

raise ValueError(f"Cannot create partition from name that exists in schema: {field_name}")


def assign_fresh_partition_spec_ids(spec: PartitionSpec, old_schema: Schema, fresh_schema: Schema) -> PartitionSpec:
partition_fields = []
for pos, field in enumerate(spec.fields):
Expand All @@ -258,6 +278,9 @@ def assign_fresh_partition_spec_ids(spec: PartitionSpec, old_schema: Schema, fre
fresh_field = fresh_schema.find_field(original_column_name)
if fresh_field is None:
raise ValueError(f"Could not find field in fresh schema: {original_column_name}")

validate_partition_name(field.name, field.transform, fresh_field.field_id, fresh_schema)

partition_fields.append(
PartitionField(
name=field.name,
Expand Down
8 changes: 8 additions & 0 deletions pyiceberg/table/update/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -658,6 +658,14 @@ def _apply(self) -> Schema:

# Check the field-ids
new_schema = Schema(*struct.fields)
if self._transaction is not None:
from pyiceberg.partitioning import validate_partition_name

for spec in self._transaction.table_metadata.partition_specs:
for partition_field in spec.fields:
validate_partition_name(
partition_field.name, partition_field.transform, partition_field.source_id, new_schema
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think there should always be a self._transaction

Suggested change
if self._transaction is not None:
from pyiceberg.partitioning import validate_partition_name
for spec in self._transaction.table_metadata.partition_specs:
for partition_field in spec.fields:
validate_partition_name(
partition_field.name, partition_field.transform, partition_field.source_id, new_schema
)
from pyiceberg.partitioning import validate_partition_name
for spec in self._transaction.table_metadata.partition_specs:
for partition_field in spec.fields:
validate_partition_name(
partition_field.name, partition_field.transform, partition_field.source_id, new_schema
)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okay, I'll do the suggested changes

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some tests show that transaction can be None in some cases, (after removing the check, tests from test_schema.py are failing). They use: UpdateSchema(transaction=None, schema=Schema())
https://github.com/rutb327/iceberg-python/blob/24b12ddd8fdab4a62650786a2c3cdd56a53f8719/tests/test_schema.py#L933

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks like everywhere else in the codebase we include transaction in UpdateSchema.

Maybe we can update the tests like this

def test_add_top_level_primitives(primitive_fields: List[NestedField], table_v2: Table) -> None:
    for primitive_field in primitive_fields:
        new_schema = Schema(primitive_field)
        applied = UpdateSchema(transaction=Transaction(table_v2), schema=Schema()).union_by_name(new_schema)._apply()  # type: ignore
        assert applied == new_schema

field_ids = set()
for name in self._identifier_field_names:
try:
Expand Down
25 changes: 14 additions & 11 deletions pyiceberg/table/update/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,16 +174,12 @@ def _commit(self) -> UpdatesAndRequirements:
return updates, requirements

def _apply(self) -> PartitionSpec:
def _check_and_add_partition_name(schema: Schema, name: str, source_id: int, partition_names: Set[str]) -> None:
try:
field = schema.find_field(name)
except ValueError:
field = None

if source_id is not None and field is not None and field.field_id != source_id:
raise ValueError(f"Cannot create identity partition from a different field in the schema {name}")
elif field is not None and source_id != field.field_id:
raise ValueError(f"Cannot create partition from name that exists in schema {name}")
def _check_and_add_partition_name(
schema: Schema, name: str, source_id: int, transform: Transform[Any, Any], partition_names: Set[str]
) -> None:
from pyiceberg.partitioning import validate_partition_name

validate_partition_name(name, transform, source_id, schema)
if not name:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can do that

raise ValueError("Undefined name")
if name in partition_names:
Expand All @@ -193,7 +189,7 @@ def _check_and_add_partition_name(schema: Schema, name: str, source_id: int, par
def _add_new_field(
schema: Schema, source_id: int, field_id: int, name: str, transform: Transform[Any, Any], partition_names: Set[str]
) -> PartitionField:
_check_and_add_partition_name(schema, name, source_id, partition_names)
_check_and_add_partition_name(schema, name, source_id, transform, partition_names)
return PartitionField(source_id, field_id, transform, name)

partition_fields = []
Expand Down Expand Up @@ -244,6 +240,13 @@ def _add_new_field(
partition_fields.append(new_field)

for added_field in self._adds:
_check_and_add_partition_name(
self._transaction.table_metadata.schema(),
added_field.name,
added_field.source_id,
added_field.transform,
partition_names,
)
Comment on lines +239 to +245
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch. just to confirm this covers the newly added partition fields?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, that's correct

new_field = PartitionField(
source_id=added_field.source_id,
field_id=added_field.field_id,
Expand Down
93 changes: 92 additions & 1 deletion tests/integration/test_partition_evolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
# pylint:disable=redefined-outer-name
from typing import Optional

import pytest

Expand Down Expand Up @@ -63,12 +64,19 @@ def _table_v2(catalog: Catalog) -> Table:
return _create_table_with_schema(catalog, schema_with_timestamp, "2")


def _create_table_with_schema(catalog: Catalog, schema: Schema, format_version: str) -> Table:
def _create_table_with_schema(
catalog: Catalog, schema: Schema, format_version: str, partition_spec: Optional[PartitionSpec] = None
) -> Table:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

following other create table helpers in tests, for example

def _create_table(
session_catalog: Catalog,
identifier: str,
format_version: int,
location: str,
partition_spec: PartitionSpec = UNPARTITIONED_PARTITION_SPEC,
schema: Schema = TABLE_SCHEMA,
) -> Table:
try:
session_catalog.drop_table(identifier=identifier)
except NoSuchTableError:
pass
return session_catalog.create_table(
identifier=identifier,
schema=schema,
location=location,
properties={"format-version": str(format_version)},
partition_spec=partition_spec,
)

Suggested change
def _create_table_with_schema(
catalog: Catalog, schema: Schema, format_version: str, partition_spec: Optional[PartitionSpec] = None
) -> Table:
def _create_table_with_schema(
catalog: Catalog, schema: Schema, format_version: str, partition_spec: PartitionSpec = UNPARTITIONED_PARTITION_SPEC
) -> Table:

tbl_name = "default.test_schema_evolution"
try:
catalog.drop_table(tbl_name)
except NoSuchTableError:
pass

if partition_spec:
return catalog.create_table(
identifier=tbl_name, schema=schema, partition_spec=partition_spec, properties={"format-version": format_version}
)
return catalog.create_table(identifier=tbl_name, schema=schema, properties={"format-version": format_version})
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and then we can just do this

Suggested change
if partition_spec:
return catalog.create_table(
identifier=tbl_name, schema=schema, partition_spec=partition_spec, properties={"format-version": format_version}
)
return catalog.create_table(identifier=tbl_name, schema=schema, properties={"format-version": format_version})
return catalog.create_table(
identifier=tbl_name, schema=schema, partition_spec=partition_spec, properties={"format-version": format_version}
)



Expand Down Expand Up @@ -564,3 +572,86 @@ def _validate_new_partition_fields(
assert len(spec.fields) == len(expected_partition_fields)
for i in range(len(spec.fields)):
assert spec.fields[i] == expected_partition_fields[i]


@pytest.mark.integration
@pytest.mark.parametrize("catalog", [pytest.lazy_fixture("session_catalog_hive"), pytest.lazy_fixture("session_catalog")])
def test_partition_schema_field_name_conflict(catalog: Catalog) -> None:
schema = Schema(
NestedField(1, "id", LongType(), required=False),
NestedField(2, "event_ts", TimestampType(), required=False),
NestedField(3, "another_ts", TimestampType(), required=False),
NestedField(4, "str", StringType(), required=False),
)
table = _create_table_with_schema(catalog, schema, "2")

with pytest.raises(ValueError, match="Cannot create partition from name that exists in schema: another_ts"):
table.update_spec().add_field("event_ts", YearTransform(), "another_ts").commit()
with pytest.raises(ValueError, match="Cannot create partition from name that exists in schema: id"):
table.update_spec().add_field("event_ts", DayTransform(), "id").commit()

with pytest.raises(
ValueError, match="Cannot create identity partition from a different source field in the schema: another_ts"
):
table.update_spec().add_field("event_ts", IdentityTransform(), "another_ts").commit()
with pytest.raises(ValueError, match="Cannot create identity partition from a different source field in the schema: str"):
table.update_spec().add_field("id", IdentityTransform(), "str").commit()

table.update_spec().add_field("id", IdentityTransform(), "id").commit()
table.update_spec().add_field("event_ts", YearTransform(), "event_year").commit()


@pytest.mark.integration
@pytest.mark.parametrize("catalog", [pytest.lazy_fixture("session_catalog_hive"), pytest.lazy_fixture("session_catalog")])
def test_partition_validation_during_table_creation(catalog: Catalog) -> None:
schema = Schema(
NestedField(1, "id", LongType(), required=False),
NestedField(2, "event_ts", TimestampType(), required=False),
NestedField(3, "another_ts", TimestampType(), required=False),
NestedField(4, "str", StringType(), required=False),
)

partition_spec = PartitionSpec(
PartitionField(source_id=2, field_id=1000, transform=YearTransform(), name="another_ts"), spec_id=1
)
with pytest.raises(ValueError, match="Cannot create partition from name that exists in schema: another_ts"):
_create_table_with_schema(catalog, schema, "2", partition_spec)

partition_spec = PartitionSpec(
PartitionField(source_id=1, field_id=1000, transform=IdentityTransform(), name="id"), spec_id=1
)
_create_table_with_schema(catalog, schema, "2", partition_spec)


@pytest.mark.integration
@pytest.mark.parametrize("catalog", [pytest.lazy_fixture("session_catalog_hive"), pytest.lazy_fixture("session_catalog")])
def test_schema_evolution_partition_conflict(catalog: Catalog) -> None:
schema = Schema(
NestedField(1, "id", LongType(), required=False),
NestedField(2, "event_ts", TimestampType(), required=False),
)
partition_spec = PartitionSpec(
PartitionField(source_id=2, field_id=1000, transform=YearTransform(), name="event_year"),
PartitionField(source_id=2, field_id=1001, transform=IdentityTransform(), name="first_name"),
PartitionField(source_id=1, field_id=1002, transform=IdentityTransform(), name="id"),
spec_id=1,
)
table = _create_table_with_schema(catalog, schema, "2", partition_spec)

with pytest.raises(ValueError, match="Cannot create partition from name that exists in schema: event_year"):
table.update_schema().add_column("event_year", StringType()).commit()
with pytest.raises(
ValueError, match="Cannot create identity partition from a different source field in the schema: first_name"
):
table.update_schema().add_column("first_name", StringType()).commit()

table.update_schema().add_column("other_field", StringType()).commit()

with pytest.raises(ValueError, match="Cannot create partition from name that exists in schema: event_year"):
table.update_schema().rename_column("other_field", "event_year").commit()
with pytest.raises(
ValueError, match="Cannot create identity partition from a different source field in the schema: first_name"
):
table.update_schema().rename_column("other_field", "first_name").commit()

table.update_schema().rename_column("other_field", "valid_name").commit()
24 changes: 21 additions & 3 deletions tests/integration/test_writes/test_partitioned_writes.py
Original file line number Diff line number Diff line change
Expand Up @@ -980,8 +980,16 @@ def test_append_ymd_transform_partitioned(
# Given
identifier = f"default.arrow_table_v{format_version}_with_{str(transform)}_partition_on_col_{part_col}"
nested_field = TABLE_SCHEMA.find_field(part_col)

if isinstance(transform, YearTransform):
partition_name = f"{part_col}_year"
elif isinstance(transform, MonthTransform):
partition_name = f"{part_col}_month"
elif isinstance(transform, DayTransform):
partition_name = f"{part_col}_day"

partition_spec = PartitionSpec(
PartitionField(source_id=nested_field.field_id, field_id=1001, transform=transform, name=part_col)
PartitionField(source_id=nested_field.field_id, field_id=1001, transform=transform, name=partition_name)
)

# When
Expand Down Expand Up @@ -1037,8 +1045,18 @@ def test_append_transform_partition_verify_partitions_count(
part_col = "timestamptz"
identifier = f"default.arrow_table_v{format_version}_with_{str(transform)}_transform_partitioned_on_col_{part_col}"
nested_field = table_date_timestamps_schema.find_field(part_col)

if isinstance(transform, YearTransform):
partition_name = f"{part_col}_year"
elif isinstance(transform, MonthTransform):
partition_name = f"{part_col}_month"
elif isinstance(transform, DayTransform):
partition_name = f"{part_col}_day"
elif isinstance(transform, HourTransform):
partition_name = f"{part_col}_hour"

partition_spec = PartitionSpec(
PartitionField(source_id=nested_field.field_id, field_id=1001, transform=transform, name=part_col),
PartitionField(source_id=nested_field.field_id, field_id=1001, transform=transform, name=partition_name),
)

# When
Expand All @@ -1061,7 +1079,7 @@ def test_append_transform_partition_verify_partitions_count(

partitions_table = tbl.inspect.partitions()
assert partitions_table.num_rows == len(expected_partitions)
assert {part[part_col] for part in partitions_table["partition"].to_pylist()} == expected_partitions
assert {part[partition_name] for part in partitions_table["partition"].to_pylist()} == expected_partitions
files_df = spark.sql(
f"""
SELECT *
Expand Down