@@ -1060,6 +1060,64 @@ int rfactor_precise_bounds_test() {
10601060 return 0 ;
10611061}
10621062
1063+ enum MaxRFactorTestVariant {
1064+ BitwiseOr,
1065+ LogicalOr,
1066+ };
1067+
1068+ template <MaxRFactorTestVariant variant>
1069+ int isnan_max_rfactor_test () {
1070+ RDom r (0 , 16 );
1071+ RVar ri (" ri" );
1072+ Var x (" x" ), y (" y" ), u (" u" );
1073+
1074+ ImageParam in (Float (32 ), 2 );
1075+
1076+ const auto make_reduce = [&](const char *name) {
1077+ Func reduce (name);
1078+ reduce (y) = Float (32 ).min ();
1079+ switch (variant) {
1080+ case BitwiseOr:
1081+ reduce (y) = select (reduce (y) > cast (reduce.type (), in (r, y)) | is_nan (reduce (y)), reduce (y), cast (reduce.type (), in (r, y)));
1082+ break ;
1083+ case LogicalOr:
1084+ reduce (y) = select (reduce (y) > cast (reduce.type (), in (r, y)) || is_nan (reduce (y)), reduce (y), cast (reduce.type (), in (r, y)));
1085+ break ;
1086+ }
1087+ return reduce;
1088+ };
1089+
1090+ Func reference = make_reduce (" reference" );
1091+
1092+ Func reduce = make_reduce (" reduce" );
1093+ reduce.update (0 ).split (r, r, ri, 8 ).rfactor (ri, u);
1094+
1095+ float tests[][16 ] = {
1096+ {NAN, 0 .29f , 0 .19f , 0 .68f , 0 .44f , 0 .40f , 0 .39f , 0 .53f , 0 .23f , 0 .21f , 0 .85f , 0 .19f , 0 .37f , 0 .03f , 0 .14f , 0 .64f },
1097+ {0 .98f , 0 .65f , 0 .86f , 0 .16f , 0 .14f , 0 .91f , 0 .74f , 0 .99f , 0 .91f , 0 .01f , 0 .11f , 0 .59f , 0 .05f , 0 .90f , 0 .93f , NAN},
1098+ {0 .84f , 0 .14f , 0 .99f , 0 .19f , 0 .63f , 0 .12f , 0 .51f , 0 .67f , NAN, 0 .34f , 0 .89f , 0 .93f , 0 .72f , 0 .69f , 0 .58f , 0 .63f },
1099+ {0 .44f , 0 .12f , 0 .00f , 0 .30f , 0 .80f , 0 .88f , 0 .95f , 0 .12f , 0 .90f , 0 .99f , 0 .67f , 0 .71f , 0 .35f , 0 .67f , 0 .18f , 0 .93f },
1100+ };
1101+
1102+ Buffer<float , 2 > buf{tests};
1103+ in.set (buf);
1104+
1105+ Buffer<float , 1 > ref_vals = reference.realize ({4 }, get_jit_target_from_environment ().with_feature (Target::StrictFloat));
1106+ Buffer<float , 1 > fac_vals = reduce.realize ({4 }, get_jit_target_from_environment ().with_feature (Target::StrictFloat));
1107+
1108+ for (int i = 0 ; i < 4 ; i++) {
1109+ if (std::isnan (fac_vals (i)) && std::isnan (ref_vals (i))) {
1110+ continue ;
1111+ }
1112+ if (fac_vals (i) != ref_vals (i)) {
1113+ std::cerr << " At index " << i << " , expected: " << ref_vals (i) << " , got: " << fac_vals (i) << " \n " ;
1114+ return 1 ;
1115+ }
1116+ }
1117+
1118+ return 0 ;
1119+ }
1120+
10631121} // namespace
10641122
10651123int main (int argc, char **argv) {
@@ -1100,6 +1158,8 @@ int main(int argc, char **argv) {
11001158 {" argmin rfactor test" , argmin_rfactor_test},
11011159 {" inlined rfactor with disappearing rvar test" , inlined_rfactor_with_disappearing_rvar_test},
11021160 {" rfactor bounds tests" , rfactor_precise_bounds_test},
1161+ {" isnan max rfactor test (bitwise or)" , isnan_max_rfactor_test<BitwiseOr>},
1162+ {" isnan max rfactor test (logical or)" , isnan_max_rfactor_test<LogicalOr>},
11031163 };
11041164
11051165 using Sharder = Halide::Internal::Test::Sharder;
0 commit comments