Skip to content

Commit 75ddeaf

Browse files
authored
Add all remaining IROperator ops to Python bindings (#8771)
* Fix typo in IROperator.h * Add missing mux overloads * Delete errant atan overload to atan2 Neither Python standard `math` nor `np.atan` support this hypothetical overload * Add fast_sin and fast_cos Fixes #8751 * Use static_cast for consistent formatting Different clang-format versions want to format the C-style cast differently. This is annoying. Add some blank lines for visual flow. * Fix bindings for random_{float,uint,int} static_cast<> didn't work for default-argument version, so call it through a lambda. * Bind scatter and gather * Bind extract_bits and concat_bits * Bind widening arithmetic operators Fixes #8769 * Bind remaining saturating and rounding operators
1 parent 2d1f036 commit 75ddeaf

File tree

2 files changed

+51
-19
lines changed

2 files changed

+51
-19
lines changed

python_bindings/src/halide/halide_/PyIROperator.cpp

Lines changed: 50 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -128,15 +128,16 @@ void define_operators(py::module &m) {
128128
return py::cast(false_expr_value);
129129
});
130130

131-
m.def("mux", (Expr(*)(const Expr &, const std::vector<Expr> &))&mux);
131+
m.def("mux", static_cast<Expr (*)(const Expr &, const std::vector<Expr> &)>(&mux));
132+
m.def("mux", static_cast<Expr (*)(const Expr &, const Tuple &)>(&mux));
133+
m.def("mux", static_cast<Tuple (*)(const Expr &, const std::vector<Tuple> &)>(&mux));
132134

