Skip to content

Commit bbb1c25

Browse files
Fokkodingo4dev
andauthored
Fix UUID support (#2007)
# Rationale for this change The UUID support is a gift that keeps on giving. The current support of PyIceberg is incomplete, and problematic. Mostly because: - It is an extension-type in Arrow, which means it is not fully supported: apache/arrow#46469 apache/arrow#46468 - It doesn't have native support in Spark, where it is converted into a string. This limits the current tests, which are mostly Spark-based. I think we have to wait for some fixes in Arrow upstream until we can fully support this. In PyIceberg, we're converting the `fixed[16]` to a `UUID`, but Spark does seem to error because the logical type annotation in Parquet is missing: ``` E py4j.protocol.Py4JJavaError: An error occurred while calling o72.collectToPython. E : org.apache.spark.SparkException: Job aborted due to stage failure: Task 0 in stage 1.0 failed 1 times, most recent failure: Lost task 0.0 in stage 1.0 (TID 1) (localhost executor driver): java.lang.UnsupportedOperationException: Unsupported type: UTF8String E at org.apache.iceberg.arrow.vectorized.ArrowVectorAccessor.getUTF8String(ArrowVectorAccessor.java:81) E at org.apache.iceberg.spark.data.vectorized.IcebergArrowColumnVector.getUTF8String(IcebergArrowColumnVector.java:143) E at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage1.processNext(Unknown Source) E at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43) E at org.apache.spark.sql.execution.WholeStageCodegenEvaluatorFactory$WholeStageCodegenPartitionEvaluator$$anon$1.hasNext(WholeStageCodegenEvaluatorFactory.scala:43) E at org.apache.spark.sql.execution.SparkPlan.$anonfun$getByteArrayRdd$1(SparkPlan.scala:388) E at org.apache.spark.rdd.RDD.$anonfun$mapPartitionsInternal$2(RDD.scala:893) E at org.apache.spark.rdd.RDD.$anonfun$mapPartitionsInternal$2$adapted(RDD.scala:893) E at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52) E at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:367) E at org.apache.spark.rdd.RDD.iterator(RDD.scala:331) E at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:93) E at org.apache.spark.TaskContext.runTaskWithListeners(TaskContext.scala:166) E at org.apache.spark.scheduler.Task.run(Task.scala:141) E at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$4(Executor.scala:620) E at org.apache.spark.util.SparkErrorUtils.tryWithSafeFinally(SparkErrorUtils.scala:64) E at org.apache.spark.util.SparkErrorUtils.tryWithSafeFinally$(SparkErrorUtils.scala:61) E at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:94) E at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:623) E at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1128) E at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:628) E at java.base/java.lang.Thread.run(Thread.java:829) E E Driver stacktrace: E at org.apache.spark.scheduler.DAGScheduler.failJobAndIndependentStages(DAGScheduler.scala:2856) E at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2(DAGScheduler.scala:2792) E at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2$adapted(DAGScheduler.scala:2791) E at scala.collection.mutable.ResizableArray.foreach(ResizableArray.scala:62) E at scala.collection.mutable.ResizableArray.foreach$(ResizableArray.scala:55) E at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:49) E at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:2791) E at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1(DAGScheduler.scala:1247) E at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1$adapted(DAGScheduler.scala:1247) E at scala.Option.foreach(Option.scala:407) E at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:1247) E at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:3060) E at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2994) E at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2983) E at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:49) E at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:989) E at org.apache.spark.SparkContext.runJob(SparkContext.scala:2393) E at org.apache.spark.SparkContext.runJob(SparkContext.scala:2414) E at org.apache.spark.SparkContext.runJob(SparkContext.scala:2433) E at org.apache.spark.SparkContext.runJob(SparkContext.scala:2458) E at org.apache.spark.rdd.RDD.$anonfun$collect$1(RDD.scala:1049) E at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151) E at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:112) E at org.apache.spark.rdd.RDD.withScope(RDD.scala:410) E at org.apache.spark.rdd.RDD.collect(RDD.scala:1048) E at org.apache.spark.sql.execution.SparkPlan.executeCollect(SparkPlan.scala:448) E at org.apache.spark.sql.Dataset.$anonfun$collectToPython$1(Dataset.scala:4149) E at org.apache.spark.sql.Dataset.$anonfun$withAction$2(Dataset.scala:4323) E at org.apache.spark.sql.execution.QueryExecution$.withInternalError(QueryExecution.scala:546) E at org.apache.spark.sql.Dataset.$anonfun$withAction$1(Dataset.scala:4321) E at org.apache.spark.sql.execution.SQLExecution$.$anonfun$withNewExecutionId$6(SQLExecution.scala:125) E at org.apache.spark.sql.execution.SQLExecution$.withSQLConfPropagated(SQLExecution.scala:201) E at org.apache.spark.sql.execution.SQLExecution$.$anonfun$withNewExecutionId$1(SQLExecution.scala:108) E at org.apache.spark.sql.SparkSession.withActive(SparkSession.scala:900) E at org.apache.spark.sql.execution.SQLExecution$.withNewExecutionId(SQLExecution.scala:66) E at org.apache.spark.sql.Dataset.withAction(Dataset.scala:4321) E at org.apache.spark.sql.Dataset.collectToPython(Dataset.scala:4146) E at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke0(Native Method) E at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62) E at java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43) E at java.base/java.lang.reflect.Method.invoke(Method.java:566) E at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244) E at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:374) E at py4j.Gateway.invoke(Gateway.java:282) E at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132) E at py4j.commands.CallCommand.execute(CallCommand.java:79) E at py4j.ClientServerConnection.waitForCommands(ClientServerConnection.java:182) E at py4j.ClientServerConnection.run(ClientServerConnection.java:106) E at java.base/java.lang.Thread.run(Thread.java:829) E Caused by: java.lang.UnsupportedOperationException: Unsupported type: UTF8String E at org.apache.iceberg.arrow.vectorized.ArrowVectorAccessor.getUTF8String(ArrowVectorAccessor.java:81) E at org.apache.iceberg.spark.data.vectorized.IcebergArrowColumnVector.getUTF8String(IcebergArrowColumnVector.java:143) E at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage1.processNext(Unknown Source) E at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43) E at org.apache.spark.sql.execution.WholeStageCodegenEvaluatorFactory$WholeStageCodegenPartitionEvaluator$$anon$1.hasNext(WholeStageCodegenEvaluatorFactory.scala:43) E at org.apache.spark.sql.execution.SparkPlan.$anonfun$getByteArrayRdd$1(SparkPlan.scala:388) E at org.apache.spark.rdd.RDD.$anonfun$mapPartitionsInternal$2(RDD.scala:893) E at org.apache.spark.rdd.RDD.$anonfun$mapPartitionsInternal$2$adapted(RDD.scala:893) E at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52) E at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:367) E at org.apache.spark.rdd.RDD.iterator(RDD.scala:331) E at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:93) E at org.apache.spark.TaskContext.runTaskWithListeners(TaskContext.scala:166) E at org.apache.spark.scheduler.Task.run(Task.scala:141) E at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$4(Executor.scala:620) E at org.apache.spark.util.SparkErrorUtils.tryWithSafeFinally(SparkErrorUtils.scala:64) E at org.apache.spark.util.SparkErrorUtils.tryWithSafeFinally$(SparkErrorUtils.scala:61) E at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:94) E at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:623) E at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1128) E at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:628) E ... 1 more ``` # Are these changes tested? # Are there any user-facing changes? Closes #1986 Closes #2002 <!-- In the case of user-facing changes, please add the changelog label. --> --------- Co-authored-by: DinGo4DEV <[email protected]>
1 parent 2b9f9e2 commit bbb1c25

File tree

8 files changed

+88
-35
lines changed

8 files changed

+88
-35
lines changed

pyiceberg/avro/writer.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
List,
3333
Optional,
3434
Tuple,
35+
Union,
3536
)
3637
from uuid import UUID
3738

