Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 29 additions & 24 deletions rolling_ball.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,16 +113,19 @@ def roll_ball(ball, array):
either on or below the patch during this process is considered part of the
background.
"""
height, width = array.shape
height, width = array.shape
pixels = numpy.float32(array.flatten())
z_ball = ball.data
ball_width = ball.width
radius = ball_width / 2
cache = numpy.zeros(width * ball_width)

z_ball_arr = numpy.array(z_ball)

for y in range(-radius, height + radius):
next_line_to_write_in_cache = (y + radius) % ball_width
next_line_to_read = y + radius

if next_line_to_read < height:
src = next_line_to_read * width
dest = next_line_to_write_in_cache * width
Expand All @@ -138,6 +141,14 @@ def roll_ball(ball, array):
y_end = y + radius
if y_end >= height:
y_end = height - 1

k_yp = numpy.array(range(y_0, y_end + 1))
y_ball = y_ball_0
y_balls = numpy.array(range(y_ball, y_ball + (y_end + 1 - y_0)))

c_cache_pointers = (k_yp % ball_width) * width
c_bps = y_balls * ball_width

for x in range(-radius, width + radius):
z = float('inf')
x_0 = x - radius
Expand All @@ -147,33 +158,27 @@ def roll_ball(ball, array):
x_end = x + radius
if x_end >= width:
x_end = width - 1
y_ball = y_ball_0
for yp in range(y_0, y_end + 1):
cache_pointer = (yp % ball_width) * width + x_0
bp = x_ball_0 + y_ball * ball_width
for xp in range(x_0, x_end + 1):
z_reduced = cache[cache_pointer] - z_ball[bp]
if z > z_reduced:
z = z_reduced
cache_pointer += 1
bp += 1
y_ball += 1

y_ball = y_ball_0
for yp in range(y_0, y_end + 1):
p = x_0 + yp * width
bp = x_ball_0 + y_ball * ball_width
for xp in range(x_0, x_end + 1):
z_min = z + z_ball[bp]
if pixels[p] < z_min:
pixels[p] = z_min
p += 1
bp += 1
y_ball += 1

cache_pointers = c_cache_pointers + x_0
bps = x_ball_0 + c_bps

for c in range(0, x_end + 1 - x_0):
z_reduced = cache[cache_pointers + c] - z_ball_arr[bps + c]
m_val = numpy.min(z_reduced)
if m_val < z:
z = m_val

ps = x_0 + k_yp * width

for c in range(0, x_end + 1 - x_0):
z_mins = z + z_ball_arr[bps + c]
b_arr = pixels[ps+c] < z_mins
pixels[(ps+c)[b_arr]] = z_mins[b_arr]

return numpy.reshape(pixels, array.shape)



class RollingBall(object):
"""
A rolling ball (or actually a square part thereof).
Expand Down