Skip to content

Commit 8a0ed51

Browse files
authored
Returns null when divide/mod by 0 (#186)
* fix: return null if 1/0 * add mod * fix compile failed
1 parent 7deeef2 commit 8a0ed51

File tree

5 files changed

+100
-67
lines changed

5 files changed

+100
-67
lines changed

cpp/src/gandiva/function_registry_arithmetic.cc

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -90,14 +90,15 @@ std::vector<NativeFunction> GetArithmeticFunctionRegistry() {
9090
// add/sub/multiply/divide/mod
9191
BINARY_SYMMETRIC_FN(add, {}), BINARY_SYMMETRIC_FN(subtract, {}),
9292
BINARY_SYMMETRIC_FN(multiply, {}),
93-
NUMERIC_TYPES(BINARY_SYMMETRIC_UNSAFE_NULL_IF_NULL, divide, {}),
94-
BINARY_GENERIC_SAFE_NULL_IF_NULL(mod, {"modulo"}, int8, int8, int8),
95-
BINARY_GENERIC_SAFE_NULL_IF_NULL(mod, {"modulo"}, int16, int16, int16),
96-
BINARY_GENERIC_SAFE_NULL_IF_NULL(mod, {"modulo"}, int32, int32, int32),
97-
BINARY_GENERIC_SAFE_NULL_IF_NULL(mod, {"modulo"}, int64, int64, int64),
93+
NUMERIC_TYPES_WITHOUT_DECIMAL(BINARY_SYMMETRIC_SAFE_INTERNAL_NULL, divide, {}),
94+
BINARY_SYMMETRIC_SAFE_INTERNAL_NULL(mod, {"modulo"}, int8),
95+
BINARY_SYMMETRIC_SAFE_INTERNAL_NULL(mod, {"modulo"}, int16),
96+
BINARY_SYMMETRIC_SAFE_INTERNAL_NULL(mod, {"modulo"}, int32),
97+
BINARY_SYMMETRIC_SAFE_INTERNAL_NULL(mod, {"modulo"}, int64),
98+
BINARY_SYMMETRIC_SAFE_INTERNAL_NULL(mod, {"modulo"}, float32),
99+
BINARY_SYMMETRIC_SAFE_INTERNAL_NULL(mod, {"modulo"}, float64),
98100
BINARY_SYMMETRIC_UNSAFE_NULL_IF_NULL(mod, {"modulo"}, decimal128),
99-
BINARY_SYMMETRIC_UNSAFE_NULL_IF_NULL(mod, {"modulo"}, float32),
100-
BINARY_SYMMETRIC_UNSAFE_NULL_IF_NULL(mod, {"modulo"}, float64),
101+
BINARY_SYMMETRIC_UNSAFE_NULL_IF_NULL(divide, {}, decimal128),
101102
BINARY_SYMMETRIC_UNSAFE_NULL_IF_NULL(div, {}, int32),
102103
BINARY_SYMMETRIC_UNSAFE_NULL_IF_NULL(div, {}, int64),
103104
BINARY_SYMMETRIC_UNSAFE_NULL_IF_NULL(div, {}, float32),

cpp/src/gandiva/function_registry_common.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,14 @@ typedef std::unordered_map<const FunctionSignature*, const NativeFunction*, KeyH
267267
INNER(NAME, ALIASES, uint64), INNER(NAME, ALIASES, float32), \
268268
INNER(NAME, ALIASES, float64), INNER(NAME, ALIASES, decimal128)
269269

270+
// Iterate the inner macro over all numeric types without decimal
271+
#define NUMERIC_TYPES_WITHOUT_DECIMAL(INNER, NAME, ALIASES) \
272+
INNER(NAME, ALIASES, int8), INNER(NAME, ALIASES, int16), INNER(NAME, ALIASES, int32), \
273+
INNER(NAME, ALIASES, int64), INNER(NAME, ALIASES, uint8), \
274+
INNER(NAME, ALIASES, uint16), INNER(NAME, ALIASES, uint32), \
275+
INNER(NAME, ALIASES, uint64), INNER(NAME, ALIASES, float32), \
276+
INNER(NAME, ALIASES, float64)
277+
270278
// Iterate the inner macro over numeric and date/time types
271279
#define NUMERIC_DATE_TYPES(INNER, NAME, ALIASES) \
272280
NUMERIC_TYPES(INNER, NAME, ALIASES), DATE_TYPES(INNER, NAME, ALIASES), \

cpp/src/gandiva/precompiled/arithmetic_ops.cc

Lines changed: 49 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -62,12 +62,21 @@ extern "C" {
6262
DATE_TYPES(INNER, NAME, OP) \
6363
INNER(NAME, boolean, OP)
6464

65-
#define MOD_OP(NAME, IN_TYPE1, IN_TYPE2, OUT_TYPE) \
66-
FORCE_INLINE \
67-
gdv_##OUT_TYPE NAME##_##IN_TYPE1##_##IN_TYPE2(gdv_##IN_TYPE1 left, \
68-
gdv_##IN_TYPE2 right) { \
69-
return (right == 0 ? static_cast<gdv_##OUT_TYPE>(left) \
70-
: static_cast<gdv_##OUT_TYPE>(left % right)); \
65+
#define MOD_OP(TYPE) \
66+
FORCE_INLINE \
67+
gdv_##TYPE mod_##TYPE##_##TYPE(gdv_##TYPE in1, bool in1_valid, \
68+
gdv_##TYPE in2, bool in2_valid, bool* out_valid) { \
69+
if (!in1_valid || !in2_valid) { \
70+
*out_valid = false; \
71+
return static_cast<gdv_##TYPE>(0); \
72+
} \
73+
if (static_cast<gdv_##TYPE>(0) == in2) { \
74+
*out_valid = false; \
75+
return static_cast<gdv_##TYPE>(0); \
76+
} \
77+
gdv_##TYPE res = static_cast<gdv_##TYPE>(in1 % in2); \
78+
*out_valid = true; \
79+
return res; \
7180
}
7281

7382
// Symmetric binary fns : left, right params and return type are same.
@@ -95,29 +104,32 @@ gdv_boolean isNaN_float32(gdv_float32 val) { return isnan(val) || isinf(val); }
95104
FORCE_INLINE
96105
gdv_boolean isNaN_float64(gdv_float64 val) { return isnan(val) || isinf(val); }
97106

98-
MOD_OP(mod, int32, int32, int32)
99-
MOD_OP(mod, int64, int64, int64)
107+
MOD_OP(int32)
108+
MOD_OP(int64)
100109

101110
#undef MOD_OP
102111

103-
gdv_float32 mod_float32_float32(int64_t context, gdv_float32 x, gdv_float32 y) {
104-
if (y == 0.0) {
105-
// char const* err_msg = "divide by zero error";
106-
// gdv_fn_context_set_error_msg(context, err_msg);
107-
return 0.0;
108-
}
109-
return fmod(x, y);
110-
}
111-
112-
gdv_float64 mod_float64_float64(int64_t context, gdv_float64 x, gdv_float64 y) {
113-
if (y == 0.0) {
114-
// Setting error msg can cause unexpected runtime exception.
115-
// char const* err_msg = "divide by zero error";
116-
// gdv_fn_context_set_error_msg(context, err_msg);
117-
return 0.0;
112+
#define MOD_FLOAT(TYPE) \
113+
FORCE_INLINE \
114+
gdv_##TYPE mod_##TYPE##_##TYPE(gdv_##TYPE in1, bool in1_valid, \
115+
gdv_##TYPE in2, bool in2_valid, bool* out_valid) { \
116+
if (!in1_valid || !in2_valid) { \
117+
*out_valid = false; \
118+
return static_cast<gdv_##TYPE>(0.0); \
119+
} \
120+
if (static_cast<gdv_##TYPE>(0.0) == in2) { \
121+
*out_valid = false; \
122+
return static_cast<gdv_##TYPE>(0.0); \
123+
} \
124+
gdv_##TYPE res = static_cast<gdv_##TYPE>(fmod(in1,in2)); \
125+
*out_valid = true; \
126+
return res; \
118127
}
119-
return fmod(x, y);
120-
}
128+
129+
MOD_FLOAT(float32)
130+
MOD_FLOAT(float64)
131+
132+
#undef MOD_FLOAT
121133

122134
// pmod, return the positive mod.
123135
#define PMOD(IN_TYPE) \
@@ -380,11 +392,19 @@ NUMERIC_BOOL_DATE_FUNCTION(IS_NOT_DISTINCT_FROM)
380392

381393
#define DIVIDE(TYPE) \
382394
FORCE_INLINE \
383-
gdv_##TYPE divide_##TYPE##_##TYPE(gdv_int64 context, gdv_##TYPE in1, gdv_##TYPE in2) { \
384-
if (in2 == 0) { \
385-
return static_cast<gdv_##TYPE>(NULL); \
395+
gdv_##TYPE divide_##TYPE##_##TYPE(gdv_##TYPE in1, bool in1_valid, \
396+
gdv_##TYPE in2, bool in2_valid, bool* out_valid) { \
397+
if (!in1_valid || !in2_valid) { \
398+
*out_valid = false; \
399+
return static_cast<gdv_##TYPE>(0); \
400+
} \
401+
if (static_cast<gdv_##TYPE>(0) == in2) { \
402+
*out_valid = false; \
403+
return static_cast<gdv_##TYPE>(0); \
386404
} \
387-
return static_cast<gdv_##TYPE>(in1 / in2); \
405+
gdv_##TYPE res = static_cast<gdv_##TYPE>(in1 / in2); \
406+
*out_valid = true; \
407+
return res; \
388408
}
389409

390410
NUMERIC_FUNCTION(DIVIDE)

cpp/src/gandiva/precompiled/arithmetic_ops_test.cc

Lines changed: 29 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -35,30 +35,27 @@ TEST(TestArithmeticOps, TestIsDistinctFrom) {
3535
}
3636

3737
TEST(TestArithmeticOps, TestMod) {
38-
gandiva::ExecutionContext context;
39-
EXPECT_EQ(mod_int32_int32(10, 0), 10);
38+
bool out_valid = false;
39+
EXPECT_EQ(mod_int32_int32(10, true, 0, true, &out_valid), 0);
40+
EXPECT_EQ(out_valid, false);
4041

4142
const double acceptable_abs_error = 0.00000000001; // 1e-10
4243

43-
EXPECT_DOUBLE_EQ(mod_float64_float64(reinterpret_cast<gdv_int64>(&context), 2.5, 0.0),
44-
0.0);
45-
// EXPECT_TRUE(context.has_error());
46-
// EXPECT_EQ(context.get_error(), "divide by zero error");
44+
out_valid = false;
45+
EXPECT_EQ(mod_float32_float32(2.5, true, 0, true, &out_valid), 0);
46+
EXPECT_EQ(out_valid, false);
4747

48-
context.Reset();
49-
EXPECT_NEAR(mod_float64_float64(reinterpret_cast<gdv_int64>(&context), 2.5, 1.2), 0.1,
50-
acceptable_abs_error);
51-
EXPECT_FALSE(context.has_error());
48+
out_valid = false;
49+
EXPECT_NEAR(mod_float64_float64(2.5, true, 1.2, true, &out_valid), 0.1, acceptable_abs_error);
50+
EXPECT_EQ(out_valid, true);
5251

53-
context.Reset();
54-
EXPECT_DOUBLE_EQ(mod_float64_float64(reinterpret_cast<gdv_int64>(&context), 2.5, 2.5),
55-
0.0);
56-
EXPECT_FALSE(context.has_error());
52+
out_valid = false;
53+
EXPECT_DOUBLE_EQ(mod_float64_float64(2.5, true, 2.5, true, &out_valid), 0.0);
54+
EXPECT_EQ(out_valid, true);
5755

58-
context.Reset();
59-
EXPECT_NEAR(mod_float64_float64(reinterpret_cast<gdv_int64>(&context), 9.2, 3.7), 1.8,
60-
acceptable_abs_error);
61-
EXPECT_FALSE(context.has_error());
56+
out_valid = false;
57+
EXPECT_NEAR(mod_float64_float64(9.2, true, 3.7, true, &out_valid), 1.8, acceptable_abs_error);
58+
EXPECT_EQ(out_valid, true);
6259
}
6360

6461
TEST(TestArithmeticOps, TestPMod) {
@@ -94,14 +91,21 @@ TEST(TestArithmeticOps, TestCompare) {
9491
}
9592

9693
TEST(TestArithmeticOps, TestDivide) {
97-
gandiva::ExecutionContext context;
98-
EXPECT_EQ(divide_int64_int64(reinterpret_cast<gdv_int64>(&context), 10, 0), 0);
99-
// EXPECT_EQ(context.has_error(), true);
100-
// EXPECT_EQ(context.get_error(), "divide by zero error");
94+
bool out_valid = false;
95+
EXPECT_EQ(divide_int64_int64(10, true, 0, true, &out_valid), 0);
96+
EXPECT_EQ(out_valid, false);
10197

102-
context.Reset();
103-
EXPECT_EQ(divide_int64_int64(reinterpret_cast<gdv_int64>(&context), 10, 2), 5);
104-
EXPECT_EQ(context.has_error(), false);
98+
out_valid = false;
99+
EXPECT_EQ(divide_int64_int64(10, true, 0, false, &out_valid), 0);
100+
EXPECT_EQ(out_valid, false);
101+
102+
out_valid = false;
103+
EXPECT_EQ(divide_int64_int64(10, false, 0, true, &out_valid), 0);
104+
EXPECT_EQ(out_valid, false);
105+
106+
out_valid = false;
107+
EXPECT_EQ(divide_int64_int64(10, true, 2, true, &out_valid), 5);
108+
EXPECT_EQ(out_valid, true);
105109
}
106110

107111
TEST(TestArithmeticOps, TestDiv) {

cpp/src/gandiva/precompiled/types.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -144,11 +144,11 @@ double months_between_timestamp_timestamp(gdv_uint64, gdv_uint64);
144144
gdv_int32 mem_compare(const char* left, gdv_int32 left_len, const char* right,
145145
gdv_int32 right_len);
146146

147-
gdv_int32 mod_int8_int8(gdv_int8 left, gdv_int8 right);
148-
gdv_int32 mod_int16_int16(gdv_int16 left, gdv_int16 right);
149-
gdv_int32 mod_int32_int32(gdv_int32 left, gdv_int32 right);
150-
gdv_float32 mod_float32_float32(gdv_int64 context, gdv_float32 left, gdv_float32 right);
151-
gdv_float64 mod_float64_float64(gdv_int64 context, gdv_float64 left, gdv_float64 right);
147+
gdv_int32 mod_int8_int8(gdv_int8 in1, bool in1_valid, gdv_int8 in2, bool in2_valid, bool* out_valid);
148+
gdv_int32 mod_int16_int16(gdv_int16 in1, bool in1_valid, gdv_int16 in2, bool in2_valid, bool* out_valid);
149+
gdv_int32 mod_int32_int32(gdv_int32 in1, bool in1_valid, gdv_int32 in2, bool in2_valid, bool* out_valid);
150+
gdv_float32 mod_float32_float32(gdv_float32 in1, bool in1_valid, gdv_float32 in2, bool in2_valid, bool* out_valid);
151+
gdv_float64 mod_float64_float64(gdv_float64 in1, bool in1_valid, gdv_float64 in2, bool in2_valid, bool* out_valid);
152152

153153
gdv_int8 pmod_int8_int8(gdv_int8 in1, bool in1_valid, gdv_int8 in2, bool in2_valid, bool* out_valid);
154154
gdv_int16 pmod_int16_int16(gdv_int16 in1, bool in1_valid, gdv_int16 in2, bool in2_valid, bool* out_valid);
@@ -160,7 +160,7 @@ gdv_float64 pmod_float64_float64(gdv_float64 in1, bool in1_valid, gdv_float64 in
160160
bool equal_with_nan_float32_float32(gdv_float32 in1, gdv_float32 in2);
161161
bool not_equal_with_nan_float32_float32(gdv_float32 in1, gdv_float32 in2);
162162

163-
gdv_int64 divide_int64_int64(gdv_int64 context, gdv_int64 in1, gdv_int64 in2);
163+
gdv_int64 divide_int64_int64(gdv_int64 in1, bool in1_valid, gdv_int64 in2, bool in2_valid, bool* out_valid);
164164

165165
gdv_int64 div_int64_int64(gdv_int64 context, gdv_int64 in1, gdv_int64 in2);
166166
gdv_float32 div_float32_float32(gdv_int64 context, gdv_float32 in1, gdv_float32 in2);

0 commit comments

Comments
 (0)