@@ -121,8 +122,11 @@ def write(self, encoder: BinaryEncoder, val: Any) -> None:
121122

122123
@dataclass(frozen=True)
123124
class UUIDWriter(Writer):
124-
def write(self, encoder: BinaryEncoder, val: UUID) -> None:
125-
encoder.write(val.bytes)
125+
def write(self, encoder: BinaryEncoder, val: Union[UUID, bytes]) -> None:
126+
if isinstance(val, UUID):
127+
encoder.write(val.bytes)
128+
else:
129+
encoder.write(val)
126130

127131

128132
@dataclass(frozen=True)

pyiceberg/io/pyarrow.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -746,7 +746,7 @@ def visit_string(self, _: StringType) -> pa.DataType:
746746
return pa.large_string()
747747

748748
def visit_uuid(self, _: UUIDType) -> pa.DataType:
749-
return pa.binary(16)
749+
return pa.uuid()
750750

751751
def visit_unknown(self, _: UnknownType) -> pa.DataType:
752752
return pa.null()
@@ -1307,6 +1307,8 @@ def primitive(self, primitive: pa.DataType) -> PrimitiveType:
13071307
return FixedType(primitive.byte_width)
13081308
elif pa.types.is_null(primitive):
13091309
return UnknownType()
1310+
elif isinstance(primitive, pa.UuidType):
1311+
return UUIDType()
13101312

