Skip to content

Commit 361a407

Browse files
authored
Test transform function consistency for all transforms (apache#1573)
I like this test from apache#1562, lets expand it to include all transforms
1 parent 5e4815a commit 361a407

File tree

1 file changed

+28
-4
lines changed

1 file changed

+28
-4
lines changed

tests/table/test_partitioning.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,15 @@
2323

2424
from pyiceberg.partitioning import UNPARTITIONED_PARTITION_SPEC, PartitionField, PartitionSpec
2525
from pyiceberg.schema import Schema
26-
from pyiceberg.transforms import BucketTransform, IdentityTransform, TruncateTransform
26+
from pyiceberg.transforms import (
27+
BucketTransform,
28+
DayTransform,
29+
HourTransform,
30+
IdentityTransform,
31+
MonthTransform,
32+
TruncateTransform,
33+
YearTransform,
34+
)
2735
from pyiceberg.typedef import Record
2836
from pyiceberg.types import (
2937
BinaryType,
@@ -186,11 +194,27 @@ def test_partition_type(table_schema_simple: Schema) -> None:
186194
(BinaryType(), b"\x8e\xd1\x87\x01"),
187195
],
188196
)
189-
def test_bucketing_function(source_type: PrimitiveType, value: Any) -> None:
190-
bucket = BucketTransform(2) # type: ignore
197+
def test_transform_consistency_with_pyarrow_transform(source_type: PrimitiveType, value: Any) -> None:
191198
import pyarrow as pa
192199

193-
assert bucket.transform(source_type)(value) == bucket.pyarrow_transform(source_type)(pa.array([value])).to_pylist()[0]
200+
all_transforms = [ # type: ignore
201+
IdentityTransform(),
202+
BucketTransform(10),
203+
TruncateTransform(10),
204+
YearTransform(),
205+
MonthTransform(),
206+
DayTransform(),
207+
HourTransform(),
208+
]
209+
for t in all_transforms:
210+
if t.can_transform(source_type):
211+
try:
212+
assert t.transform(source_type)(value) == t.pyarrow_transform(source_type)(pa.array([value])).to_pylist()[0]
213+
except ValueError as e:
214+
# Skipping unsupported feature
215+
if "FeatureUnsupported => Unsupported data type for truncate transform" in str(e):
216+
continue
217+
raise
194218

195219

196220
def test_deserialize_partition_field_v2() -> None:

0 commit comments

Comments
 (0)