Skip to content

Commit d32a050

Browse files
authored
Merge pull request #27 from wrightky/master
Optimize weights computation
2 parents 1e52705 + 96bd9b8 commit d32a050

File tree

5 files changed

+294
-127
lines changed

5 files changed

+294
-127
lines changed

dorado/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
__version__ = "2.4.2"
1+
__version__ = "2.5.0"
22

33

44
from . import lagrangian_walker

dorado/lagrangian_walker.py

Lines changed: 161 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -39,56 +39,185 @@ def random_pick_seed(choices, probs=None):
3939
return choices[idx]
4040

4141

42-
def make_weight(Particles, ind):
43-
"""Update weighting array with weights at this index"""
44-
# get stage values for neighboring cells
45-
stage_ind = Particles.stage[ind[0]-1:ind[0]+2, ind[1]-1:ind[1]+2]
42+
def big_sliding_window(raster):
43+
"""Creates 3D array organizing local neighbors at every index
44+
45+
Returns a raster of shape (L,W,9) which organizes (along the third
46+
dimension) all of the neighbors in the original raster at a given
47+
index, in the order [NW, N, NE, W, 0, E, SW, S, SE]. Outputs are
48+
ordered to match np.ravel(), so it functions similarly to a loop
49+
applying ravel to the elements around each index.
50+
For example, the neighboring values in raster indexed at (i,j) are
51+
raster(i-1:i+2, j-1:j+2).ravel(). These 9 values have been mapped to
52+
big_ravel(i,j,:) for ease of computations. Helper function for make_weight.
4653
47-
# calculate surface slope weights
48-
weight_sfc = maximum(0,
49-
(Particles.stage[ind] - stage_ind) /
50-
Particles.distances)
54+
**Inputs** :
55+
56+
raster : `ndarray`
57+
2D array of values (e.g. stage, qx)
58+
59+
**Outputs** :
60+
61+
big_ravel : `ndarray`
62+
3D array which sorts the D8 neighbors at index (i,j) in
63+
raster into the 3rd dimension at (i,j,:)
64+
65+
"""
66+
big_ravel = np.zeros((raster.shape[0],raster.shape[1],9))
67+
big_ravel[1:-1,1:-1,0] = raster[0:-2,0:-2]
68+
big_ravel[1:-1,1:-1,1] = raster[0:-2,1:-1]
69+
big_ravel[1:-1,1:-1,2] = raster[0:-2,2:]
70+
big_ravel[1:-1,1:-1,3] = raster[1:-1,0:-2]
71+
big_ravel[1:-1,1:-1,4] = raster[1:-1,1:-1]
72+
big_ravel[1:-1,1:-1,5] = raster[1:-1,2:]
73+
big_ravel[1:-1,1:-1,6] = raster[2:,0:-2]
74+
big_ravel[1:-1,1:-1,7] = raster[2:,1:-1]
75+
big_ravel[1:-1,1:-1,8] = raster[2:,2:]
76+
77+
return big_ravel
78+
79+
80+
def tile_local_array(local_array, L, W):
81+
"""Take a local array [[NW, N, NE], [W, 0, E], [SW, S, SE]]
82+
and repeat it into an array of shape (L,W,9), where L, W are
83+
the shape of the domain, and the original elements are ordered
84+
along the third axis. Helper function for make_weight.
85+
86+
**Inputs** :
87+
88+
local_array : `ndarray`
89+
2D array of represnting the D8 neighbors around
90+
some index (e.g. ivec, jvec)
91+
92+
L : `int`
93+
Length of the domain
94+
95+
W : `int`
96+
Width of the domain
97+
98+
**Outputs** :
99+
100+
tiled_array : `ndarray`
101+
3D array repeating local_array.ravel() LxW times
102+
103+
"""
104+
return np.tile(local_array.ravel(), (L, W, 1))
105+
106+
107+
def tile_domain_array(raster):
108+
"""Repeat a large 2D array 9 times along the third axis.
109+
Helper function for make_weight.
110+
111+
**Inputs** :
112+
113+
raster : `ndarray`
114+
2D array of values (e.g. stage, qx)
115+
116+
**Outputs** :
117+
118+
tiled_array : `ndarray`
119+
3D array repeating raster 9 times
120+
121+
"""
122+
return np.repeat(raster[:, :, np.newaxis], 9, axis=2)
51123

52-
# calculate inertial component weights
53-
weight_int = maximum(0, ((Particles.qx[ind] * Particles.jvec +
54-
Particles.qy[ind] * Particles.ivec) /
55-
Particles.distances))
56124

