Skip to content

Commit a3b2524

Browse files
authored
Add rfactor patterns for NaN-propagating min/max (#8587)
1 parent c805e54 commit a3b2524

File tree

3 files changed

+91
-1
lines changed

3 files changed

+91
-1
lines changed

src/AssociativeOpsTable.cpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,32 @@ void populate_ops_table_single_uint32_select(const vector<Type> &types, vector<A
249249
table.emplace_back(select(x0 < -y0, y0, tmax_0), zero_0, true); // Saturating add
250250
}
251251

252+
// This function exists because the Solve module strips strict_float on one side of the pattern matching.
253+
// This leads to failed pattern matches in the nan-propagating min/max patterns.
254+
// TODO: Once strict_float has been reworked, this should be removed.
255+
Expr is_nan_not_strict(Expr x) {
256+
Type t = Bool(x.type().lanes());
257+
if (x.type().element_of() == Float(64)) {
258+
return Call::make(t, "is_nan_f64", {std::move(x)}, Call::PureExtern);
259+
}
260+
if (x.type().element_of() == Float(16)) {
261+
return Call::make(t, "is_nan_f16", {std::move(x)}, Call::PureExtern);
262+
}
263+
internal_assert(x.type().element_of() == Float(32));
264+
return Call::make(t, "is_nan_f32", {std::move(x)}, Call::PureExtern);
265+
}
266+
267+
void populate_ops_table_single_float_select(const vector<Type> &types, vector<AssociativePattern> &table) {
268+
declare_vars_single(types);
269+
// Propagating max operators
270+
table.emplace_back(select(is_nan_not_strict(x0) || x0 > y0, x0, y0), tmin_0, true);
271+
table.emplace_back(select(is_nan_not_strict(x0) || x0 >= y0, x0, y0), tmin_0, true);
272+
273+
// Propagating min operators
274+
table.emplace_back(select(is_nan_not_strict(x0) || x0 < y0, x0, y0), tmax_0, true);
275+
table.emplace_back(select(is_nan_not_strict(x0) || x0 <= y0, x0, y0), tmax_0, true);
276+
}
277+
252278
const map<TableKey, void (*)(const vector<Type> &types, vector<AssociativePattern> &)> val_type_to_populate_luts_fn = {
253279
{TableKey(ValType::All, IRNodeType::Add, 1), &populate_ops_table_single_general_add},
254280
{TableKey(ValType::All, IRNodeType::Mul, 1), &populate_ops_table_single_general_mul},
@@ -275,6 +301,10 @@ const map<TableKey, void (*)(const vector<Type> &types, vector<AssociativePatter
275301

276302
{TableKey(ValType::UInt32, IRNodeType::Cast, 1), &populate_ops_table_single_uint32_cast},
277303
{TableKey(ValType::UInt32, IRNodeType::Select, 1), &populate_ops_table_single_uint32_select},
304+
305+
{TableKey(ValType::Float16, IRNodeType::Select, 1), &populate_ops_table_single_float_select},
306+
{TableKey(ValType::Float32, IRNodeType::Select, 1), &populate_ops_table_single_float_select},
307+
{TableKey(ValType::Float64, IRNodeType::Select, 1), &populate_ops_table_single_float_select},
278308
};
279309

280310
const vector<AssociativePattern> &get_ops_table_helper(const vector<Type> &types, IRNodeType root, size_t dim) {

src/Pipeline.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ struct PipelineContents;
3232
*
3333
* The 'name' field specifies the type of Autoscheduler
3434
* to be used (e.g. Adams2019, Mullapudi2016). If this is an empty string,
35-
* no autoscheduling will be done; if not, it mustbe the name of a known Autoscheduler.
35+
* no autoscheduling will be done; if not, it must be the name of a known Autoscheduler.
3636
*
3737
* At this time, well-known autoschedulers include:
3838
* "Mullapudi2016" -- heuristics-based; the first working autoscheduler; currently built in to libHalide

test/correctness/rfactor.cpp

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1060,6 +1060,64 @@ int rfactor_precise_bounds_test() {
10601060
return 0;
10611061
}
10621062

1063+
enum MaxRFactorTestVariant {
1064+
BitwiseOr,
1065+
LogicalOr,
1066+
};
1067+
1068+
template<MaxRFactorTestVariant variant>
1069+
int isnan_max_rfactor_test() {
1070+
RDom r(0, 16);
1071+
RVar ri("ri");
1072+
Var x("x"), y("y"), u("u");
1073+
1074+
ImageParam in(Float(32), 2);
1075+
1076+
const auto make_reduce = [&](const char *name) {
1077+
Func reduce(name);
1078+
reduce(y) = Float(32).min();
1079+
switch (variant) {
1080+
case BitwiseOr:
1081+
reduce(y) = select(reduce(y) > cast(reduce.type(), in(r, y)) | is_nan(reduce(y)), reduce(y), cast(reduce.type(), in(r, y)));
1082+
break;
1083+
case LogicalOr:
1084+
reduce(y) = select(reduce(y) > cast(reduce.type(), in(r, y)) || is_nan(reduce(y)), reduce(y), cast(reduce.type(), in(r, y)));
1085+
break;
1086+
}
1087+
return reduce;
1088+
};
1089+
1090+
Func reference = make_reduce("reference");
1091+
1092+
Func reduce = make_reduce("reduce");
1093+
reduce.update(0).split(r, r, ri, 8).rfactor(ri, u);
1094+
1095+
float tests[][16] = {
1096+
{NAN, 0.29f, 0.19f, 0.68f, 0.44f, 0.40f, 0.39f, 0.53f, 0.23f, 0.21f, 0.85f, 0.19f, 0.37f, 0.03f, 0.14f, 0.64f},
1097+
{0.98f, 0.65f, 0.86f, 0.16f, 0.14f, 0.91f, 0.74f, 0.99f, 0.91f, 0.01f, 0.11f, 0.59f, 0.05f, 0.90f, 0.93f, NAN},
1098+
{0.84f, 0.14f, 0.99f, 0.19f, 0.63f, 0.12f, 0.51f, 0.67f, NAN, 0.34f, 0.89f, 0.93f, 0.72f, 0.69f, 0.58f, 0.63f},
1099+
{0.44f, 0.12f, 0.00f, 0.30f, 0.80f, 0.88f, 0.95f, 0.12f, 0.90f, 0.99f, 0.67f, 0.71f, 0.35f, 0.67f, 0.18f, 0.93f},
1100+
};
1101+
1102+
Buffer<float, 2> buf{tests};
1103+
in.set(buf);
1104+
1105+
Buffer<float, 1> ref_vals = reference.realize({4}, get_jit_target_from_environment().with_feature(Target::StrictFloat));
1106+
Buffer<float, 1> fac_vals = reduce.realize({4}, get_jit_target_from_environment().with_feature(Target::StrictFloat));
1107+
1108+
for (int i = 0; i < 4; i++) {
1109+
if (std::isnan(fac_vals(i)) && std::isnan(ref_vals(i))) {
1110+
continue;
1111+
}
1112+
if (fac_vals(i) != ref_vals(i)) {
1113+
std::cerr << "At index " << i << ", expected: " << ref_vals(i) << ", got: " << fac_vals(i) << "\n";
1114+
return 1;
1115+
}
1116+
}
1117+
1118+
return 0;
1119+
}
1120+
10631121
} // namespace
10641122

10651123
int main(int argc, char **argv) {
@@ -1100,6 +1158,8 @@ int main(int argc, char **argv) {
11001158
{"argmin rfactor test", argmin_rfactor_test},
11011159
{"inlined rfactor with disappearing rvar test", inlined_rfactor_with_disappearing_rvar_test},
11021160
{"rfactor bounds tests", rfactor_precise_bounds_test},
1161+
{"isnan max rfactor test (bitwise or)", isnan_max_rfactor_test<BitwiseOr>},
1162+
{"isnan max rfactor test (logical or)", isnan_max_rfactor_test<LogicalOr>},
11031163
};
11041164

11051165
using Sharder = Halide::Internal::Test::Sharder;

0 commit comments

Comments
 (0)