Skip to content

Commit 5902602

Browse files
Fokkokevinjqliu
andauthored
Enable add tests migrated Hive tables (#2295)
<!-- Thanks for opening a pull request! --> <!-- In the case this PR will resolve an issue, please replace ${GITHUB_ISSUE_ID} below with the actual Github issue id. --> <!-- Closes #${GITHUB_ISSUE_ID} --> # Rationale for this change # Are these changes tested? # Are there any user-facing changes? <!-- In the case of user-facing changes, please add the changelog label. --> --------- Co-authored-by: Kevin Liu <[email protected]>
1 parent b67ef2e commit 5902602

File tree

3 files changed

+18
-67
lines changed

3 files changed

+18
-67
lines changed

pyiceberg/expressions/visitors.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -915,17 +915,13 @@ def visit_bound_predicate(self, predicate: BoundPredicate[L]) -> BooleanExpressi
915915

916916
# In the order described by the "Column Projection" section of the Iceberg spec:
917917
# https://iceberg.apache.org/spec/#column-projection
918-
# Evaluate column projection first if it exists
919-
if field_id in self.projected_field_values:
920-
if expression_evaluator(Schema(field), pred, case_sensitive=self.case_sensitive)(
921-
Record(self.projected_field_values[field_id])
922-
):
923-
return AlwaysTrue()
924-
925-
# Evaluate initial_default value
918+
# Evaluate column projection first if it exists, otherwise default to the initial-default-value
919+
field_value = (
920+
self.projected_field_values[field_id] if field.field_id in self.projected_field_values else field.initial_default
921+
)
926922
return (
927923
AlwaysTrue()
928-
if expression_evaluator(Schema(field), pred, case_sensitive=self.case_sensitive)(Record(field.initial_default))
924+
if expression_evaluator(Schema(field), pred, case_sensitive=self.case_sensitive)(Record(field_value))
929925
else AlwaysFalse()
930926
)
931927

@@ -940,7 +936,7 @@ def visit_bound_predicate(self, predicate: BoundPredicate[L]) -> BooleanExpressi
940936

941937

942938
def translate_column_names(
943-
expr: BooleanExpression, file_schema: Schema, case_sensitive: bool, projected_field_values: Dict[int, Any] = EMPTY_DICT
939+
expr: BooleanExpression, file_schema: Schema, case_sensitive: bool = True, projected_field_values: Dict[int, Any] = EMPTY_DICT
944940
) -> BooleanExpression:
945941
return visit(expr, _ColumnNameTranslator(file_schema, case_sensitive, projected_field_values))
946942

tests/expressions/test_visitors.py

Lines changed: 8 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1750,7 +1750,7 @@ def test_translate_column_names_missing_column_match_explicit_null() -> None:
17501750
)
17511751

17521752
# Translate column names
1753-
translated_expr = translate_column_names(bound_expr, file_schema, case_sensitive=True, projected_field_values={2: None})
1753+
translated_expr = translate_column_names(bound_expr, file_schema, projected_field_values={2: None})
17541754

17551755
# Should evaluate to AlwaysTrue because the missing column is treated as null
17561756
# missing_col's default initial_default (None) satisfies the IsNull predicate
@@ -1828,12 +1828,7 @@ def test_translate_column_names_missing_column_with_projected_field_matches() ->
18281828
)
18291829

18301830
# Projected column that is missing in the file schema
1831-
projected_field_values = {2: 42}
1832-
1833-
# Translate column names
1834-
translated_expr = translate_column_names(
1835-
bound_expr, file_schema, case_sensitive=True, projected_field_values=projected_field_values
1836-
)
1831+
translated_expr = translate_column_names(bound_expr, file_schema, projected_field_values={2: 42})
18371832

18381833
# Should evaluate to AlwaysTrue since projected field value matches the expression literal
18391834
# even though the field is missing in the file schema
@@ -1860,18 +1855,13 @@ def test_translate_column_names_missing_column_with_projected_field_mismatch() -
18601855
)
18611856

18621857
# Projected column that is missing in the file schema
1863-
projected_field_values = {2: 1}
1864-
1865-
# Translate column names
1866-
translated_expr = translate_column_names(
1867-
bound_expr, file_schema, case_sensitive=True, projected_field_values=projected_field_values
1868-
)
1858+
translated_expr = translate_column_names(bound_expr, file_schema, projected_field_values={2: 1})
18691859

