Skip to content

Commit ca3a529

Browse files
fix: Unsigned type related bugs (#1095)
## Which issue does this PR close? Closes #1067 ## Rationale for this change Bug fix. A few expressions were failing some unsigned type related tests ## What changes are included in this PR? - For `u8`/`u16`, switched to use `generate_cast_to_signed!` in order to copy full i16/i32 width instead of padding zeros in the higher bits - `u64` becomes `Decimal(20, 0)` but there was a bug in `round()` (`>` vs `>=`) ## How are these changes tested? Put back tests for unsigned types
1 parent 59da6ce commit ca3a529

File tree

6 files changed

+17
-21
lines changed

6 files changed

+17
-21
lines changed

native/core/src/parquet/read/values.rs

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -476,7 +476,7 @@ make_int_variant_impl!(Int32ToDoubleType, copy_i32_to_f64, 8);
476476
make_int_variant_impl!(FloatToDoubleType, copy_f32_to_f64, 8);
477477

478478
// unsigned type require double the width and zeroes are written for the second half
479-
// perhaps because they are implemented as the next size up signed type?
479+
// because they are implemented as the next size up signed type
480480
make_int_variant_impl!(UInt8Type, copy_i32_to_u8, 2);
481481
make_int_variant_impl!(UInt16Type, copy_i32_to_u16, 4);
482482
make_int_variant_impl!(UInt32Type, copy_i32_to_u32, 8);
@@ -586,8 +586,6 @@ macro_rules! generate_cast_to_unsigned {
586586
};
587587
}
588588

589-
generate_cast_to_unsigned!(copy_i32_to_u8, i32, u8, 0_u8);
590-
generate_cast_to_unsigned!(copy_i32_to_u16, i32, u16, 0_u16);
591589
generate_cast_to_unsigned!(copy_i32_to_u32, i32, u32, 0_u32);
592590

