Skip to content

Commit e3b0ccd

Browse files
authored
Fix up memoize; bind to Python (#8778)
* Bind memoize and EvictionKey * Don't visit the same Function twice in Memoization.cpp * Make it illegal to memoize an output Func * Add tests
1 parent d3ca7b2 commit e3b0ccd

File tree

9 files changed

+67
-8
lines changed

9 files changed

+67
-8
lines changed

python_bindings/src/halide/halide_/PyExpr.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@ void define_expr(py::module &m) {
1717
") cannot be converted to a bool. "
1818
"If this error occurs using the 'and'/'or' keywords, "
1919
"consider using the '&'/'|' operators instead.");
20-
return false;
2120
};
2221

2322
auto expr_class =
@@ -78,6 +77,12 @@ void define_expr(py::module &m) {
7877
py::implicitly_convertible<RVar, Expr>();
7978
py::implicitly_convertible<Var, Expr>();
8079

80+
auto eviction_key_class =
81+
py::class_<EvictionKey>(m, "EvictionKey")
82+
.def(py::init<Expr>());
83+
84+
py::implicitly_convertible<Expr, EvictionKey>();
85+
8186
auto range_class =
8287
py::class_<Range>(m, "Range")
8388
.def(py::init<>())

python_bindings/src/halide/halide_/PyFunc.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@ void define_func(py::module &m) {
215215
.def("async_", &Func::async)
216216
.def("ring_buffer", &Func::ring_buffer)
217217
.def("bound_storage", &Func::bound_storage)
218-
.def("memoize", &Func::memoize)
218+
.def("memoize", &Func::memoize, py::arg("eviction_key") = EvictionKey())
219219
.def("compute_inline", &Func::compute_inline)
220220
.def("compute_root", &Func::compute_root)
221221
.def("store_root", &Func::store_root)
@@ -404,12 +404,12 @@ void define_func(py::module &m) {
404404
},
405405
py::arg("dst"), py::arg("target") = Target())
406406

407-
.def("in_", (Func(Func::*)(const Func &))&Func::in, py::arg("f"))
408-
.def("in_", (Func(Func::*)(const std::vector<Func> &fs))&Func::in, py::arg("fs"))
409-
.def("in_", (Func(Func::*)())&Func::in)
407+
.def("in_", static_cast<Func (Func::*)(const Func &)>(&Func::in), py::arg("f"))
408+
.def("in_", static_cast<Func (Func::*)(const std::vector<Func> &fs)>(&Func::in), py::arg("fs"))
409+
.def("in_", static_cast<Func (Func::*)()>(&Func::in))
410410

411-
.def("clone_in", (Func(Func::*)(const Func &))&Func::clone_in, py::arg("f"))
412-
.def("clone_in", (Func(Func::*)(const std::vector<Func> &fs))&Func::clone_in, py::arg("fs"))
411+
.def("clone_in", static_cast<Func (Func::*)(const Func &)>(&Func::clone_in), py::arg("f"))
412+
.def("clone_in", static_cast<Func (Func::*)(const std::vector<Func> &fs)>(&Func::clone_in), py::arg("fs"))
413413

414414
.def("copy_to_device", &Func::copy_to_device, py::arg("device_api") = DeviceAPI::Default_GPU)
415415
.def("copy_to_host", &Func::copy_to_host)

python_bindings/test/correctness/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ set(tests
2020
extern.py
2121
float_precision_test.py
2222
iroperator.py
23+
memoize.py
2324
multi_method_module_test.py
2425
multipass_constraints.py
2526
pystub.py
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
from halide import Func, Var
2+
3+
4+
def test_memoize():
5+
x = Var("x")
6+
7+
f = Func("f")
8+
f[x] = 0.0
9+
f[x] += 1
10+
f.compute_root().memoize()
11+
12+
output = Func("output")
13+
output[x] = f[x]
14+
15+
result = output.realize([3])
16+
assert list(result) == [1., 1., 1.]
17+
18+
19+
def main():
20+
test_memoize()
21+
22+
23+
if __name__ == "__main__":
24+
main()

src/Func.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2259,6 +2259,11 @@ class Func {
22592259
* to remove memoized entries using this eviction key from the
22602260
* cache. Memoized computations that do not provide an eviction
22612261
* key will never be evicted by this mechanism.
2262+
*
2263+
* It is invalid to memoize the output of a Pipeline; attempting
2264+
* to do so will issue an error. To cache an entire pipeline,
2265+
* either implement a caching mechanism outside of Halide or
2266+
* explicitly copy out of the cache with another output Func.
22622267
*/
22632268
Func &memoize(const EvictionKey &eviction_key = EvictionKey());
22642269

src/Memoization.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,16 @@ namespace Internal {
1717
namespace {
1818

1919
class FindParameterDependencies : public IRGraphVisitor {
20+
std::set<Function, Function::Compare> visited_functions;
21+
2022
public:
2123
FindParameterDependencies() = default;
2224
~FindParameterDependencies() override = default;
2325

2426
void visit_function(const Function &function) {
27+
if (const auto [_, inserted] = visited_functions.insert(function); !inserted) {
28+
return;
29+
}
2530
function.accept(this);
2631

2732
if (function.has_extern_definition()) {

src/Pipeline.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -499,7 +499,11 @@ Module Pipeline::compile_to_module(const vector<Argument> &args,
499499

500500
for (const Function &f : contents->outputs) {
501501
user_assert(f.has_pure_definition() || f.has_extern_definition())
502-
<< "Can't compile Pipeline with undefined output Func: " << f.name() << ".\n";
502+
<< "Can't compile Pipeline with undefined output Func: " << f.name() << ".";
503+
user_assert(!f.schedule().memoized())
504+
<< "Can't compile Pipeline with memoized output Func: " << f.name() << ". "
505+
<< "Memoization is valid only on intermediate Funcs because it takes "
506+
<< "control of buffer allocation.";
503507
}
504508

505509
string new_fn_name(fn_name);

test/error/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ tests(GROUPS error
7575
lerp_mismatch.cpp
7676
lerp_signed_weight.cpp
7777
memoize_different_compute_store.cpp
78+
memoize_output_invalid.cpp
7879
memoize_redefine_eviction_key.cpp
7980
metal_threads_too_large.cpp
8081
metal_vector_too_large.cpp
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
#include <Halide.h>
2+
using namespace Halide;
3+
4+
int main(int argc, char **argv) {
5+
Var x{"x"};
6+
Func f{"f"};
7+
f(x) = 0.0f;
8+
f(x) += 1;
9+
f.memoize();
10+
11+
f.realize({3});
12+
13+
printf("Success!\n");
14+
}

0 commit comments

Comments
 (0)