13111313
raise TypeError(f"Unsupported type: {primitive}")
13121314

pyiceberg/partitioning.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -467,8 +467,17 @@ def _(type: IcebergType, value: Optional[time]) -> Optional[int]:
467467

468468

469469
@_to_partition_representation.register(UUIDType)
470-
def _(type: IcebergType, value: Optional[uuid.UUID]) -> Optional[str]:
471-
return str(value) if value is not None else None
470+
def _(type: IcebergType, value: Optional[Union[uuid.UUID, int, bytes]]) -> Optional[Union[bytes, int]]:
471+
if value is None:
472+
return None
473+
elif isinstance(value, bytes):
474+
return value # IdentityTransform
475+
elif isinstance(value, uuid.UUID):
476+
return value.bytes # IdentityTransform
477+
elif isinstance(value, int):
478+
return value # BucketTransform
479+
else:
480+
raise ValueError(f"Type not recognized: {value}")
472481

473482

474483
@_to_partition_representation.register(PrimitiveType)

tests/conftest.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2827,7 +2827,7 @@ def pyarrow_schema_with_promoted_types() -> "pa.Schema":
28272827
pa.field("list", pa.list_(pa.int32()), nullable=False), # can support upcasting integer to long
28282828
pa.field("map", pa.map_(pa.string(), pa.int32()), nullable=False), # can support upcasting integer to long
28292829
pa.field("double", pa.float32(), nullable=True), # can support upcasting float to double
2830-
pa.field("uuid", pa.binary(length=16), nullable=True), # can support upcasting float to double
2830+
pa.field("uuid", pa.binary(length=16), nullable=True), # can support upcasting fixed to uuid
28312831
)
28322832
)
28332833

@@ -2843,7 +2843,10 @@ def pyarrow_table_with_promoted_types(pyarrow_schema_with_promoted_types: "pa.Sc
28432843
"list": [[1, 1], [2, 2]],
28442844
"map": [{"a": 1}, {"b": 2}],
28452845
"double": [1.1, 9.2],
2846-
"uuid": [b"qZx\xefNS@\x89\x9b\xf9:\xd0\xee\x9b\xf5E", b"\x97]\x87T^JDJ\x96\x97\xf4v\xe4\x03\x0c\xde"],
2846+
"uuid": [
2847+
uuid.UUID("00000000-0000-0000-0000-000000000000").bytes,
2848+
uuid.UUID("11111111-1111-1111-1111-111111111111").bytes,
2849+
],
28472850
},
28482851
schema=pyarrow_schema_with_promoted_types,
28492852
)