593591
macro_rules! generate_cast_to_signed {
@@ -624,6 +622,9 @@ generate_cast_to_signed!(copy_i64_to_i64, i64, i64);
624622
generate_cast_to_signed!(copy_i64_to_i128, i64, i128);
625623
generate_cast_to_signed!(copy_u64_to_u128, u64, u128);
626624
generate_cast_to_signed!(copy_f32_to_f64, f32, f64);
625+
// even for u8/u16, need to copy full i16/i32 width for Spark compatibility
626+
generate_cast_to_signed!(copy_i32_to_u8, i32, i16);
627+
generate_cast_to_signed!(copy_i32_to_u16, i32, i32);
627628

628629
// Shared implementation for variants of Binary type
629630
macro_rules! make_plain_binary_impl {
@@ -1096,7 +1097,7 @@ mod test {
10961097
let source =
10971098
hex::decode("8a000000dbffffff1800000034ffffff300000001d000000abffffff37fffffff1000000")
10981099
.unwrap();
1099-
let expected = hex::decode("8a00db001800340030001d00ab003700f100").unwrap();
1100+
let expected = hex::decode("8a00dbff180034ff30001d00abff37fff100").unwrap();
11001101
let num = source.len() / 4;
11011102
let mut dest: Vec<u8> = vec![b' '; num * 2];
11021103
copy_i32_to_u8(source.as_bytes(), dest.as_mut_slice(), num);

native/spark-expr/src/scalar_funcs.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -354,7 +354,7 @@ pub fn spark_round(
354354
DataType::Int32 if *point < 0 => round_integer_array!(array, point, Int32Array, i32),
355355
DataType::Int16 if *point < 0 => round_integer_array!(array, point, Int16Array, i16),
356356
DataType::Int8 if *point < 0 => round_integer_array!(array, point, Int8Array, i8),
357-
DataType::Decimal128(_, scale) if *scale > 0 => {
357+
DataType::Decimal128(_, scale) if *scale >= 0 => {
358358
let f = decimal_round_f(scale, point);
359359
let (precision, scale) = get_precision_scale(data_type);
360360
make_decimal_array(array, precision, scale, &f)

spark/src/test/scala/org/apache/comet/CometCastSuite.scala

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -861,10 +861,7 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
861861
// primitives
862862
checkSparkAnswerAndOperator(
863863
"SELECT CAST(struct(_1, _2, _3, _4, _5, _6, _7, _8) as string) FROM tbl")
864-
// TODO: enable tests for unsigned ints (_9, _10, _11, _12) once
865-
// https://github.com/apache/datafusion-comet/issues/1067 is resolved
866-
// checkSparkAnswerAndOperator(
867-
// "SELECT CAST(struct(_9, _10, _11, _12) as string) FROM tbl")
864+
checkSparkAnswerAndOperator("SELECT CAST(struct(_9, _10, _11, _12) as string) FROM tbl")
868865
// decimals
869866
// TODO add _16 when https://github.com/apache/datafusion-comet/issues/1068 is resolved
870867
checkSparkAnswerAndOperator("SELECT CAST(struct(_15, _17) as string) FROM tbl")

spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -119,10 +119,7 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
119119
val path = new Path(dir.toURI.toString, "test.parquet")
120120
makeParquetFileAllTypes(path, dictionaryEnabled = dictionaryEnabled, 10000)
121121
withParquetTable(path.toString, "tbl") {
122-
// TODO: enable test for unsigned ints
123-
checkSparkAnswerAndOperator(
124-
"select _1, _2, _3, _4, _5, _6, _7, _8, _13, _14, _15, _16, _17, " +
125-
"_18, _19, _20 FROM tbl WHERE _2 > 100")
122+
checkSparkAnswerAndOperator("select * FROM tbl WHERE _2 > 100")
126123
}
127124
}
128125
}
@@ -1115,7 +1112,7 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
11151112
val path = new Path(dir.toURI.toString, "test.parquet")
11161113
makeParquetFileAllTypes(path, dictionaryEnabled = dictionaryEnabled, 100)
11171114
withParquetTable(path.toString, "tbl") {
1118-
Seq(2, 3, 4, 5, 6, 7, 15, 16, 17).foreach { col =>
1115+
Seq(2, 3, 4, 5, 6, 7, 9, 10, 11, 12, 15, 16, 17).foreach { col =>
11191116
checkSparkAnswerAndOperator(s"SELECT abs(_${col}) FROM tbl")
11201117
}
11211118
}
@@ -1239,9 +1236,8 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
12391236
withParquetTable(path.toString, "tbl") {
12401237
for (s <- Seq(-5, -1, 0, 1, 5, -1000, 1000, -323, -308, 308, -15, 15, -16, 16, null)) {
12411238
// array tests
1242-
// TODO: enable test for unsigned ints (_9, _10, _11, _12)
12431239
// TODO: enable test for floats (_6, _7, _8, _13)
1244-
for (c <- Seq(2, 3, 4, 5, 15, 16, 17)) {
1240+
for (c <- Seq(2, 3, 4, 5, 9, 10, 11, 12, 15, 16, 17)) {
12451241
checkSparkAnswerAndOperator(s"select _${c}, round(_${c}, ${s}) FROM tbl")
12461242
}
12471243
// scalar tests
@@ -1452,9 +1448,8 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
14521448
makeParquetFileAllTypes(path, dictionaryEnabled = dictionaryEnabled, 10000)
14531449

14541450
withParquetTable(path.toString, "tbl") {
1455-
// _9 and _10 (uint8 and uint16) not supported
14561451
checkSparkAnswerAndOperator(
1457-
"SELECT hex(_1), hex(_2), hex(_3), hex(_4), hex(_5), hex(_6), hex(_7), hex(_8), hex(_11), hex(_12), hex(_13), hex(_14), hex(_15), hex(_16), hex(_17), hex(_18), hex(_19), hex(_20) FROM tbl")
1452+
"SELECT hex(_1), hex(_2), hex(_3), hex(_4), hex(_5), hex(_6), hex(_7), hex(_8), hex(_9), hex(_10), hex(_11), hex(_12), hex(_13), hex(_14), hex(_15), hex(_16), hex(_17), hex(_18), hex(_19), hex(_20) FROM tbl")
14581453
}
14591454
}
14601455
}
@@ -2334,7 +2329,6 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
23342329
checkSparkAnswerAndOperator(
23352330
spark.sql("SELECT array_append((CASE WHEN _2 =_3 THEN array(_4) END), _4) FROM t1"));
23362331
}
2337-
23382332
}
23392333
}
23402334
}

spark/src/test/scala/org/apache/comet/exec/CometColumnarShuffleSuite.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -750,6 +750,10 @@ abstract class CometColumnarShuffleSuite extends CometTestBase with AdaptiveSpar
750750
$"_6",
751751
$"_7",
752752
$"_8",
753+
$"_9",
754+
$"_10",
755+
$"_11",
756+
$"_12",
753757
$"_13",
754758
$"_14",
755759
$"_15",

spark/src/test/scala/org/apache/comet/parquet/ParquetReadSuite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -434,8 +434,8 @@ abstract class ParquetReadSuite extends CometTestBase {
434434
i.toFloat,
435435
i.toDouble,
436436
i.toString * 48,
437-
java.lang.Byte.toUnsignedInt((-i).toByte),
438-
java.lang.Short.toUnsignedInt((-i).toShort),
437+
(-i).toByte,
438+
(-i).toShort,
439439
java.lang.Integer.toUnsignedLong(-i),
440440
new BigDecimal(UnsignedLong.fromLongBits((-i).toLong).bigIntegerValue()),
441441
i.toString,

0 commit comments

Comments
 (0)