125+
def clear_borders(tiled_array):
126+
"""Set to zero all the edge elements of a vertical stack
127+
of 2D arrays. Helper function for make_weight.
128+
129+
**Inputs** :
130+
131+
tiled_array : `ndarray`
132+
3D array repeating raster (e.g. stage, qx) 9 times
133+
along the third axis
134+
135+
**Outputs** :
136+
137+
tiled_array : `ndarray`
138+
The same 3D array as the input, but with the borders
139+
in the first and second dimension set to 0.
140+
141+
"""
142+
tiled_array[0,:,:] = 0.
143+
tiled_array[:,0,:] = 0.
144+
tiled_array[-1,:,:] = 0.
145+
tiled_array[:,-1,:] = 0.
146+
return
147+
148+
149+
def make_weight(Particles):
150+
"""Create the weighting array for particle routing
151+
152+
Function to compute the routing weights at each index, which gets
153+
stored inside the :obj:`dorado.particle_track.Particles` object
154+
for use when routing. Called when the object gets instantiated.
155+
156+
**Inputs** :
157+
158+
Particles : :obj:`dorado.particle_track.Particles`
159+
A :obj:`dorado.particle_track.Particles` object
160+
161+
**Outputs** :
162+
163+
Updates the weights array inside the
164+
:obj:`dorado.particle_track.Particles` object
165+
166+
"""
167+
L, W = Particles.stage.shape
168+
169+
# calculate surface slope weights
170+
weight_sfc = (tile_domain_array(Particles.stage) \
171+
- big_sliding_window(Particles.stage))
172+
weight_sfc /= tile_local_array(Particles.distances, L, W)
173+
weight_sfc[weight_sfc <= 0] = 0
174+
clear_borders(weight_sfc)
175+
176+
# calculate inertial component weights
177+
weight_int = (tile_domain_array(Particles.qx)*tile_local_array(Particles.jvec, L, W)) \
178+
+ (tile_domain_array(Particles.qy)*tile_local_array(Particles.ivec, L, W))
179+
weight_int /= tile_local_array(Particles.distances, L, W)
180+
weight_int[weight_int <= 0] = 0
181+
clear_borders(weight_int)
182+
57183
# get depth and cell types for neighboring cells
58-
depth_ind = Particles.depth[ind[0]-1:ind[0]+2, ind[1]-1:ind[1]+2]
59-
ct_ind = Particles.cell_type[ind[0]-1:ind[0]+2, ind[1]-1:ind[1]+2]
184+
depth_ind = big_sliding_window(Particles.depth)
185+
ct_ind = big_sliding_window(Particles.cell_type)
60186

61187
# set weights for cells that are too shallow, or invalid 0
62188
weight_sfc[(depth_ind <= Particles.dry_depth) | (ct_ind == 2)] = 0
63189
weight_int[(depth_ind <= Particles.dry_depth) | (ct_ind == 2)] = 0
64-
190+
65191
# if sum of weights is above 0 normalize by sum of weights
66-
if nansum(weight_sfc) > 0:
67-
weight_sfc = weight_sfc / nansum(weight_sfc)
68-
69-
# if sum of weight is above 0 normalize by sum of weights
70-
if nansum(weight_int) > 0:
71-
weight_int = weight_int / nansum(weight_int)
192+
norm_sfc = np.nansum(weight_sfc, axis=2)
193+
idx_sfc = tile_domain_array((norm_sfc > 0))
194+
weight_sfc[idx_sfc] /= tile_domain_array(norm_sfc)[idx_sfc]
195+
196+
norm_int = np.nansum(weight_int, axis=2)
197+
idx_int = tile_domain_array((norm_int > 0))
198+
weight_int[idx_int] /= tile_domain_array(norm_int)[idx_int]
72199

73200
# define actual weight by using gamma, and weight components
74201
weight = Particles.gamma * weight_sfc + \
75-
(1 - Particles.gamma) * weight_int
76-
202+
(1 - Particles.gamma) * weight_int
203+
77204
# modify the weight by the depth and theta weighting parameter
78205
weight = depth_ind ** Particles.theta * weight
79-
80-
# if the depth is below the minimum depth then location is not
81-
# considered therefore set the associated weight to nan
82-
weight[(depth_ind <= Particles.dry_depth) | (ct_ind == 2)] \
83-
= np.nan
206+
207+
# if the depth is below the minimum depth then set weight to 0
208+
weight[(depth_ind <= Particles.dry_depth) | (ct_ind == 2)] = 0
84209