tests/integration/test_add_files.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -737,7 +737,7 @@ def test_add_files_with_valid_upcast(
737737
with pq.ParquetWriter(fos, schema=pyarrow_schema_with_promoted_types) as writer:
738738
writer.write_table(pyarrow_table_with_promoted_types)
739739

740-
tbl.add_files(file_paths=[file_path])
740+
tbl.add_files(file_paths=[file_path], check_duplicate_files=False)
741741
# table's long field should cast to long on read
742742
written_arrow_table = tbl.scan().to_arrow()
743743
assert written_arrow_table == pyarrow_table_with_promoted_types.cast(
@@ -747,7 +747,7 @@ def test_add_files_with_valid_upcast(
747747
pa.field("list", pa.list_(pa.int64()), nullable=False),
748748
pa.field("map", pa.map_(pa.string(), pa.int64()), nullable=False),
749749
pa.field("double", pa.float64(), nullable=True),
750-
pa.field("uuid", pa.binary(length=16), nullable=True), # can UUID is read as fixed length binary of length 16
750+
pa.field("uuid", pa.uuid(), nullable=True),
751751
)
752752
)
753753
)

tests/integration/test_partitioning_key.py

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717
# pylint:disable=redefined-outer-name
18-
import uuid
1918
from datetime import date, datetime, timedelta, timezone
2019
from decimal import Decimal
2120
from typing import Any, List
@@ -308,25 +307,6 @@
308307
(CAST('2023-01-01' AS DATE), 'Associated string value for date 2023-01-01')
309308
""",
310309
),
311-
(
312-
[PartitionField(source_id=14, field_id=1001, transform=IdentityTransform(), name="uuid_field")],
313-
[uuid.UUID("f47ac10b-58cc-4372-a567-0e02b2c3d479")],
314-
Record("f47ac10b-58cc-4372-a567-0e02b2c3d479"),
315-
"uuid_field=f47ac10b-58cc-4372-a567-0e02b2c3d479",
316-
f"""CREATE TABLE {identifier} (
317-
uuid_field string,
318-
string_field string
319-
)
320-
USING iceberg
321-
PARTITIONED BY (
322-
identity(uuid_field)
323-
)
324-
""",
325-
f"""INSERT INTO {identifier}
326-
VALUES
327-
('f47ac10b-58cc-4372-a567-0e02b2c3d479', 'Associated string value for UUID f47ac10b-58cc-4372-a567-0e02b2c3d479')
328-
""",
329-
),
330310
(
331311
[PartitionField(source_id=11, field_id=1001, transform=IdentityTransform(), name="binary_field")],
332312
[b"example"],

tests/integration/test_reads.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -589,15 +589,15 @@ def test_partitioned_tables(catalog: Catalog) -> None:
589589
def test_unpartitioned_uuid_table(catalog: Catalog) -> None:
590590
unpartitioned_uuid = catalog.load_table("default.test_uuid_and_fixed_unpartitioned")
591591
arrow_table_eq = unpartitioned_uuid.scan(row_filter="uuid_col == '102cb62f-e6f8-4eb0-9973-d9b012ff0967'").to_arrow()
592-
assert arrow_table_eq["uuid_col"].to_pylist() == [uuid.UUID("102cb62f-e6f8-4eb0-9973-d9b012ff0967").bytes]
592+
assert arrow_table_eq["uuid_col"].to_pylist() == [uuid.UUID("102cb62f-e6f8-4eb0-9973-d9b012ff0967")]
593593

594594
arrow_table_neq = unpartitioned_uuid.scan(
595595
row_filter="uuid_col != '102cb62f-e6f8-4eb0-9973-d9b012ff0967' and uuid_col != '639cccce-c9d2-494a-a78c-278ab234f024'"
596596
).to_arrow()
597597
assert arrow_table_neq["uuid_col"].to_pylist() == [
598-
uuid.UUID("ec33e4b2-a834-4cc3-8c4a-a1d3bfc2f226").bytes,
599-
uuid.UUID("c1b0d8e0-0b0e-4b1e-9b0a-0e0b0d0c0a0b").bytes,
600-
uuid.UUID("923dae77-83d6-47cd-b4b0-d383e64ee57e").bytes,
598+
uuid.UUID("ec33e4b2-a834-4cc3-8c4a-a1d3bfc2f226"),
599+
uuid.UUID("c1b0d8e0-0b0e-4b1e-9b0a-0e0b0d0c0a0b"),
600+
uuid.UUID("923dae77-83d6-47cd-b4b0-d383e64ee57e"),
601601
]
602602

603603

tests/integration/test_writes/test_writes.py

Lines changed: 57 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import os
2020
import random
2121
import time
22+
import uuid
2223
from datetime import date, datetime, timedelta
2324
from decimal import Decimal
2425
from pathlib import Path
@@ -49,7 +50,7 @@
4950
from pyiceberg.table import TableProperties
5051
from pyiceberg.table.refs import MAIN_BRANCH
5152
from pyiceberg.table.sorting import SortDirection, SortField, SortOrder
52-
from pyiceberg.transforms import DayTransform, HourTransform, IdentityTransform
53+
from pyiceberg.transforms import DayTransform, HourTransform, IdentityTransform, Transform
5354
from pyiceberg.types import (
5455
DateType,
5556
DecimalType,
@@ -59,6 +60,7 @@
5960
LongType,
6061
NestedField,
6162
StringType,
63+
UUIDType,
6264
)
6365
from utils import _create_table
6466

@@ -1286,7 +1288,7 @@ def test_table_write_schema_with_valid_upcast(
12861288
pa.field("list", pa.list_(pa.int64()), nullable=False),
12871289
pa.field("map", pa.map_(pa.string(), pa.int64()), nullable=False),
12881290
pa.field("double", pa.float64(), nullable=True), # can support upcasting float to double
1289-
pa.field("uuid", pa.binary(length=16), nullable=True), # can UUID is read as fixed length binary of length 16
1291+
pa.field("uuid", pa.uuid(), nullable=True),
12901292
)
12911293
)
12921294
)
@@ -1858,6 +1860,59 @@ def test_read_write_decimals(session_catalog: Catalog) -> None:
18581860
assert tbl.scan().to_arrow() == arrow_table
18591861

18601862

1863+
@pytest.mark.integration
1864+
@pytest.mark.parametrize(
1865+
"transform",
1866+
[
1867+
IdentityTransform(),
1868+
# Bucket is disabled because of an issue in Iceberg Java:
1869+
# https://github.com/apache/iceberg/pull/13324
1870+
# BucketTransform(32)
1871+
],
1872+
)
1873+
def test_uuid_partitioning(session_catalog: Catalog, spark: SparkSession, transform: Transform) -> None: # type: ignore
1874+
identifier = f"default.test_uuid_partitioning_{str(transform).replace('[32]', '')}"
1875+
1876+
schema = Schema(NestedField(field_id=1, name="uuid", field_type=UUIDType(), required=True))
1877+
1878+
try:
1879+
session_catalog.drop_table(identifier=identifier)
1880+
except NoSuchTableError:
1881+
pass
1882+
1883+
partition_spec = PartitionSpec(PartitionField(source_id=1, field_id=1000, transform=transform, name="uuid_identity"))
1884+
1885+
import pyarrow as pa
1886+
1887+
arr_table = pa.Table.from_pydict(
1888+
{
1889+
"uuid": [
1890+
uuid.UUID("00000000-0000-0000-0000-000000000000").bytes,
1891+
uuid.UUID("11111111-1111-1111-1111-111111111111").bytes,
1892+
],
1893+
},
1894+
schema=pa.schema(
1895+
[
1896+
# Uuid not yet supported, so we have to stick with `binary(16)`
1897+
# https://github.com/apache/arrow/issues/46468
1898+
pa.field("uuid", pa.binary(16), nullable=False),
1899+
]
1900+
),
1901+
)
1902+
1903+
tbl = session_catalog.create_table(
1904+
identifier=identifier,
1905+
schema=schema,
1906+
partition_spec=partition_spec,
1907+
)
1908+
1909+
tbl.append(arr_table)
1910+
1911+
lhs = [r[0] for r in spark.table(identifier).collect()]
1912+
rhs = [str(u.as_py()) for u in tbl.scan().to_arrow()["uuid"].combine_chunks()]
1913+
assert lhs == rhs
1914+
1915+
18611916
@pytest.mark.integration
18621917
def test_avro_compression_codecs(session_catalog: Catalog, arrow_table_with_null: pa.Table) -> None:
18631918
identifier = "default.test_avro_compression_codecs"

0 commit comments

Comments
 (0)