Skip to content

Commit 8003b51

Browse files
authored
Add Python bindings for serialization (#8718)
1 parent 6709900 commit 8003b51

File tree

8 files changed

+363
-4
lines changed

8 files changed

+363
-4
lines changed

python_bindings/src/halide/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ target_sources(
3535
halide_/PyParameter.cpp
3636
halide_/PyPipeline.cpp
3737
halide_/PyRDom.cpp
38+
halide_/PySerialization.cpp
3839
halide_/PyStage.cpp
3940
halide_/PyTarget.cpp
4041
halide_/PyTuple.cpp

python_bindings/src/halide/halide_/PyHalide.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include "PyParameter.h"
2222
#include "PyPipeline.h"
2323
#include "PyRDom.h"
24+
#include "PySerialization.h"
2425
#include "PyTarget.h"
2526
#include "PyTuple.h"
2627
#include "PyType.h"
@@ -70,6 +71,7 @@ PYBIND11_MODULE(HALIDE_PYBIND_MODULE_NAME, m) {
7071
define_type(m);
7172
define_derivative(m);
7273
define_generator(m);
74+
define_serialization(m);
7375

7476
// There is no PyUtil yet, so just put this here
7577
m.def("load_plugin", &Halide::load_plugin, py::arg("lib_name"));

python_bindings/src/halide/halide_/PyHalide.h

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,58 @@ Expr double_to_expr_check(double v);
3737
Target to_jit_target(const Target &target);
3838
Target to_aot_target(const Target &target);
3939

40+
// TODO: when out base toolchains are modern enough, we can just
41+
// use std::filesystem::path, since pybind11 has a built-in type
42+
// caster in <pybind11/stl/filesystem.h>.
43+
// See: https://github.com/halide/Halide/issues/8723
44+
class PathLike {
45+
std::string path;
46+
47+
public:
48+
PathLike() = default;
49+
PathLike(const py::bytes &path)
50+
: path(path) {
51+
}
52+
53+
operator const std::string &() const {
54+
return path;
55+
}
56+
57+
PyObject *decode() const {
58+
return PyUnicode_DecodeFSDefaultAndSize(path.c_str(), static_cast<Py_ssize_t>(path.size()));
59+
}
60+
};
61+
4062
} // namespace PythonBindings
4163
} // namespace Halide
4264

65+
template<>
66+
class pybind11::detail::type_caster<Halide::PythonBindings::PathLike> {
67+
public:
68+
PYBIND11_TYPE_CASTER(Halide::PythonBindings::PathLike, const_name("os.PathLike"));
69+
70+
bool load(handle src, bool) {
71+
try {
72+
PyObject *path = nullptr;
73+
if (!PyUnicode_FSConverter(src.ptr(), &path)) {
74+
throw error_already_set();
75+
}
76+
value = reinterpret_steal<bytes>(path);
77+
return true;
78+
} catch (error_already_set &) {
79+
return false;
80+
}
81+
}
82+
83+
static handle cast(const Halide::PythonBindings::PathLike &path,
84+
return_value_policy, handle) {
85+
if (auto *py_str = path.decode()) {
86+
return module_::import("pathlib")
87+
.attr("Path")(reinterpret_steal<object>(py_str))
88+
.release();
89+
}
90+
return nullptr;
91+
}
92+
};
93+
4394
#endif // HALIDE_PYTHON_BINDINGS_PYHALIDE_H
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
#include "PySerialization.h"
2+
3+
#include <pybind11/stl/filesystem.h>
4+
5+
namespace Halide {
6+
namespace PythonBindings {
7+
8+
void define_serialization(py::module &m) {
9+
// Serialize pipeline functions
10+
m.def("serialize_pipeline", //
11+
[](const Pipeline &pipeline, const PathLike &filename, std::optional<bool> get_params) -> std::optional<std::map<std::string, Parameter>> {
12+
if (get_params.value_or(false)) {
13+
std::map<std::string, Parameter> params;
14+
serialize_pipeline(pipeline, filename, params);
15+
return params;
16+
}
17+
serialize_pipeline(pipeline, filename);
18+
return {}; //
19+
},
20+
py::arg("pipeline"), //
21+
py::arg("filename"), //
22+
py::kw_only(), //
23+
py::arg("get_params") = std::nullopt, //
24+
"Serialize a Halide pipeline to a file. Accepts string or path-like objects. Optionally returns external parameters.");
25+
26+
m.def("serialize_pipeline", //
27+
[](const Pipeline &pipeline, std::optional<bool> get_params) -> std::variant<std::tuple<py::bytes, std::map<std::string, Parameter>>, py::bytes> {
28+
std::vector<uint8_t> data;
29+
if (get_params.value_or(false)) {
30+
std::map<std::string, Parameter> params;
31+
serialize_pipeline(pipeline, data, params);
32+
py::bytes bytes_data = py::bytes(reinterpret_cast<const char *>(data.data()), data.size());
33+
return std::make_tuple(bytes_data, params);
34+
}
35+
serialize_pipeline(pipeline, data);
36+
return py::bytes(reinterpret_cast<const char *>(data.data()), data.size()); //
37+
},
38+
py::arg("pipeline"), //
39+
py::kw_only(), //
40+
py::arg("get_params") = std::nullopt, //
41+
"Serialize a Halide pipeline to bytes, optionally returning external parameters as a tuple.");
42+
43+
// Deserialize pipeline functions
44+
m.def("deserialize_pipeline", //
45+
[](const py::bytes &data, const std::map<std::string, Parameter> &user_params) -> Pipeline {
46+
// TODO: rework API in serialize_pipeline to take a std::span<> in C++20
47+
// https://github.com/halide/Halide/issues/8722
48+
std::string_view view{data};
49+
std::vector<uint8_t> span{view.begin(), view.end()};
50+
return deserialize_pipeline(span, user_params); //
51+
},
52+
py::arg("data"), //
53+
py::arg("user_params") = std::map<std::string, Parameter>{}, //
54+
"Deserialize a Halide pipeline from bytes.");
55+
56+
m.def("deserialize_pipeline", //
57+
[](const PathLike &filename, const std::map<std::string, Parameter> &user_params) -> Pipeline {
58+
return deserialize_pipeline(filename, user_params); //
59+
},
60+
py::arg("filename"), //
61+
py::arg("user_params") = std::map<std::string, Parameter>{}, //
62+
"Deserialize a Halide pipeline from a file. Accepts string or path-like objects.");
63+
64+
// Deserialize parameters functions
65+
m.def("deserialize_parameters", //
66+
[](const py::bytes &data) -> std::map<std::string, Parameter> {
67+
// TODO: rework API in serialize_pipeline to take a std::span<> in C++20
68+
// https://github.com/halide/Halide/issues/8722
69+
std::string_view view{data};
70+
std::vector<uint8_t> span{view.begin(), view.end()};
71+
return deserialize_parameters(span); //
72+
},
73+
py::arg("data"), //
74+
"Deserialize external parameters from serialized pipeline bytes.");
75+
76+
m.def("deserialize_parameters", //
77+
[](const PathLike &filename) -> std::map<std::string, Parameter> {
78+
return deserialize_parameters(filename); //
79+
},
80+
py::arg("filename"), //
81+
"Deserialize external parameters from a serialized pipeline file. Accepts string or path-like objects.");
82+
}
83+
84+
} // namespace PythonBindings
85+
} // namespace Halide
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
#ifndef HALIDE_PYTHON_BINDINGS_PYSERIALIZATION_H
2+
#define HALIDE_PYTHON_BINDINGS_PYSERIALIZATION_H
3+
4+
#include "PyHalide.h"
5+
6+
namespace Halide {
7+
namespace PythonBindings {
8+
9+
void define_serialization(py::module &m);
10+
11+
} // namespace PythonBindings
12+
} // namespace Halide
13+
14+
#endif // HALIDE_PYTHON_BINDINGS_PYSERIALIZATION_H

python_bindings/test/correctness/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ set(tests
2525
pystub.py
2626
rdom.py
2727
realize_warnings.py
28+
serialization.py
2829
target.py
2930
tuple_select.py
3031
type.py
Lines changed: 205 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,205 @@
1+
import tempfile
2+
from pathlib import Path
3+
4+
import halide as hl
5+
import numpy as np
6+
7+
8+
def test_serialize_deserialize_pipeline_file():
9+
"""Test serializing and deserializing a pipeline to/from a file."""
10+
x, y = hl.vars("x y")
11+
f = hl.Func("f")
12+
f[x, y] = x + y
13+
14+
pipeline = hl.Pipeline(f)
15+
16+
with tempfile.NamedTemporaryFile(suffix=".hlpipe", delete=False) as tmp:
17+
filename = Path(tmp.name)
18+
19+
try:
20+
hl.serialize_pipeline(pipeline, filename)
21+
22+
assert filename.exists()
23+
assert filename.stat().st_size > 0
24+
25+
deserialized_pipeline = hl.deserialize_pipeline(filename)
26+
27+
result = deserialized_pipeline.realize([10, 10])
28+
assert result.dim(0).extent() == 10
29+
assert result.dim(1).extent() == 10
30+
31+
expected = np.add.outer(np.arange(10), np.arange(10))
32+
assert np.array_equal(np.array(result), expected)
33+
34+
finally:
35+
filename.unlink(missing_ok=True)
36+
37+
38+
def test_serialize_deserialize_pipeline_bytes():
39+
"""Test serializing and deserializing a pipeline to/from bytes."""
40+
x, y = hl.Var("x"), hl.Var("y")
41+
f = hl.Func("f")
42+
f[x, y] = x * 2 + y
43+
44+
pipeline = hl.Pipeline(f)
45+
46+
data = hl.serialize_pipeline(pipeline)
47+
assert isinstance(data, bytes)
48+
assert len(data) > 0
49+
50+
deserialized_pipeline = hl.deserialize_pipeline(data)
51+
52+
result = deserialized_pipeline.realize([5, 5])
53+
expected = np.fromfunction(lambda x, y: x * 2 + y, (5, 5), dtype=int).transpose()
54+
assert np.array_equal(np.array(result), expected)
55+
56+
57+
def test_serialize_deserialize_with_parameters():
58+
"""Test serializing and deserializing a pipeline with external parameters."""
59+
x = hl.Var("x")
60+
p = hl.Param(hl.Int(32), "multiplier", 1)
61+
f = hl.Func("f")
62+
f[x] = x * p
63+
64+
pipeline = hl.Pipeline(f)
65+
66+
with tempfile.NamedTemporaryFile(suffix=".hlpipe", delete=False) as tmp:
67+
filename = Path(tmp.name)
68+
69+
try:
70+
params = hl.serialize_pipeline(pipeline, filename, get_params=True)
71+
72+
assert "multiplier" in params
73+
assert params["multiplier"].name() == "multiplier"
74+
75+
user_params = {"multiplier": hl.Param(hl.Int(32), "multiplier", 5).parameter()}
76+
deserialized_pipeline = hl.deserialize_pipeline(filename, user_params)
77+
78+
result = deserialized_pipeline.realize([3])
79+
assert list(result) == [0, 5, 10]
80+
81+
finally:
82+
filename.unlink(missing_ok=True)
83+
84+
85+
def test_serialize_deserialize_with_parameters_bytes():
86+
"""Test serializing and deserializing a pipeline with parameters using bytes."""
87+
x = hl.Var("x")
88+
p = hl.Param(hl.Int(32), "offset", 0)
89+
f = hl.Func("f")
90+
f[x] = x + p
91+
92+
pipeline = hl.Pipeline(f)
93+
94+
data, params = hl.serialize_pipeline(pipeline, get_params=True)
95+
96+
assert "offset" in params
97+
assert params["offset"].name() == "offset"
98+
99+
user_params = {"offset": hl.Param(hl.Int(32), "offset", 100).parameter()}
100+
deserialized_pipeline = hl.deserialize_pipeline(data, user_params)
101+
102+
result = deserialized_pipeline.realize([3])
103+
assert list(result) == [100, 101, 102]
104+
105+
106+
def test_deserialize_parameters_file():
107+
"""Test deserializing just the parameters from a file."""
108+
x = hl.Var("x")
109+
p1 = hl.Param(hl.Int(32), "param1", 1)
110+
p2 = hl.Param(hl.Float(32), "param2", 2.0)
111+
f = hl.Func("f")
112+
f[x] = hl.cast(hl.Int(32), x * p1 + p2)
113+
114+
pipeline = hl.Pipeline(f)
115+
116+
with tempfile.NamedTemporaryFile(suffix=".hlpipe", delete=False) as tmp:
117+
filename = Path(tmp.name)
118+
119+
try:
120+
hl.serialize_pipeline(pipeline, filename)
121+
params = hl.deserialize_parameters(filename)
122+
123+
assert "param1" in params
124+
assert "param2" in params
125+
assert params["param1"].name() == "param1"
126+
assert params["param2"].name() == "param2"
127+
assert params["param1"].type() == hl.Int(32)
128+
assert params["param2"].type() == hl.Float(32)
129+
130+
finally:
131+
filename.unlink(missing_ok=True)
132+
133+
134+
def test_deserialize_parameters_bytes():
135+
"""Test deserializing just the parameters from bytes."""
136+
x = hl.Var("x")
137+
p1 = hl.Param(hl.UInt(16), "width", 64)
138+
p2 = hl.Param(hl.UInt(16), "height", 64)
139+
f = hl.Func("f")
140+
f[x] = hl.select(x < p1, p2, 0)
141+
142+
pipeline = hl.Pipeline(f)
143+
144+
data = hl.serialize_pipeline(pipeline)
145+
params = hl.deserialize_parameters(data)
146+
147+
assert "width" in params
148+
assert "height" in params
149+
assert params["width"].type() == hl.UInt(16)
150+
assert params["height"].type() == hl.UInt(16)
151+
152+
153+
def test_pipeline_with_multiple_outputs():
154+
"""Test serializing/deserializing a pipeline with multiple outputs."""
155+
x, y = hl.vars("x y")
156+
157+
f1 = hl.Func("f1")
158+
f1[x, y] = x + y
159+
160+
f2 = hl.Func("f2")
161+
f2[x, y] = x - y
162+
163+
pipeline = hl.Pipeline([f1, f2])
164+
165+
data = hl.serialize_pipeline(pipeline)
166+
deserialized_pipeline = hl.deserialize_pipeline(data)
167+
168+
results = deserialized_pipeline.realize([5, 5])
169+
assert len(results) == 2
170+
171+
expected = np.add.outer(np.arange(5), np.arange(5))
172+
assert np.array_equal(np.array(results[0]), expected)
173+
174+
expected = np.subtract.outer(np.arange(5), np.arange(5)).transpose()
175+
assert np.array_equal(np.array(results[1]), expected)
176+
177+
178+
def test_empty_user_params():
179+
"""Test that empty user_params works correctly."""
180+
x = hl.Var("x")
181+
f = hl.Func("f")
182+
f[x] = x * x
183+
184+
pipeline = hl.Pipeline(f)
185+
data = hl.serialize_pipeline(pipeline)
186+
187+
deserialized1 = hl.deserialize_pipeline(data, {})
188+
result1 = deserialized1.realize([3])
189+
assert list(result1) == [0, 1, 4]
190+
191+
deserialized2 = hl.deserialize_pipeline(data)
192+
result2 = deserialized2.realize([3])
193+
assert list(result2) == [0, 1, 4]
194+
195+
196+
if __name__ == "__main__":
197+
test_serialize_deserialize_pipeline_file()
198+
test_serialize_deserialize_pipeline_bytes()
199+
test_serialize_deserialize_with_parameters()
200+
test_serialize_deserialize_with_parameters_bytes()
201+
test_deserialize_parameters_file()
202+
test_deserialize_parameters_bytes()
203+
test_pipeline_with_multiple_outputs()
204+
test_empty_user_params()
205+
print("Success!")

0 commit comments

Comments
 (0)