Skip to content

Commit d3ca7b2

Browse files
authored
Avoid throwing from a destructor in PartitionLoops.cpp (#8767)
Also limit correctness_bounds_of_pure_intrinsics stack usage on ASAN
1 parent adca0ce commit d3ca7b2

File tree

4 files changed

+52
-27
lines changed

4 files changed

+52
-27
lines changed

src/PartitionLoops.cpp

Lines changed: 17 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -547,40 +547,31 @@ class PartitionLoops : public IRMutator {
547547
(op->partition_policy == Partition::Auto && in_tail)) {
548548
return IRMutator::visit(op);
549549
}
550+
const auto [loop, mutated] = visit_for(op);
551+
user_assert(op->partition_policy != Partition::Always || mutated)
552+
<< "Loop Partition Policy is set to " << op->partition_policy
553+
<< " for " << op->name << ", but no loop partitioning was performed.";
554+
return loop;
555+
}
550556

557+
std::tuple<Stmt, bool> visit_for(const For *op) {
558+
bool mutated = false;
551559
Stmt body = op->body;
552560

553-
// A struct that upon destruction will check if the current For was partitioned
554-
// and error out if it wasn't when the schedule demanded it.
555-
struct ErrorIfNotMutated {
556-
const For *op;
557-
bool must_mutate;
558-
bool mutated{false};
559-
ErrorIfNotMutated(const For *op, bool must_mutate)
560-
: op(op), must_mutate(must_mutate) {
561-
}
562-
~ErrorIfNotMutated() {
563-
if (must_mutate && !mutated) {
564-
user_error << "Loop Partition Policy is set to " << op->partition_policy
565-
<< " for " << op->name << ", but no loop partitioning was performed.";
566-
}
567-
}
568-
} mutation_checker{op, op->partition_policy == Partition::Always};
569-
570561
ScopedValue<bool> old_in_gpu_loop(in_gpu_loop, in_gpu_loop || is_gpu(op->for_type));
571562

572563
// If we're inside GPU kernel, and the body contains thread
573564
// barriers or warp shuffles, it's not safe to partition loops.
574565
if (in_gpu_loop && contains_warp_synchronous_logic(op)) {
575-
return IRMutator::visit(op);
566+
return {IRMutator::visit(op), mutated};
576567
}
577568

578569
// Find simplifications in this loop body
579570
FindSimplifications finder(op->name);
580571
body.accept(&finder);
581572

582573
if (finder.simplifications.empty()) {
583-
return IRMutator::visit(op);
574+
return {IRMutator::visit(op), mutated};
584575
}
585576

586577
debug(3) << "\n\n**** Partitioning loop over " << op->name << "\n";
@@ -776,13 +767,13 @@ class PartitionLoops : public IRMutator {
776767
prologue = For::make(op->name, op->min, min_steady - op->min,
777768
op->for_type, op->partition_policy, op->device_api, prologue);
778769
stmt = Block::make(prologue, stmt);
779-
mutation_checker.mutated = true;
770+
mutated = true;
780771
}
781772
if (make_epilogue) {
782773
epilogue = For::make(op->name, max_steady, op->min + op->extent - max_steady,
783774
op->for_type, op->partition_policy, op->device_api, epilogue);
784775
stmt = Block::make(stmt, epilogue);
785-
mutation_checker.mutated = true;
776+
mutated = true;
786777
}
787778
} else {
788779
// For parallel for loops we could use a Fork node here,
@@ -801,15 +792,15 @@ class PartitionLoops : public IRMutator {
801792
stmt = simpler_body;
802793
if (make_epilogue && make_prologue && equal(prologue, epilogue)) {
803794
stmt = IfThenElse::make(min_steady <= loop_var && loop_var < max_steady, stmt, prologue);
804-
mutation_checker.mutated = true;
795+
mutated = true;
805796
} else {
806797
if (make_epilogue) {
807798
stmt = IfThenElse::make(loop_var < max_steady, stmt, epilogue);
808-
mutation_checker.mutated = true;
799+
mutated = true;
809800
}
810801
if (make_prologue) {
811802
stmt = IfThenElse::make(loop_var < min_steady, prologue, stmt);
812-
mutation_checker.mutated = true;
803+
mutated = true;
813804
}
814805
}
815806
stmt = For::make(op->name, op->min, op->extent, op->for_type, op->partition_policy, op->device_api, stmt);
@@ -833,14 +824,14 @@ class PartitionLoops : public IRMutator {
833824
if (can_prove(epilogue_val <= prologue_val)) {
834825
// The steady state is empty. I've made a huge
835826
// mistake. Try to partition a loop further in.
836-
return IRMutator::visit(op);
827+
return {IRMutator::visit(op), mutated};
837828
}
838829

839830
debug(3) << "Partition loop.\n"
840831
<< "Old: " << Stmt(op) << "\n"
841832
<< "New: " << stmt << "\n";
842833

843-
return stmt;
834+
return {stmt, mutated};
844835
}
845836
};
846837

test/correctness/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ tests(GROUPS correctness
1111
async_order.cpp
1212
autodiff.cpp
1313
bad_likely.cpp
14+
bad_partition_always_throws.cpp
1415
bit_counting.cpp
1516
bits_known.cpp
1617
bitwise_ops.cpp
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
#include "Halide.h"
2+
using namespace Halide;
3+
4+
int main(int argc, char **argv) {
5+
#ifndef HALIDE_WITH_EXCEPTIONS
6+
printf("[SKIP] bad_partition_always_throws requires exceptions\n");
7+
return 0;
8+
#else
9+
try {
10+
Func f("f");
11+
Var x("x");
12+
f(x) = 0;
13+
f.partition(x, Partition::Always);
14+
f.realize({10});
15+
} catch (const CompileError &e) {
16+
const std::string_view msg = e.what();
17+
constexpr std::string_view expected_msg =
18+
"Loop Partition Policy is set to Always for f.s0.x, "
19+
"but no loop partitioning was performed.";
20+
if (msg.find(expected_msg) == std::string_view::npos) {
21+
std::cerr << "Expected error containing (" << expected_msg << "), but got (" << msg << ")\n";
22+
return 1;
23+
}
24+
printf("Success!\n");
25+
return 0;
26+
}
27+
28+
printf("Did not see any exception!\n");
29+
return 1;
30+
#endif
31+
}

test/correctness/bounds_of_pure_intrinsics.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@ int main(int argc, char **argv) {
1313
Scope<Interval> scope;
1414
scope.push(p2.name(), Interval{p2_min, p2_max});
1515

16-
for (int limit = 1; limit < 500; limit++) {
16+
// This test uses a lot of stack space, especially on ASAN, where we don't
17+
// do any stack switching (see Util.cpp). Don't push this number too far.
18+
for (int limit = 1; limit < 100; limit++) {
1719
Expr e1 = p1, e2 = p2;
1820
for (int i = 0; i < limit; i++) {
1921
e1 = e1 * p1 + (i + 1);

0 commit comments

Comments
 (0)