diff --git a/src/AssociativeOpsTable.cpp b/src/AssociativeOpsTable.cpp index cac2508335f4..45925ded0ddc 100644 --- a/src/AssociativeOpsTable.cpp +++ b/src/AssociativeOpsTable.cpp @@ -219,8 +219,8 @@ void populate_ops_table_single_uint8_cast(const vector &types, vector &types, vector &table) { declare_vars_single(types); - table.emplace_back(select(x0 > tmax_0 - y0, tmax_0, y0), zero_0, true); // Saturating add - table.emplace_back(select(x0 < -y0, y0, tmax_0), zero_0, true); // Saturating add + table.emplace_back(select(x0 > tmax_0 - y0, tmax_0, x0 + y0), zero_0, true); // Saturating add + table.emplace_back(select(x0 < -y0, x0 + y0, tmax_0), zero_0, true); // Saturating add } void populate_ops_table_single_uint16_cast(const vector &types, vector &table) { @@ -233,8 +233,8 @@ void populate_ops_table_single_uint16_cast(const vector &types, vector &types, vector &table) { declare_vars_single(types); - table.emplace_back(select(x0 > tmax_0 - y0, tmax_0, y0), zero_0, true); // Saturating add - table.emplace_back(select(x0 < -y0, y0, tmax_0), zero_0, true); // Saturating add + table.emplace_back(select(x0 > tmax_0 - y0, tmax_0, x0 + y0), zero_0, true); // Saturating add + table.emplace_back(select(x0 < -y0, x0 + y0, tmax_0), zero_0, true); // Saturating add } void populate_ops_table_single_uint32_cast(const vector &types, vector &table) { @@ -245,8 +245,8 @@ void populate_ops_table_single_uint32_cast(const vector &types, vector &types, vector &table) { declare_vars_single(types); - table.emplace_back(select(x0 > tmax_0 - y0, tmax_0, y0), zero_0, true); // Saturating add - table.emplace_back(select(x0 < -y0, y0, tmax_0), zero_0, true); // Saturating add + table.emplace_back(select(x0 > tmax_0 - y0, tmax_0, x0 + y0), zero_0, true); // Saturating add + table.emplace_back(select(x0 < -y0, x0 + y0, tmax_0), zero_0, true); // Saturating add } void populate_ops_table_single_float_select(const vector &types, vector &table) { diff --git a/src/Associativity.cpp b/src/Associativity.cpp index 6a8d8948be85..14eaf8a81481 100644 --- a/src/Associativity.cpp +++ b/src/Associativity.cpp @@ -544,8 +544,8 @@ void associativity_test() { Expr f_call_0 = Call::make(t, "f", {x_idx}, Call::CallType::Halide, FunctionPtr(), 0); for (const Expr &e : {cast(min(cast(x) + y, 255)), - select(x > 255 - y, cast(255), y), - select(x < -y, y, cast(255)), + select(x > 255 - y, cast(255), x + y), + select(x < -y, x + y, cast(255)), saturating_add(x, y), saturating_add(y, x), saturating_cast(widening_add(x, y))}) { diff --git a/test/correctness/rfactor.cpp b/test/correctness/rfactor.cpp index 94431b0a3d09..2bcb2bc06f6e 100644 --- a/test/correctness/rfactor.cpp +++ b/test/correctness/rfactor.cpp @@ -825,6 +825,59 @@ int argmin_rfactor_test() { return 0; } +int saturating_add_rfactor_test() { + Func f("f"), g("g"), ref("ref"); + Var x("x"), y("y"), z("z"); + + f(x) = cast(x); + f.compute_root(); + + Param inner_extent; + RDom r(10, inner_extent); + inner_extent.set(6); + uint8_t max_int = 255; + + g() = Tuple(cast(0), cast(0)); + g() = Tuple(select(g()[0] > max_int - 3 * f(r.x), max_int, g()[0] + 3 * f(r.x)), + select(g()[1] > max_int - 9 * f(r.x), max_int, 9 * f(r.x) + g()[1])); + + RVar rxi("rxi"), rxo("rxo"); + g.update(0).split(r.x, rxo, rxi, 2); + + Var u("u"); + Func intm = g.update(0).rfactor(rxo, u); + intm.compute_root(); + intm.update(0).vectorize(u, 2); + + Realization rn = g.realize(); + Buffer im1(rn[0]); + Buffer im2(rn[1]); + + auto func1 = [](int x, int y, int z) { + int ret = 0; + for (int i = 10; i < 16; i++) { + ret += 3 * i; + } + return std::min(ret, 255); + }; + if (check_image(im1, func1)) { + return 1; + } + + auto func2 = [](int x, int y, int z) { + int ret = 0; + for (int i = 10; i < 16; i++) { + ret += 9 * i; + } + return std::min(ret, 255); + }; + if (check_image(im2, func2)) { + return 1; + } + + return 0; +} + int allocation_bound_test_trace(JITUserContext *user_context, const halide_trace_event_t *e) { // The schedule implies that f will be stored from 0 to 1 if (e->event == 2 && std::string(e->func) == "f") { @@ -1156,6 +1209,7 @@ int main(int argc, char **argv) { {"rfactor tile reorder test: checking output img correctness...", rfactor_tile_reorder_test}, {"complex multiply rfactor test", complex_multiply_rfactor_test}, {"argmin rfactor test", argmin_rfactor_test}, + {"saturating add rfactor test", saturating_add_rfactor_test}, {"inlined rfactor with disappearing rvar test", inlined_rfactor_with_disappearing_rvar_test}, {"rfactor bounds tests", rfactor_precise_bounds_test}, {"isnan max rfactor test (bitwise or)", isnan_max_rfactor_test},