Skip to content

Commit 4c1cfdc

Browse files
authored
Cast data to Iceberg Table's pyarrow schema (#523)
* cast to pyarrow schema * use Schema.as_arrow() * also for append * _check_schema_compatible * comment * use .as_arrow() * add test for downcast schema * cast only when necessary
1 parent 35f6f33 commit 4c1cfdc

File tree

4 files changed

+70
-9
lines changed

4 files changed

+70
-9
lines changed

pyiceberg/io/pyarrow.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1774,7 +1774,7 @@ def write_file(io: FileIO, table_metadata: TableMetadata, tasks: Iterator[WriteT
17741774

17751775
file_path = f'{table_metadata.location}/data/{task.generate_data_file_filename("parquet")}'
17761776
schema = table_metadata.schema()
1777-
arrow_file_schema = schema_to_pyarrow(schema)
1777+
arrow_file_schema = schema.as_arrow()
17781778

17791779
fo = io.new_output(file_path)
17801780
row_group_size = PropertyUtil.property_as_int(

pyiceberg/table/__init__.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,15 @@
145145
_JAVA_LONG_MAX = 9223372036854775807
146146

147147

148-
def _check_schema(table_schema: Schema, other_schema: "pa.Schema") -> None:
148+
def _check_schema_compatible(table_schema: Schema, other_schema: "pa.Schema") -> None:
149+
"""
150+
Check if the `table_schema` is compatible with `other_schema`.
151+
152+
Two schemas are considered compatible when they are equal in terms of the Iceberg Schema type.
153+
154+
Raises:
155+
ValueError: If the schemas are not compatible.
156+
"""
149157
from pyiceberg.io.pyarrow import _pyarrow_to_schema_without_ids, pyarrow_to_schema
150158

151159
name_mapping = table_schema.name_mapping
@@ -1118,7 +1126,10 @@ def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT)
11181126
if len(self.spec().fields) > 0:
11191127
raise ValueError("Cannot write to partitioned tables")
11201128

1121-
_check_schema(self.schema(), other_schema=df.schema)
1129+
_check_schema_compatible(self.schema(), other_schema=df.schema)
1130+
# cast if the two schemas are compatible but not equal
1131+
if self.schema().as_arrow() != df.schema:
1132+
df = df.cast(self.schema().as_arrow())
11221133

11231134
with self.transaction() as txn:
11241135
with txn.update_snapshot(snapshot_properties=snapshot_properties).fast_append() as update_snapshot:
@@ -1156,7 +1167,10 @@ def overwrite(
11561167
if len(self.spec().fields) > 0:
11571168
raise ValueError("Cannot write to partitioned tables")
11581169

1159-
_check_schema(self.schema(), other_schema=df.schema)
1170+
_check_schema_compatible(self.schema(), other_schema=df.schema)
1171+
# cast if the two schemas are compatible but not equal
1172+
if self.schema().as_arrow() != df.schema:
1173+
df = df.cast(self.schema().as_arrow())
11601174

11611175
with self.transaction() as txn:
11621176
with txn.update_snapshot(snapshot_properties=snapshot_properties).overwrite() as update_snapshot:

tests/catalog/test_sql.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,39 @@ def test_create_table_with_pyarrow_schema(
193193
catalog.drop_table(random_identifier)
194194

195195

196+
@pytest.mark.parametrize(
197+
'catalog',
198+
[
199+
lazy_fixture('catalog_memory'),
200+
# lazy_fixture('catalog_sqlite'),
201+
],
202+
)
203+
def test_write_pyarrow_schema(catalog: SqlCatalog, random_identifier: Identifier) -> None:
204+
import pyarrow as pa
205+
206+
pyarrow_table = pa.Table.from_arrays(
207+
[
208+
pa.array([None, "A", "B", "C"]), # 'foo' column
209+
pa.array([1, 2, 3, 4]), # 'bar' column
210+
pa.array([True, None, False, True]), # 'baz' column
211+
pa.array([None, "A", "B", "C"]), # 'large' column
212+
],
213+
schema=pa.schema([
214+
pa.field('foo', pa.string(), nullable=True),
215+
pa.field('bar', pa.int32(), nullable=False),
216+
pa.field('baz', pa.bool_(), nullable=True),
217+
pa.field('large', pa.large_string(), nullable=True),
218+
]),
219+
)
220+
database_name, _table_name = random_identifier
221+
catalog.create_namespace(database_name)
222+
table = catalog.create_table(random_identifier, pyarrow_table.schema)
223+
print(pyarrow_table.schema)
224+
print(table.schema().as_struct())
225+
print()
226+
table.overwrite(pyarrow_table)
227+
228+
196229
@pytest.mark.parametrize(
197230
'catalog',
198231
[

tests/table/test_init.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@
6363
TableIdentifier,
6464
UpdateSchema,
6565
_apply_table_update,
66-
_check_schema,
66+
_check_schema_compatible,
6767
_match_deletes_to_data_file,
6868
_TableMetadataUpdateContext,
6969
update_table_metadata,
@@ -1033,7 +1033,7 @@ def test_schema_mismatch_type(table_schema_simple: Schema) -> None:
10331033
"""
10341034

10351035
with pytest.raises(ValueError, match=expected):
1036-
_check_schema(table_schema_simple, other_schema)
1036+
_check_schema_compatible(table_schema_simple, other_schema)
10371037

10381038

10391039
def test_schema_mismatch_nullability(table_schema_simple: Schema) -> None:
@@ -1054,7 +1054,7 @@ def test_schema_mismatch_nullability(table_schema_simple: Schema) -> None:
10541054
"""
10551055

10561056
with pytest.raises(ValueError, match=expected):
1057-
_check_schema(table_schema_simple, other_schema)
1057+
_check_schema_compatible(table_schema_simple, other_schema)
10581058

10591059

10601060
def test_schema_mismatch_missing_field(table_schema_simple: Schema) -> None:
@@ -1074,7 +1074,7 @@ def test_schema_mismatch_missing_field(table_schema_simple: Schema) -> None:
10741074
"""
10751075

10761076
with pytest.raises(ValueError, match=expected):
1077-
_check_schema(table_schema_simple, other_schema)
1077+
_check_schema_compatible(table_schema_simple, other_schema)
10781078

10791079

10801080
def test_schema_mismatch_additional_field(table_schema_simple: Schema) -> None:
@@ -1088,7 +1088,21 @@ def test_schema_mismatch_additional_field(table_schema_simple: Schema) -> None:
10881088
expected = r"PyArrow table contains more columns: new_field. Update the schema first \(hint, use union_by_name\)."
10891089

10901090
with pytest.raises(ValueError, match=expected):
1091-
_check_schema(table_schema_simple, other_schema)
1091+
_check_schema_compatible(table_schema_simple, other_schema)
1092+
1093+
1094+
def test_schema_downcast(table_schema_simple: Schema) -> None:
1095+
# large_string type is compatible with string type
1096+
other_schema = pa.schema((
1097+
pa.field("foo", pa.large_string(), nullable=True),
1098+
pa.field("bar", pa.int32(), nullable=False),
1099+
pa.field("baz", pa.bool_(), nullable=True),
1100+
))
1101+
1102+
try:
1103+
_check_schema_compatible(table_schema_simple, other_schema)
1104+
except Exception:
1105+
pytest.fail("Unexpected Exception raised when calling `_check_schema`")
10921106

10931107

10941108
def test_table_properties(example_table_metadata_v2: Dict[str, Any]) -> None:

0 commit comments

Comments
 (0)