Skip to content

Commit 3577f88

Browse files
abadamssteven-johnson
authored andcommitted
Fix type error in VectorizeLoops (#8055)
1 parent 2111594 commit 3577f88

File tree

2 files changed

+70
-1
lines changed

2 files changed

+70
-1
lines changed

src/VectorizeLoops.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ Interval bounds_of_lanes(const Expr &e) {
134134
Interval ia = bounds_of_lanes(not_->a);
135135
return {!ia.max, !ia.min};
136136
} else if (const Ramp *r = e.as<Ramp>()) {
137-
Expr last_lane_idx = make_const(r->base.type(), r->lanes - 1);
137+
Expr last_lane_idx = make_const(r->base.type().element_of(), r->lanes - 1);
138138
Interval ib = bounds_of_lanes(r->base);
139139
const Broadcast *b = as_scalar_broadcast(r->stride);
140140
Expr stride = b ? b->value : r->stride;
@@ -875,6 +875,7 @@ class VectorSubs : public IRMutator {
875875
// generating a scalar condition that checks if
876876
// the least-true lane is true.
877877
Expr all_true = bounds_of_lanes(likely->args[0]).min;
878+
internal_assert(all_true.type() == Bool());
878879
// Wrap it in the same flavor of likely
879880
all_true = Call::make(Bool(), likely->name,
880881
{all_true}, Call::PureIntrinsic);

test/correctness/fuzz_schedule.cpp

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,74 @@ int main(int argc, char **argv) {
202202
check_blur_output(buf, correct);
203203
}
204204

205+
// https://github.com/halide/Halide/issues/8054
206+
{
207+
ImageParam input(Float(32), 2, "input");
208+
const float r_sigma = 0.1;
209+
const int s_sigma = 8;
210+
Func bilateral_grid{"bilateral_grid"};
211+
212+
Var x("x"), y("y"), z("z"), c("c");
213+
214+
// Add a boundary condition
215+
Func clamped = Halide::BoundaryConditions::repeat_edge(input);
216+
217+
// Construct the bilateral grid
218+
RDom r(0, s_sigma, 0, s_sigma);
219+
Expr val = clamped(x * s_sigma + r.x - s_sigma / 2, y * s_sigma + r.y - s_sigma / 2);
220+
val = clamp(val, 0.0f, 1.0f);
221+
222+
Expr zi = cast<int>(val * (1.0f / r_sigma) + 0.5f);
223+
224+
Func histogram("histogram");
225+
histogram(x, y, z, c) = 0.0f;
226+
histogram(x, y, zi, c) += mux(c, {val, 1.0f});
227+
228+
// Blur the grid using a five-tap filter
229+
Func blurx("blurx"), blury("blury"), blurz("blurz");
230+
blurz(x, y, z, c) = (histogram(x, y, z - 2, c) +
231+
histogram(x, y, z - 1, c) * 4 +
232+
histogram(x, y, z, c) * 6 +
233+
histogram(x, y, z + 1, c) * 4 +
234+
histogram(x, y, z + 2, c));
235+
blurx(x, y, z, c) = (blurz(x - 2, y, z, c) +
236+
blurz(x - 1, y, z, c) * 4 +
237+
blurz(x, y, z, c) * 6 +
238+
blurz(x + 1, y, z, c) * 4 +
239+
blurz(x + 2, y, z, c));
240+
blury(x, y, z, c) = (blurx(x, y - 2, z, c) +
241+
blurx(x, y - 1, z, c) * 4 +
242+
blurx(x, y, z, c) * 6 +
243+
blurx(x, y + 1, z, c) * 4 +
244+
blurx(x, y + 2, z, c));
245+
246+
// Take trilinear samples to compute the output
247+
val = clamp(input(x, y), 0.0f, 1.0f);
248+
Expr zv = val * (1.0f / r_sigma);
249+
zi = cast<int>(zv);
250+
Expr zf = zv - zi;
251+
Expr xf = cast<float>(x % s_sigma) / s_sigma;
252+
Expr yf = cast<float>(y % s_sigma) / s_sigma;
253+
Expr xi = x / s_sigma;
254+
Expr yi = y / s_sigma;
255+
Func interpolated("interpolated");
256+
interpolated(x, y, c) =
257+
lerp(lerp(lerp(blury(xi, yi, zi, c), blury(xi + 1, yi, zi, c), xf),
258+
lerp(blury(xi, yi + 1, zi, c), blury(xi + 1, yi + 1, zi, c), xf), yf),
259+
lerp(lerp(blury(xi, yi, zi + 1, c), blury(xi + 1, yi, zi + 1, c), xf),
260+
lerp(blury(xi, yi + 1, zi + 1, c), blury(xi + 1, yi + 1, zi + 1, c), xf), yf),
261+
zf);
262+
263+
// Normalize
264+
bilateral_grid(x, y) = interpolated(x, y, 0) / interpolated(x, y, 1);
265+
Pipeline p({bilateral_grid});
266+
267+
Var v6, zo, vzi;
268+
269+
blury.compute_root().split(x, x, v6, 6, TailStrategy::GuardWithIf).split(z, zo, vzi, 8, TailStrategy::GuardWithIf).reorder(y, x, c, vzi, zo, v6).vectorize(vzi).vectorize(v6);
270+
p.compile_to_module({input}, "bilateral_grid", {Target("host")});
271+
}
272+
205273
printf("Success!\n");
206274
return 0;
207275
}

0 commit comments

Comments
 (0)