18701860
# Should evaluate to AlwaysFalse since projected field value does not match the expression literal
18711861
assert translated_expr == AlwaysFalse()
18721862

18731863

1874-
def test_translate_column_names_missing_column_projected_field_fallbacks_to_initial_default() -> None:
1864+
def test_translate_column_names_missing_column_projected_field_ignores_initial_default() -> None:
18751865
"""Test translate_column_names when projected field value doesn't match but initial_default does."""
18761866
# Original schema with a field that has an initial_default
18771867
original_schema = Schema(
@@ -1891,43 +1881,11 @@ def test_translate_column_names_missing_column_projected_field_fallbacks_to_init
18911881
)
18921882

18931883
# Projected field value that differs from both the expression literal and initial_default
1894-
projected_field_values = {2: 10} # This doesn't match expression literal (42)
1895-
1896-
# Translate column names
1897-
translated_expr = translate_column_names(
1898-
bound_expr, file_schema, case_sensitive=True, projected_field_values=projected_field_values
1899-
)
1900-
1901-
# Should evaluate to AlwaysTrue since projected field value doesn't match but initial_default does
1902-
assert translated_expr == AlwaysTrue()
1903-
1904-
1905-
def test_translate_column_names_missing_column_projected_field_matches_initial_default_mismatch() -> None:
1906-
"""Test translate_column_names when both projected field value and initial_default doesn't match."""
1907-
# Original schema with a field that has an initial_default that doesn't match the expression
1908-
original_schema = Schema(
1909-
NestedField(field_id=1, name="existing_col", field_type=StringType(), required=False),
1910-
NestedField(field_id=2, name="missing_col", field_type=IntegerType(), required=False, initial_default=10),
1911-
schema_id=1,
1912-
)
1913-
1914-
# Create bound expression for the missing column
1915-
unbound_expr = EqualTo("missing_col", 42)
1916-
bound_expr = visit(unbound_expr, visitor=BindVisitor(schema=original_schema, case_sensitive=True))
1917-
1918-
# File schema only has the existing column (field_id=1), missing field_id=2
1919-
file_schema = Schema(
1920-
NestedField(field_id=1, name="existing_col", field_type=StringType(), required=False),
1921-
schema_id=1,
1922-
)
1923-
1924-
# Projected field value that matches the expression literal
1925-
projected_field_values = {2: 10} # This doesn't match expression literal (42)
1926-
1927-
# Translate column names
19281884
translated_expr = translate_column_names(
1929-
bound_expr, file_schema, case_sensitive=True, projected_field_values=projected_field_values
1885+
bound_expr,
1886+
file_schema,
1887+
projected_field_values={2: 10}, # This doesn't match expression literal (42)
19301888
)
19311889

1932-
# Should evaluate to AlwaysFalse since both projected field value and initial_default does not match
1890+
# Should evaluate to AlwaysFalse since projected field value doesn't match the expression literal
19331891
assert translated_expr == AlwaysFalse()

tests/integration/test_hive_migration.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717
import time
18+
from datetime import date
1819

1920
import pytest
2021
from pyspark.sql import SparkSession
@@ -75,12 +76,8 @@ def test_migrate_table(
7576
tbl = session_catalog_hive.load_table(dst_table_identifier)
7677
assert tbl.schema().column_names == ["number", "dt"]
7778

78-
# TODO: Returns the primitive type (int), rather than the logical type
79-
# assert set(tbl.scan().to_arrow().column(1).combine_chunks().tolist()) == {'2022-01-01', '2023-01-01'}
80-
79+
assert set(tbl.scan().to_arrow().column(1).combine_chunks().tolist()) == {date(2023, 1, 1), date(2022, 1, 1)}
8180
assert tbl.scan(row_filter="number > 3").to_arrow().column(0).combine_chunks().tolist() == [4, 5, 6]
82-
8381
assert tbl.scan(row_filter="dt == '2023-01-01'").to_arrow().column(0).combine_chunks().tolist() == [4, 5, 6]
84-
85-
# TODO: Issue with filtering the projected column
86-
# assert tbl.scan(row_filter="dt == '2022-01-01'").to_arrow().column(0).combine_chunks().tolist() == [1, 2, 3]
82+
assert tbl.scan(row_filter="dt == '2022-01-01'").to_arrow().column(0).combine_chunks().tolist() == [1, 2, 3]
83+
assert tbl.scan(row_filter="dt < '2022-02-01'").to_arrow().column(0).combine_chunks().tolist() == [1, 2, 3]

0 commit comments

Comments
 (0)