133135
m.def("sin", &sin);
134136
m.def("asin", &asin);
135137
m.def("cos", &cos);
136138
m.def("acos", &acos);
137139
m.def("tan", &tan);
138140
m.def("atan", &atan);
139-
m.def("atan", &atan2);
140141
m.def("atan2", &atan2);
141142
m.def("sinh", &sinh);
142143
m.def("asinh", &asinh);
@@ -150,6 +151,8 @@ void define_operators(py::module &m) {
150151
m.def("log", &log);
151152
m.def("pow", &pow);
152153
m.def("erf", &erf);
154+
m.def("fast_sin", &fast_sin);
155+
m.def("fast_cos", &fast_cos);
153156
m.def("fast_log", &fast_log);
154157
m.def("fast_exp", &fast_exp);
155158
m.def("fast_pow", &fast_pow);
@@ -163,55 +166,84 @@ void define_operators(py::module &m) {
163166
m.def("is_nan", &is_nan);
164167
m.def("is_inf", &is_inf);
165168
m.def("is_finite", &is_finite);
166-
m.def("reinterpret", (Expr(*)(Type, Expr))&reinterpret);
167-
m.def("cast", (Expr(*)(Type, Expr))&cast);
169+
m.def("reinterpret", static_cast<Expr (*)(Type, Expr)>(&reinterpret));
170+
m.def("cast", static_cast<Expr (*)(Type, Expr)>(&cast));
171+
168172
m.def("print", [](const py::args &args) -> Expr {
169173
return print(collect_print_args(args));
170174
});
175+
171176
m.def(
172177
"print_when", [](const Expr &condition, const py::args &args) -> Expr {
173178
return print_when(condition, collect_print_args(args));
174179
},
175180
py::arg("condition"));
181+
176182
m.def(
177183
"require", [](const Expr &condition, const Expr &value, const py::args &args) -> Expr {
178184
auto v = args_to_vector<Expr>(args);
179185
v.insert(v.begin(), value);
180186
return require(condition, v);
181187
},
182188
py::arg("condition"), py::arg("value"));
189+
183190
m.def("lerp", &lerp);
184191
m.def("popcount", &popcount);
185192
m.def("count_leading_zeros", &count_leading_zeros);
186193
m.def("count_trailing_zeros", &count_trailing_zeros);
187194
m.def("div_round_to_zero", &div_round_to_zero);
188195
m.def("mod_round_to_zero", &mod_round_to_zero);
189-
m.def("random_float", (Expr(*)())&random_float);
190-
m.def("random_uint", (Expr(*)())&random_uint);
191-
m.def("random_int", (Expr(*)())&random_int);
192-
m.def("random_float", (Expr(*)(Expr))&random_float, py::arg("seed"));
193-
m.def("random_uint", (Expr(*)(Expr))&random_uint, py::arg("seed"));
194-
m.def("random_int", (Expr(*)(Expr))&random_int, py::arg("seed"));
195-
m.def("undef", (Expr(*)(Type))&undef);
196+
m.def("random_float", [] { return random_float(); });
197+
m.def("random_float", &random_float, py::arg("seed"));
198+
m.def("random_uint", [] { return random_uint(); });
199+
m.def("random_uint", &random_uint, py::arg("seed"));
200+
m.def("random_int", [] { return random_int(); });
201+
m.def("random_int", &random_int, py::arg("seed"));
202+
m.def("undef", static_cast<Expr (*)(Type)>(&undef));
203+
196204
m.def(
197205
"memoize_tag", [](const Expr &result, const py::args &cache_key_values) -> Expr {
198206
return Internal::memoize_tag_helper(result, args_to_vector<Expr>(cache_key_values));
199207
},
200208
py::arg("result"));
209+
201210
m.def("likely", &likely);
202211
m.def("likely_if_innermost", &likely_if_innermost);
203-
m.def("saturating_cast", (Expr(*)(Type, Expr))&saturating_cast);
212+
m.def("saturating_cast", static_cast<Expr (*)(Type, Expr)>(&saturating_cast));
204213
m.def("strict_float", &strict_float);
214+
m.def("scatter", static_cast<Expr (*)(const std::vector<Expr> &)>(&scatter));
215+
m.def("gather", static_cast<Expr (*)(const std::vector<Expr> &)>(&gather));
216+
m.def("extract_bits", static_cast<Expr (*)(Type, const Expr &, const Expr &)>(&extract_bits));
217+
m.def("concat_bits", &concat_bits);
218+
m.def("widen_right_add", &widen_right_add);
219+
m.def("widen_right_mul", &widen_right_mul);
220+
m.def("widen_right_sub", &widen_right_sub);
221+
m.def("widening_add", &widening_add);
222+
m.def("widening_mul", &widening_mul);
223+
m.def("widening_sub", &widening_sub);
224+
m.def("widening_shift_left", static_cast<Expr (*)(Expr, int)>(&widening_shift_left));
225+
m.def("widening_shift_left", static_cast<Expr (*)(Expr, Expr)>(&widening_shift_left));
226+
m.def("widening_shift_right", static_cast<Expr (*)(Expr, int)>(&widening_shift_right));
227+
m.def("widening_shift_right", static_cast<Expr (*)(Expr, Expr)>(&widening_shift_right));
228+
m.def("rounding_shift_left", static_cast<Expr (*)(Expr, int)>(&rounding_shift_left));
229+
m.def("rounding_shift_left", static_cast<Expr (*)(Expr, Expr)>(&rounding_shift_left));
230+
m.def("rounding_shift_right", static_cast<Expr (*)(Expr, int)>(&rounding_shift_right));
231+
m.def("rounding_shift_right", static_cast<Expr (*)(Expr, Expr)>(&rounding_shift_right));
232+
m.def("saturating_add", &saturating_add);
233+
m.def("saturating_sub", &saturating_sub);
234+
m.def("halving_add", &halving_add);
235+
m.def("rounding_halving_add", &rounding_halving_add);
236+
m.def("halving_sub", &halving_sub);
237+
m.def("mul_shift_right", static_cast<Expr (*)(Expr, Expr, int)>(&mul_shift_right));
238+
m.def("mul_shift_right", static_cast<Expr (*)(Expr, Expr, Expr)>(&mul_shift_right));
239+
m.def("rounding_mul_shift_right", static_cast<Expr (*)(Expr, Expr, int)>(&rounding_mul_shift_right));
240+
m.def("rounding_mul_shift_right", static_cast<Expr (*)(Expr, Expr, Expr)>(&rounding_mul_shift_right));
205241
m.def("target_arch_is", &target_arch_is);
206242
m.def("target_bits", &target_bits);
207243
m.def("target_has_feature", &target_has_feature);
208-
m.def("target_natural_vector_size", [](const Type &t) -> Expr {
209-
return target_natural_vector_size(t);
210-
});
244+
m.def("target_natural_vector_size", static_cast<Expr (*)(Type)>(&target_natural_vector_size));
211245
m.def("target_os_is", &target_os_is);
212-
m.def("logical_not", [](const Expr &expr) -> Expr {
213-
return !expr;
214-
});
246+
m.def("logical_not", [](const Expr &expr) -> Expr { return !expr; });
215247
}
216248

217249
} // namespace PythonBindings

src/IROperator.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1555,7 +1555,7 @@ f(scatter(3, 5)) = f(select(p, gather(5, 3), gather(3, 5)));
15551555
f(select(p, scatter(3, 5, 5), scatter(1, 2, 3))) = f(select(p, gather(5, 3, 3), gather(2, 3, 1)));
15561556
\endcode
15571557
*
1558-
* Note that in the p == true case, we redudantly load from 3 and write
1558+
* Note that in the p == true case, we redundantly load from 3 and write
15591559
* to 5 twice.
15601560
*/
15611561
//@{

0 commit comments

Comments
 (0)