85210
# if it's a dead end with only nans and 0's, choose deepest cell
86-
if nansum(weight) <= 0:
87-
weight = np.zeros_like(weight)
88-
weight[depth_ind == np.max(depth_ind)] = 1.0
211+
zero_weights = tile_domain_array((np.nansum(weight, axis=2) <= 0))
212+
deepest_cells = (depth_ind == tile_domain_array(np.max(depth_ind,axis=2)))
213+
choose_deep_cells = (zero_weights & deepest_cells)
214+
weight[choose_deep_cells] = 1.0
215+
216+
# Final checks, eliminate invalid choices
217+
clear_borders(weight)
89218

90219
# set weight in the true weight array
91-
Particles.weight[ind[0], ind[1], :] = weight.ravel()
220+
Particles.weight = weight
92221

93222

94223
def get_weight(Particles, ind):
@@ -111,9 +240,6 @@ def get_weight(Particles, ind):
111240
New location given as a value between 1 and 8 (inclusive)
112241
113242
"""
114-
# Check if weights have been computed for this location:
115-
if nansum(Particles.weight[ind[0], ind[1], :]) <= 0:
116-
make_weight(Particles, ind)
117243
# randomly pick the new cell for the particle to move to using the
118244
# random_pick function and the set of weights
119245
if Particles.steepest_descent is not True:

dorado/particle_track.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -402,7 +402,7 @@ def __init__(self, params):
402402
self.walk_data = None
403403

404404
# initialize routing weights array
405-
self.weight = np.zeros((self.stage.shape[0], self.stage.shape[1], 9))
405+
lw.make_weight(self)
406406

407407

408408
# function to clear walk data if you've made a mistake while generating it

dorado/routines.py

Lines changed: 67 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -74,25 +74,29 @@ def steady_plots(particle, num_iter,
7474
# Do particle iterations
7575
walk_data = particle.run_iteration()
7676
if save_output:
77-
x0, y0, t0 = get_state(walk_data, 0)
77+
# Get current location
7878
xi, yi, ti = get_state(walk_data)
79-
80-
fig = plt.figure(dpi=200)
81-
ax = fig.add_subplot(111)
82-
im = ax.imshow(particle.depth)
83-
plt.title('Depth - Particle Iteration ' + str(i))
84-
cax = fig.add_axes([ax.get_position().x1+0.01,
85-
ax.get_position().y0,
86-
0.02,
87-
ax.get_position().height])
88-
cbar = plt.colorbar(im, cax=cax)
89-
cbar.set_label('Water Depth [m]')
90-
ax.scatter(y0, x0, c='b', s=0.75)
91-
ax.scatter(yi, xi, c='r', s=0.75)
79+
if i == 0:
80+
# Initialize figure
81+
x0, y0, t0 = get_state(walk_data, 0)
82+
fig = plt.figure(dpi=200)
83+
ax = fig.add_subplot(111)
84+
im = ax.imshow(particle.depth)
85+
cax = fig.add_axes([ax.get_position().x1+0.01,
86+
ax.get_position().y0,
87+
0.02,
88+
ax.get_position().height])
89+
cbar = plt.colorbar(im, cax=cax)
90+
cbar.set_label('Water Depth [m]')
91+
orig = ax.scatter(y0, x0, c='b', s=0.75)
92+
newloc = ax.scatter(yi, xi, c='r', s=0.75)
93+
else:
94+
# Update figure with new locations
95+
newloc.set_offsets(np.array([yi,xi]).T)
96+
ax.set_title('Depth - Particle Iteration ' + str(i))
9297
plt.savefig(folder_name+os.sep +
9398
'figs'+os.sep+'output'+str(i)+'.png',
9499
bbox_inches='tight')
95-
plt.close()
96100

97101
if save_output:
98102
# save data as json text file - technically human readable
@@ -165,6 +169,7 @@ def unsteady_plots(dx, Np_tracer, seed_xloc, seed_yloc, num_steps, timestep,
165169
# init params
166170
params = modelParams()
167171
params.dx = dx
172+
params.verbose = False
168173
# make directory to save the data
169174
if folder_name is None:
170175
folder_name = os.getcwd()
@@ -237,25 +242,30 @@ def unsteady_plots(dx, Np_tracer, seed_xloc, seed_yloc, num_steps, timestep,
237242

238243
walk_data = particle.run_iteration(target_time=target_times[i])
239244

240-
x0, y0, t0 = get_state(walk_data, 0)
241245
xi, yi, ti = get_state(walk_data)
242-
243-
# make and save plots and data
244-
fig = plt.figure(dpi=200)
245-
ax = fig.add_subplot(111)
246-
ax.scatter(y0, x0, c='b', s=0.75)
247-
ax.scatter(yi, xi, c='r', s=0.75)
248-
im = ax.imshow(params.depth)
249-
plt.title('Depth at Time ' + str(target_times[i]))
250-
cax = fig.add_axes([ax.get_position().x1+0.01,
251-
ax.get_position().y0,
252-
0.02,
253-
ax.get_position().height])
254-
cbar = plt.colorbar(im, cax=cax)
255-
cbar.set_label('Water Depth [m]')
246+
if i == 0:
247+
x0, y0, t0 = get_state(walk_data, 0)
248+
# Initialize figure
249+
fig = plt.figure(dpi=200)
250+
ax = fig.add_subplot(111)
251+
im = ax.imshow(params.depth)
252+
cax = fig.add_axes([ax.get_position().x1+0.01,
253+
ax.get_position().y0,
254+
0.02,
255+
ax.get_position().height])
256+
cbar = plt.colorbar(im, cax=cax)
257+
cbar.set_label('Water Depth [m]')
258+
orig = ax.scatter(y0, x0, c='b', s=0.75)
259+
newloc = ax.scatter(yi, xi, c='r', s=0.75)
260+
else:
261+
# Update figure with new locations
262+
im.set_data(params.depth)
263+
im.set_clim(np.min(params.depth), np.max(params.depth))
264+
newloc.set_offsets(np.array([yi,xi]).T)
265+
plt.draw()
266+
ax.set_title('Depth at Time ' + str(target_times[i]))
256267
plt.savefig(folder_name+os.sep+'figs'+os.sep+'output'+str(i)+'.png',
257268
bbox_inches='tight')
258-
plt.close()
259269

260270
# save data as a json text file - technically human readable
261271
fpath = folder_name+os.sep+'data'+os.sep+'data.txt'
@@ -312,30 +322,36 @@ def time_plots(particle, num_iter, folder_name=None):
312322
walk_data = particle.run_iteration()
313323

314324
# collect latest travel times
315-
x0, y0, t0 = get_state(walk_data, 0)
316325
xi, yi, temptimes = get_state(walk_data)
317326

318-
# set colorbar using 10th and 90th percentile values
319-
cm = matplotlib.cm.colors.Normalize(vmax=np.percentile(temptimes, 90),
320-
vmin=np.percentile(temptimes, 10))
321-
322-
fig = plt.figure(dpi=200)
323-
ax = plt.gca()
324-
plt.title('Depth - Particle Iteration ' + str(i))
325-
ax.scatter(y0, x0, c='b', s=0.75)
326-
sc = ax.scatter(yi, xi, c=temptimes, s=0.75, cmap='coolwarm', norm=cm)
327-
divider = make_axes_locatable(ax)
328-
cax = divider.append_axes("right", size="5%", pad=0.05)
329-
cbar = plt.colorbar(sc, cax=cax)
330-
cbar.set_label('Particle Travel Times [s]')
331-
im = ax.imshow(particle.depth)
332-
divider = make_axes_locatable(ax)
333-
cax = divider.append_axes("bottom", size="5%", pad=0.5)
334-
cbar2 = plt.colorbar(im, cax=cax, orientation='horizontal')
335-
cbar2.set_label('Water Depth [m]')
327+
if i == 0:
328+
x0, y0, t0 = get_state(walk_data, 0)
329+
# Initialize figure
330+
fig = plt.figure(dpi=200)
331+
ax = plt.gca()
332+
im = ax.imshow(particle.depth)
333+
orig = ax.scatter(y0, x0, c='b', s=0.75)
334+
sc = ax.scatter(yi, xi, c=temptimes, s=0.75, cmap='coolwarm')
335+
sc.set_clim(np.percentile(temptimes,90),
336+
np.percentile(temptimes,10))
337+
divider = make_axes_locatable(ax)
338+
cax = divider.append_axes("right", size="5%", pad=0.05)
339+
cbar = plt.colorbar(sc, cax=cax)
340+
cbar.set_label('Particle Travel Times [s]')
341+
divider = make_axes_locatable(ax)
342+
cax = divider.append_axes("bottom", size="5%", pad=0.5)
343+
cbar2 = plt.colorbar(im, cax=cax, orientation='horizontal')
344+
cbar2.set_label('Water Depth [m]')
345+
else:
346+
# Update figure with new locations
347+
sc.set_offsets(np.array([yi,xi]).T) # Location
348+
sc.set_array(np.array(temptimes)) # Color values
349+
sc.set_clim(np.percentile(temptimes,90),
350+
np.percentile(temptimes,10)) # Color limits
351+
plt.draw()
352+
ax.set_title('Depth - Particle Iteration ' + str(i))
336353
plt.savefig(folder_name+os.sep+'figs'+os.sep+'output'+str(i)+'.png',
337354
bbox_inches='tight')
338-
plt.close()
339355

340356
# save data as a json text file - technically human readable
341357
fpath = folder_name+os.sep+'data'+os.sep+'data.txt'

0 commit comments

Comments
 (0)