Skip to content

Commit 35d4648

Browse files
authored
Read: fetch file_schema directly from pyarrow_to_schema (#597)
1 parent 5039b5d commit 35d4648

File tree

5 files changed

+34
-15
lines changed

5 files changed

+34
-15
lines changed

pyiceberg/io/pyarrow.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,6 @@
122122
pre_order_visit,
123123
promote,
124124
prune_columns,
125-
sanitize_column_names,
126125
visit,
127126
visit_with_partner,
128127
)
@@ -966,20 +965,15 @@ def _task_to_table(
966965
with fs.open_input_file(path) as fin:
967966
fragment = arrow_format.make_fragment(fin)
968967
physical_schema = fragment.physical_schema
969-
schema_raw = None
970-
if metadata := physical_schema.metadata:
971-
schema_raw = metadata.get(ICEBERG_SCHEMA)
972-
file_schema = (
973-
Schema.model_validate_json(schema_raw) if schema_raw is not None else pyarrow_to_schema(physical_schema, name_mapping)
974-
)
968+
file_schema = pyarrow_to_schema(physical_schema, name_mapping)
975969

976970
pyarrow_filter = None
977971
if bound_row_filter is not AlwaysTrue():
978972
translated_row_filter = translate_column_names(bound_row_filter, file_schema, case_sensitive=case_sensitive)
979973
bound_file_filter = bind(file_schema, translated_row_filter, case_sensitive=case_sensitive)
980974
pyarrow_filter = expression_to_pyarrow(bound_file_filter)
981975

982-
file_project_schema = sanitize_column_names(prune_columns(file_schema, projected_field_ids, select_full_types=False))
976+
file_project_schema = prune_columns(file_schema, projected_field_ids, select_full_types=False)
983977

984978
if file_schema is None:
985979
raise ValueError(f"Missing Iceberg schema in Metadata for file: {path}")

tests/conftest.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1931,9 +1931,11 @@ def data_file(table_schema_simple: Schema, tmp_path: str) -> str:
19311931
import pyarrow as pa
19321932
from pyarrow import parquet as pq
19331933

1934+
from pyiceberg.io.pyarrow import schema_to_pyarrow
1935+
19341936
table = pa.table(
19351937
{"foo": ["a", "b", "c"], "bar": [1, 2, 3], "baz": [True, False, None]},
1936-
metadata={"iceberg.schema": table_schema_simple.model_dump_json()},
1938+
schema=schema_to_pyarrow(table_schema_simple),
19371939
)
19381940

19391941
file_path = f"{tmp_path}/0000-data.parquet"

tests/integration/test_writes/test_writes.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,28 @@ def get_current_snapshot_id(identifier: str) -> int:
271271
assert tbl.current_snapshot().snapshot_id == get_current_snapshot_id(identifier) # type: ignore
272272

273273

274+
@pytest.mark.integration
275+
@pytest.mark.parametrize("format_version", [1, 2])
276+
def test_python_writes_special_character_column_with_spark_reads(
277+
spark: SparkSession, session_catalog: Catalog, format_version: int
278+
) -> None:
279+
identifier = "default.python_writes_special_character_column_with_spark_reads"
280+
column_name_with_special_character = "letter/abc"
281+
TEST_DATA_WITH_SPECIAL_CHARACTER_COLUMN = {
282+
column_name_with_special_character: ['a', None, 'z'],
283+
}
284+
pa_schema = pa.schema([
285+
(column_name_with_special_character, pa.string()),
286+
])
287+
arrow_table_with_special_character_column = pa.Table.from_pydict(TEST_DATA_WITH_SPECIAL_CHARACTER_COLUMN, schema=pa_schema)
288+
tbl = _create_table(session_catalog, identifier, {"format-version": format_version}, schema=pa_schema)
289+
290+
tbl.overwrite(arrow_table_with_special_character_column)
291+
spark_df = spark.sql(f"SELECT * FROM {identifier}").toPandas()
292+
pyiceberg_df = tbl.scan().to_pandas()
293+
assert spark_df.equals(pyiceberg_df)
294+
295+
274296
@pytest.mark.integration
275297
def test_write_bin_pack_data_files(spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table) -> None:
276298
identifier = "default.write_bin_pack_data_files"

tests/integration/test_writes/utils.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717
# pylint:disable=redefined-outer-name
18-
from typing import List, Optional
18+
from typing import List, Optional, Union
1919

2020
import pyarrow as pa
2121

@@ -65,6 +65,7 @@ def _create_table(
6565
properties: Properties,
6666
data: Optional[List[pa.Table]] = None,
6767
partition_spec: Optional[PartitionSpec] = None,
68+
schema: Union[Schema, "pa.Schema"] = TABLE_SCHEMA,
6869
) -> Table:
6970
try:
7071
session_catalog.drop_table(identifier=identifier)
@@ -73,10 +74,10 @@ def _create_table(
7374

7475
if partition_spec:
7576
tbl = session_catalog.create_table(
76-
identifier=identifier, schema=TABLE_SCHEMA, properties=properties, partition_spec=partition_spec
77+
identifier=identifier, schema=schema, properties=properties, partition_spec=partition_spec
7778
)
7879
else:
79-
tbl = session_catalog.create_table(identifier=identifier, schema=TABLE_SCHEMA, properties=properties)
80+
tbl = session_catalog.create_table(identifier=identifier, schema=schema, properties=properties)
8081

8182
if data:
8283
for d in data:

tests/io/test_pyarrow.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1373,7 +1373,7 @@ def test_delete(deletes_file: str, example_task: FileScanTask, table_schema_simp
13731373
str(with_deletes)
13741374
== """pyarrow.Table
13751375
foo: string
1376-
bar: int64 not null
1376+
bar: int32 not null
13771377
baz: bool
13781378
----
13791379
foo: [["a","c"]]
@@ -1411,7 +1411,7 @@ def test_delete_duplicates(deletes_file: str, example_task: FileScanTask, table_
14111411
str(with_deletes)
14121412
== """pyarrow.Table
14131413
foo: string
1414-
bar: int64 not null
1414+
bar: int32 not null
14151415
baz: bool
14161416
----
14171417
foo: [["a","c"]]
@@ -1442,7 +1442,7 @@ def test_pyarrow_wrap_fsspec(example_task: FileScanTask, table_schema_simple: Sc
14421442
str(projection)
14431443
== """pyarrow.Table
14441444
foo: string
1445-
bar: int64 not null
1445+
bar: int32 not null
14461446
baz: bool
14471447
----
14481448
foo: [["a","b","c"]]

0 commit comments

Comments
 (0)