Skip to content

Commit d4a4eed

Browse files
authored
Cast PyArrow schema to large_* types (#807)
* _pyarrow_with * fix * fix test * adopt review feedback * revert accidental conf change * adopt-nit
1 parent 2407a3c commit d4a4eed

File tree

7 files changed

+197
-70
lines changed

7 files changed

+197
-70
lines changed

pyiceberg/io/pyarrow.py

Lines changed: 46 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -504,7 +504,7 @@ def field(self, field: NestedField, field_result: pa.DataType) -> pa.Field:
504504

505505
def list(self, list_type: ListType, element_result: pa.DataType) -> pa.DataType:
506506
element_field = self.field(list_type.element_field, element_result)
507-
return pa.list_(value_type=element_field)
507+
return pa.large_list(value_type=element_field)
508508

509509
def map(self, map_type: MapType, key_result: pa.DataType, value_result: pa.DataType) -> pa.DataType:
510510
key_field = self.field(map_type.key_field, key_result)
@@ -548,7 +548,7 @@ def visit_timestamptz(self, _: TimestamptzType) -> pa.DataType:
548548
return pa.timestamp(unit="us", tz="UTC")
549549

550550
def visit_string(self, _: StringType) -> pa.DataType:
551-
return pa.string()
551+
return pa.large_string()
552552

553553
def visit_uuid(self, _: UUIDType) -> pa.DataType:
554554
return pa.binary(16)
@@ -680,6 +680,10 @@ def _pyarrow_to_schema_without_ids(schema: pa.Schema) -> Schema:
680680
return visit_pyarrow(schema, _ConvertToIcebergWithoutIDs())
681681

682682

683+
def _pyarrow_schema_ensure_large_types(schema: pa.Schema) -> pa.Schema:
684+
return visit_pyarrow(schema, _ConvertToLargeTypes())
685+
686+
683687
@singledispatch
684688
def visit_pyarrow(obj: Union[pa.DataType, pa.Schema], visitor: PyArrowSchemaVisitor[T]) -> T:
685689
"""Apply a pyarrow schema visitor to any point within a schema.
@@ -952,6 +956,30 @@ def after_map_value(self, element: pa.Field) -> None:
952956
self._field_names.pop()
953957

954958

959+
class _ConvertToLargeTypes(PyArrowSchemaVisitor[Union[pa.DataType, pa.Schema]]):
960+
def schema(self, schema: pa.Schema, struct_result: pa.StructType) -> pa.Schema:
961+
return pa.schema(struct_result)
962+
963+
def struct(self, struct: pa.StructType, field_results: List[pa.Field]) -> pa.StructType:
964+
return pa.struct(field_results)
965+
966+
def field(self, field: pa.Field, field_result: pa.DataType) -> pa.Field:
967+
return field.with_type(field_result)
968+
969+
def list(self, list_type: pa.ListType, element_result: pa.DataType) -> pa.DataType:
970+
return pa.large_list(element_result)
971+
972+
def map(self, map_type: pa.MapType, key_result: pa.DataType, value_result: pa.DataType) -> pa.DataType:
973+
return pa.map_(key_result, value_result)
974+
975+
def primitive(self, primitive: pa.DataType) -> pa.DataType:
976+
if primitive == pa.string():
977+
return pa.large_string()
978+
elif primitive == pa.binary():
979+
return pa.large_binary()
980+
return primitive
981+
982+
955983
class _ConvertToIcebergWithoutIDs(_ConvertToIceberg):
956984
"""
957985
Converts PyArrowSchema to Iceberg Schema with all -1 ids.
@@ -998,7 +1026,9 @@ def _task_to_table(
9981026

9991027
fragment_scanner = ds.Scanner.from_fragment(
10001028
fragment=fragment,
1001-
schema=physical_schema,
1029+
# We always use large types in memory as it uses larger offsets
1030+
# That can chunk more row values into the buffers
1031+
schema=_pyarrow_schema_ensure_large_types(physical_schema),
10021032
# This will push down the query to Arrow.
10031033
# But in case there are positional deletes, we have to apply them first
10041034
filter=pyarrow_filter if not positional_deletes else None,
@@ -1167,8 +1197,14 @@ def __init__(self, file_schema: Schema):
11671197

11681198
def _cast_if_needed(self, field: NestedField, values: pa.Array) -> pa.Array:
11691199
file_field = self.file_schema.find_field(field.field_id)
1170-
if field.field_type.is_primitive and field.field_type != file_field.field_type:
1171-
return values.cast(schema_to_pyarrow(promote(file_field.field_type, field.field_type), include_field_ids=False))
1200+
if field.field_type.is_primitive:
1201+
if field.field_type != file_field.field_type:
1202+
return values.cast(schema_to_pyarrow(promote(file_field.field_type, field.field_type), include_field_ids=False))
1203+
elif (target_type := schema_to_pyarrow(field.field_type, include_field_ids=False)) != values.type:
1204+
# if file_field and field_type (e.g. String) are the same
1205+
# but the pyarrow type of the array is different from the expected type
1206+
# (e.g. string vs larger_string), we want to cast the array to the larger type
1207+
return values.cast(target_type)
11721208
return values
11731209

11741210
def _construct_field(self, field: NestedField, arrow_type: pa.DataType) -> pa.Field:
@@ -1207,13 +1243,13 @@ def field(self, field: NestedField, _: Optional[pa.Array], field_array: Optional
12071243
return field_array
12081244

12091245
def list(self, list_type: ListType, list_array: Optional[pa.Array], value_array: Optional[pa.Array]) -> Optional[pa.Array]:
1210-
if isinstance(list_array, pa.ListArray) and value_array is not None:
1246+
if isinstance(list_array, (pa.ListArray, pa.LargeListArray, pa.FixedSizeListArray)) and value_array is not None:
12111247
if isinstance(value_array, pa.StructArray):
12121248
# This can be removed once this has been fixed:
12131249
# https://github.com/apache/arrow/issues/38809
1214-
list_array = pa.ListArray.from_arrays(list_array.offsets, value_array)
1250+
list_array = pa.LargeListArray.from_arrays(list_array.offsets, value_array)
12151251

1216-
arrow_field = pa.list_(self._construct_field(list_type.element_field, value_array.type))
1252+
arrow_field = pa.large_list(self._construct_field(list_type.element_field, value_array.type))
12171253
return list_array.cast(arrow_field)
12181254
else:
12191255
return None
@@ -1263,7 +1299,7 @@ def field_partner(self, partner_struct: Optional[pa.Array], field_id: int, _: st
12631299
return None
12641300

12651301
def list_element_partner(self, partner_list: Optional[pa.Array]) -> Optional[pa.Array]:
1266-
return partner_list.values if isinstance(partner_list, pa.ListArray) else None
1302+
return partner_list.values if isinstance(partner_list, (pa.ListArray, pa.LargeListArray, pa.FixedSizeListArray)) else None
12671303

12681304
def map_key_partner(self, partner_map: Optional[pa.Array]) -> Optional[pa.Array]:
12691305
return partner_map.keys if isinstance(partner_map, pa.MapArray) else None
@@ -1800,10 +1836,10 @@ def write_parquet(task: WriteTask) -> DataFile:
18001836
# otherwise use the original schema
18011837
if (sanitized_schema := sanitize_column_names(table_schema)) != table_schema:
18021838
file_schema = sanitized_schema
1803-
arrow_table = to_requested_schema(requested_schema=file_schema, file_schema=table_schema, table=arrow_table)
18041839
else:
18051840
file_schema = table_schema
18061841

1842+
arrow_table = to_requested_schema(requested_schema=file_schema, file_schema=table_schema, table=arrow_table)
18071843
file_path = f'{table_metadata.location}/data/{task.generate_data_file_path("parquet")}'
18081844
fo = io.new_output(file_path)
18091845
with fo.create(overwrite=True) as fos:

tests/catalog/test_sql.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,7 @@ def test_write_pyarrow_schema(catalog: SqlCatalog, table_identifier: Identifier)
288288
pa.array([None, "A", "B", "C"]), # 'large' column
289289
],
290290
schema=pa.schema([
291-
pa.field("foo", pa.string(), nullable=True),
291+
pa.field("foo", pa.large_string(), nullable=True),
292292
pa.field("bar", pa.int32(), nullable=False),
293293
pa.field("baz", pa.bool_(), nullable=True),
294294
pa.field("large", pa.large_string(), nullable=True),
@@ -1325,7 +1325,7 @@ def test_write_and_evolve(catalog: SqlCatalog, format_version: int) -> None:
13251325
{
13261326
"foo": ["a", None, "z"],
13271327
},
1328-
schema=pa.schema([pa.field("foo", pa.string(), nullable=True)]),
1328+
schema=pa.schema([pa.field("foo", pa.large_string(), nullable=True)]),
13291329
)
13301330

13311331
tbl = catalog.create_table(identifier=identifier, schema=pa_table.schema, properties={"format-version": str(format_version)})
@@ -1336,7 +1336,7 @@ def test_write_and_evolve(catalog: SqlCatalog, format_version: int) -> None:
13361336
"bar": [19, None, 25],
13371337
},
13381338
schema=pa.schema([
1339-
pa.field("foo", pa.string(), nullable=True),
1339+
pa.field("foo", pa.large_string(), nullable=True),
13401340
pa.field("bar", pa.int32(), nullable=True),
13411341
]),
13421342
)
@@ -1375,7 +1375,7 @@ def test_create_table_transaction(catalog: SqlCatalog, format_version: int) -> N
13751375
{
13761376
"foo": ["a", None, "z"],
13771377
},
1378-
schema=pa.schema([pa.field("foo", pa.string(), nullable=True)]),
1378+
schema=pa.schema([pa.field("foo", pa.large_string(), nullable=True)]),
13791379
)
13801380

13811381
pa_table_with_column = pa.Table.from_pydict(
@@ -1384,7 +1384,7 @@ def test_create_table_transaction(catalog: SqlCatalog, format_version: int) -> N
13841384
"bar": [19, None, 25],
13851385
},
13861386
schema=pa.schema([
1387-
pa.field("foo", pa.string(), nullable=True),
1387+
pa.field("foo", pa.large_string(), nullable=True),
13881388
pa.field("bar", pa.int32(), nullable=True),
13891389
]),
13901390
)

tests/conftest.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2116,8 +2116,8 @@ def pa_schema() -> "pa.Schema":
21162116

21172117
return pa.schema([
21182118
("bool", pa.bool_()),
2119-
("string", pa.string()),
2120-
("string_long", pa.string()),
2119+
("string", pa.large_string()),
2120+
("string_long", pa.large_string()),
21212121
("int", pa.int32()),
21222122
("long", pa.int64()),
21232123
("float", pa.float32()),

tests/integration/test_writes/test_writes.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -340,6 +340,60 @@ def test_python_writes_dictionary_encoded_column_with_spark_reads(
340340
assert spark_df.equals(pyiceberg_df)
341341

342342

343+
@pytest.mark.integration
344+
@pytest.mark.parametrize("format_version", [1, 2])
345+
def test_python_writes_with_small_and_large_types_spark_reads(
346+
spark: SparkSession, session_catalog: Catalog, format_version: int
347+
) -> None:
348+
identifier = "default.python_writes_with_small_and_large_types_spark_reads"
349+
TEST_DATA = {
350+
"foo": ["a", None, "z"],
351+
"id": [1, 2, 3],
352+
"name": ["AB", "CD", "EF"],
353+
"address": [
354+
{"street": "123", "city": "SFO", "zip": 12345, "bar": "a"},
355+
{"street": "456", "city": "SW", "zip": 67890, "bar": "b"},
356+
{"street": "789", "city": "Random", "zip": 10112, "bar": "c"},
357+
],
358+
}
359+
pa_schema = pa.schema([
360+
pa.field("foo", pa.large_string()),
361+
pa.field("id", pa.int32()),
362+
pa.field("name", pa.string()),
363+
pa.field(
364+
"address",
365+
pa.struct([
366+
pa.field("street", pa.string()),
367+
pa.field("city", pa.string()),
368+
pa.field("zip", pa.int32()),
369+
pa.field("bar", pa.large_string()),
370+
]),
371+
),
372+
])
373+
arrow_table = pa.Table.from_pydict(TEST_DATA, schema=pa_schema)
374+
tbl = _create_table(session_catalog, identifier, {"format-version": format_version}, schema=pa_schema)
375+
376+
tbl.overwrite(arrow_table)
377+
spark_df = spark.sql(f"SELECT * FROM {identifier}").toPandas()
378+
pyiceberg_df = tbl.scan().to_pandas()
379+
assert spark_df.equals(pyiceberg_df)
380+
arrow_table_on_read = tbl.scan().to_arrow()
381+
assert arrow_table_on_read.schema == pa.schema([
382+
pa.field("foo", pa.large_string()),
383+
pa.field("id", pa.int32()),
384+
pa.field("name", pa.large_string()),
385+
pa.field(
386+
"address",
387+
pa.struct([
388+
pa.field("street", pa.large_string()),
389+
pa.field("city", pa.large_string()),
390+
pa.field("zip", pa.int32()),
391+
pa.field("bar", pa.large_string()),
392+
]),
393+
),
394+
])
395+
396+
343397
@pytest.mark.integration
344398
def test_write_bin_pack_data_files(spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table) -> None:
345399
identifier = "default.write_bin_pack_data_files"

0 commit comments

Comments
 (0)