@@ -1807,6 +1807,74 @@ TEST(TestCaseWhen, Decimal) {
18071807 }
18081808}
18091809
1810+ TEST (TestCaseWhen, DecimalPromotion) {
1811+ auto check_case_when_decimal_promotion =
1812+ [](std::shared_ptr<Scalar> body_true, std::shared_ptr<Scalar> body_false,
1813+ std::shared_ptr<Scalar> promoted_true, std::shared_ptr<Scalar> promoted_false) {
1814+ auto cond_true = ScalarFromJSON (boolean (), " true" );
1815+ auto cond_false = ScalarFromJSON (boolean (), " false" );
1816+ CheckScalar (" case_when" , {MakeStruct ({cond_true}), body_true, body_false},
1817+ promoted_true);
1818+ CheckScalar (" case_when" , {MakeStruct ({cond_false}), body_true, body_false},
1819+ promoted_false);
1820+ };
1821+
1822+ const std::vector<std::pair<int , int >> precisions = {{10 , 20 }, {15 , 15 }, {20 , 10 }};
1823+ const std::vector<std::pair<int , int >> scales = {{3 , 9 }, {6 , 6 }, {9 , 3 }};
1824+ for (auto p : precisions) {
1825+ for (auto s : scales) {
1826+ auto p1 = p.first ;
1827+ auto s1 = s.first ;
1828+ auto p2 = p.second ;
1829+ auto s2 = s.second ;
1830+
1831+ auto max_scale = std::max ({s1, s2});
1832+ auto scale_up_1 = max_scale - s1;
1833+ auto scale_up_2 = max_scale - s2;
1834+ auto max_precision = std::max ({p1 + scale_up_1, p2 + scale_up_2});
1835+
1836+ // Operand string: 444.777...
1837+ std::string str_d1 =
1838+ R"( ")" + std::string (p1 - s1, ' 4' ) + " ." + std::string (s1, ' 7' ) + R"( ")" ;
1839+ std::string str_d2 =
1840+ R"( ")" + std::string (p2 - s2, ' 4' ) + " ." + std::string (s2, ' 7' ) + R"( ")" ;
1841+
1842+ // Promoted string: 444.777...000
1843+ std::string str_d1_promoted = R"( ")" + std::string (p1 - s1, ' 4' ) + " ." +
1844+ std::string (s1, ' 7' ) +
1845+ std::string (max_scale - s1, ' 0' ) + R"( ")" ;
1846+ std::string str_d2_promoted = R"( ")" + std::string (p2 - s2, ' 4' ) + " ." +
1847+ std::string (s2, ' 7' ) +
1848+ std::string (max_scale - s2, ' 0' ) + R"( ")" ;
1849+
1850+ auto d128_1 = decimal128 (p1, s1);
1851+ auto d128_2 = decimal128 (p2, s2);
1852+ auto d256_1 = decimal256 (p1, s1);
1853+ auto d256_2 = decimal256 (p2, s2);
1854+ auto d128_promoted = decimal128 (max_precision, max_scale);
1855+ auto d256_promoted = decimal256 (max_precision, max_scale);
1856+
1857+ auto scalar128_1 = ScalarFromJSON (d128_1, str_d1);
1858+ auto scalar128_2 = ScalarFromJSON (d128_2, str_d2);
1859+ auto scalar256_1 = ScalarFromJSON (d256_1, str_d1);
1860+ auto scalar256_2 = ScalarFromJSON (d256_2, str_d2);
1861+ auto scalar128_d1_promoted = ScalarFromJSON (d128_promoted, str_d1_promoted);
1862+ auto scalar128_d2_promoted = ScalarFromJSON (d128_promoted, str_d2_promoted);
1863+ auto scalar256_d1_promoted = ScalarFromJSON (d256_promoted, str_d1_promoted);
1864+ auto scalar256_d2_promoted = ScalarFromJSON (d256_promoted, str_d2_promoted);
1865+
1866+ check_case_when_decimal_promotion (scalar128_1, scalar128_2, scalar128_d1_promoted,
1867+ scalar128_d2_promoted);
1868+ check_case_when_decimal_promotion (scalar128_1, scalar256_2, scalar256_d1_promoted,
1869+ scalar256_d2_promoted);
1870+ check_case_when_decimal_promotion (scalar256_1, scalar128_2, scalar256_d1_promoted,
1871+ scalar256_d2_promoted);
1872+ check_case_when_decimal_promotion (scalar256_1, scalar256_2, scalar256_d1_promoted,
1873+ scalar256_d2_promoted);
1874+ }
1875+ }
1876+ }
1877+
18101878TEST (TestCaseWhen, FixedSizeBinary) {
18111879 auto type = fixed_size_binary (3 );
18121880 auto cond_true = ScalarFromJSON (boolean (), " true" );
@@ -2509,6 +2577,28 @@ TEST(TestCaseWhen, UnionBoolStringRandom) {
25092577 }
25102578}
25112579
2580+ TEST (TestCaseWhen, DispatchExact) {
2581+ // Decimal types with same (p, s)
2582+ CheckDispatchExact (" case_when" , {struct_ ({field (" " , boolean ())}), decimal128 (20 , 3 ),
2583+ decimal128 (20 , 3 )});
2584+ CheckDispatchExact (" case_when" , {struct_ ({field (" " , boolean ())}), decimal256 (20 , 3 ),
2585+ decimal256 (20 , 3 )});
2586+
2587+ // Decimal types with different (p, s)
2588+ CheckDispatchExactFails (" case_when" , {struct_ ({field (" " , boolean ())}),
2589+ decimal128 (20 , 3 ), decimal128 (21 , 3 )});
2590+ CheckDispatchExactFails (" case_when" , {struct_ ({field (" " , boolean ())}),
2591+ decimal128 (20 , 1 ), decimal128 (20 , 3 )});
2592+ CheckDispatchExactFails (" case_when" , {struct_ ({field (" " , boolean ())}),
2593+ decimal128 (20 , 3 ), decimal256 (20 , 3 )});
2594+ CheckDispatchExactFails (" case_when" , {struct_ ({field (" " , boolean ())}),
2595+ decimal256 (20 , 3 ), decimal128 (21 , 3 )});
2596+ CheckDispatchExactFails (" case_when" , {struct_ ({field (" " , boolean ())}),
2597+ decimal256 (20 , 3 ), decimal256 (21 , 3 )});
2598+ CheckDispatchExactFails (" case_when" , {struct_ ({field (" " , boolean ())}),
2599+ decimal256 (20 , 1 ), decimal256 (20 , 3 )});
2600+ }
2601+
25122602TEST (TestCaseWhen, DispatchBest) {
25132603 CheckDispatchBest (" case_when" , {struct_ ({field (" " , boolean ())}), int64 (), int32 ()},
25142604 {struct_ ({field (" " , boolean ())}), int64 (), int64 ()});
@@ -2559,6 +2649,32 @@ TEST(TestCaseWhen, DispatchBest) {
25592649 CheckDispatchBest (
25602650 " case_when" , {struct_ ({field (" " , boolean ())}), dictionary (int64 (), utf8 ()), utf8 ()},
25612651 {struct_ ({field (" " , boolean ())}), utf8 (), utf8 ()});
2652+
2653+ // Decimal promotion
2654+ CheckDispatchBest (
2655+ " case_when" ,
2656+ {struct_ ({field (" " , boolean ())}), decimal128 (20 , 3 ), decimal128 (21 , 3 )},
2657+ {struct_ ({field (" " , boolean ())}), decimal128 (21 , 3 ), decimal128 (21 , 3 )});
2658+ CheckDispatchBest (
2659+ " case_when" ,
2660+ {struct_ ({field (" " , boolean ())}), decimal128 (20 , 1 ), decimal128 (21 , 3 )},
2661+ {struct_ ({field (" " , boolean ())}), decimal128 (22 , 3 ), decimal128 (22 , 3 )});
2662+ CheckDispatchBest (
2663+ " case_when" ,
2664+ {struct_ ({field (" " , boolean ())}), decimal128 (20 , 3 ), decimal128 (21 , 1 )},
2665+ {struct_ ({field (" " , boolean ())}), decimal128 (23 , 3 ), decimal128 (23 , 3 )});
2666+ CheckDispatchBest (
2667+ " case_when" ,
2668+ {struct_ ({field (" " , boolean ())}), decimal128 (20 , 3 ), decimal256 (21 , 3 )},
2669+ {struct_ ({field (" " , boolean ())}), decimal256 (21 , 3 ), decimal256 (21 , 3 )});
2670+ CheckDispatchBest (
2671+ " case_when" ,
2672+ {struct_ ({field (" " , boolean ())}), decimal256 (20 , 1 ), decimal128 (21 , 3 )},
2673+ {struct_ ({field (" " , boolean ())}), decimal256 (22 , 3 ), decimal256 (22 , 3 )});
2674+ CheckDispatchBest (
2675+ " case_when" ,
2676+ {struct_ ({field (" " , boolean ())}), decimal256 (20 , 3 ), decimal256 (21 , 1 )},
2677+ {struct_ ({field (" " , boolean ())}), decimal256 (23 , 3 ), decimal256 (23 , 3 )});
25622678}
25632679
25642680template <typename Type>
0 commit comments