diff --git a/rolling_ball.py b/rolling_ball.py index 786fb15..dbccee1 100644 --- a/rolling_ball.py +++ b/rolling_ball.py @@ -113,16 +113,19 @@ def roll_ball(ball, array): either on or below the patch during this process is considered part of the background. """ - height, width = array.shape + height, width = array.shape pixels = numpy.float32(array.flatten()) z_ball = ball.data ball_width = ball.width radius = ball_width / 2 cache = numpy.zeros(width * ball_width) + z_ball_arr = numpy.array(z_ball) + for y in range(-radius, height + radius): next_line_to_write_in_cache = (y + radius) % ball_width next_line_to_read = y + radius + if next_line_to_read < height: src = next_line_to_read * width dest = next_line_to_write_in_cache * width @@ -138,6 +141,14 @@ def roll_ball(ball, array): y_end = y + radius if y_end >= height: y_end = height - 1 + + k_yp = numpy.array(range(y_0, y_end + 1)) + y_ball = y_ball_0 + y_balls = numpy.array(range(y_ball, y_ball + (y_end + 1 - y_0))) + + c_cache_pointers = (k_yp % ball_width) * width + c_bps = y_balls * ball_width + for x in range(-radius, width + radius): z = float('inf') x_0 = x - radius @@ -147,33 +158,27 @@ def roll_ball(ball, array): x_end = x + radius if x_end >= width: x_end = width - 1 - y_ball = y_ball_0 - for yp in range(y_0, y_end + 1): - cache_pointer = (yp % ball_width) * width + x_0 - bp = x_ball_0 + y_ball * ball_width - for xp in range(x_0, x_end + 1): - z_reduced = cache[cache_pointer] - z_ball[bp] - if z > z_reduced: - z = z_reduced - cache_pointer += 1 - bp += 1 - y_ball += 1 - - y_ball = y_ball_0 - for yp in range(y_0, y_end + 1): - p = x_0 + yp * width - bp = x_ball_0 + y_ball * ball_width - for xp in range(x_0, x_end + 1): - z_min = z + z_ball[bp] - if pixels[p] < z_min: - pixels[p] = z_min - p += 1 - bp += 1 - y_ball += 1 + + cache_pointers = c_cache_pointers + x_0 + bps = x_ball_0 + c_bps + + for c in range(0, x_end + 1 - x_0): + z_reduced = cache[cache_pointers + c] - z_ball_arr[bps + c] + m_val = numpy.min(z_reduced) + if m_val < z: + z = m_val + + ps = x_0 + k_yp * width + + for c in range(0, x_end + 1 - x_0): + z_mins = z + z_ball_arr[bps + c] + b_arr = pixels[ps+c] < z_mins + pixels[(ps+c)[b_arr]] = z_mins[b_arr] return numpy.reshape(pixels, array.shape) + class RollingBall(object): """ A rolling ball (or actually a square part thereof).