Skip to content
This repository was archived by the owner on Sep 4, 2021. It is now read-only.

Commit d15d031

Browse files
committed
speed up finding valid point pairs
1 parent 2f6d052 commit d15d031

File tree

1 file changed

+36
-7
lines changed

1 file changed

+36
-7
lines changed

circuitscape/compute_base.py

Lines changed: 36 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ def __init__(self, configFile, ext_log_handler):
1818
#gc.set_debug(gc.DEBUG_STATS | gc.DEBUG_UNCOLLECTABLE | gc.DEBUG_SAVEALL)
1919
np.seterr(invalid='ignore')
2020
np.seterr(divide='ignore')
21+
np.set_printoptions(linewidth=150)
2122

2223
self.state = CSState()
2324
self.options = CSConfig(configFile)
@@ -234,17 +235,35 @@ def has_pair(self, r, c):
234235
if not any(has_rows):
235236
return 0
236237
valid_data = self.mat.data[has_rows]
237-
return sum(valid_data[has_cols])
238+
return np.sum(valid_data[has_cols])
239+
240+
def get_possible_pair(self, r):
241+
pt2list = np.array([])
242+
243+
if (r >= self.max_id):
244+
return pt2list
245+
246+
has_ids = np.where(self.mat.row == r)
247+
if np.any(has_ids):
248+
pt2list = np.append(pt2list, self.mat.col[has_ids])
249+
250+
has_ids = np.where(self.mat.col == r)
251+
if np.any(has_ids):
252+
pt2list = np.append(pt2list, self.mat.row[has_ids])
253+
254+
pt2list = np.unique(pt2list)
255+
return pt2list[np.where(pt2list > r)]
256+
238257

239258
def has(self, r):
240259
if not (r < self.max_id):
241260
return 0
242261
has_rows = (self.mat.row == r)
243262
has_cols = (self.mat.col == r)
244263
has_any = has_rows | has_cols
245-
if not any(has_any):
264+
if not np.any(has_any):
246265
return 0
247-
return sum(self.mat.data[has_any])
266+
return np.sum(self.mat.data[has_any])
248267

249268
def is_included_pair(self, point_id1, point_id2):
250269
return ((self.has_pair(point_id1, point_id2) + self.has_pair(point_id2, point_id1)) > 0) == self.is_include
@@ -467,9 +486,17 @@ def point_pair_idxs_in_component(self, comp, habitat):
467486
dst = self.get_graph_node_idx(pt1_idx, node_map)
468487
if (dst < 0 or components[dst] != comp):
469488
continue
470-
for pt2_idx in range(pt1_idx+1, numpoints):
471-
if (None != self.incl_pairs) and not self.incl_pairs.is_included_pair(self.points_rc[pt1_idx, 0], self.points_rc[pt2_idx, 0]):
472-
continue
489+
if (None != self.incl_pairs):
490+
ccv = self.points_rc[:, 0]
491+
inc = np.array([])
492+
for pt in self.incl_pairs.get_possible_pair(self.points_rc[pt1_idx, 0]):
493+
cc = np.where(ccv == pt)
494+
if len(cc) > 0:
495+
inc = np.append(inc, cc[0])
496+
else:
497+
inc = range(pt1_idx+1, numpoints)
498+
499+
for pt2_idx in inc:
473500
src = self.get_graph_node_idx(pt2_idx, node_map)
474501
if (src >= 0 and components[src] == comp):
475502
yield (pt1_idx, pt2_idx)
@@ -620,13 +647,15 @@ def _make_sparse_csr(node1, node2, conductances, numnodes):
620647

621648
@staticmethod
622649
def _neighbors_horiz(g_map):
623-
"""Returns values of horizontal neighbors in conductance map."""
650+
"""Returns values of horizontal neighbors in conductance map."""
651+
624652
m = g_map.shape[0]
625653
n = g_map.shape[1]
626654

627655
g_map_l = g_map[:, 0:(n-1)]
628656
g_map_r = g_map[:, 1:n]
629657
g_map_lr = np.double(np.logical_and(g_map_l, g_map_r))
658+
630659
s_horiz = np.where(np.c_[g_map_lr, np.zeros((m,1), dtype='int32')].flatten())
631660
t_horiz = np.where(np.c_[np.zeros((m,1), dtype='int32'), g_map_lr].flatten())
632661

0 commit comments

Comments
 (0)