@@ -18,6 +18,7 @@ def __init__(self, configFile, ext_log_handler):
18
18
#gc.set_debug(gc.DEBUG_STATS | gc.DEBUG_UNCOLLECTABLE | gc.DEBUG_SAVEALL)
19
19
np .seterr (invalid = 'ignore' )
20
20
np .seterr (divide = 'ignore' )
21
+ np .set_printoptions (linewidth = 150 )
21
22
22
23
self .state = CSState ()
23
24
self .options = CSConfig (configFile )
@@ -234,17 +235,35 @@ def has_pair(self, r, c):
234
235
if not any (has_rows ):
235
236
return 0
236
237
valid_data = self .mat .data [has_rows ]
237
- return sum (valid_data [has_cols ])
238
+ return np .sum (valid_data [has_cols ])
239
+
240
+ def get_possible_pair (self , r ):
241
+ pt2list = np .array ([])
242
+
243
+ if (r >= self .max_id ):
244
+ return pt2list
245
+
246
+ has_ids = np .where (self .mat .row == r )
247
+ if np .any (has_ids ):
248
+ pt2list = np .append (pt2list , self .mat .col [has_ids ])
249
+
250
+ has_ids = np .where (self .mat .col == r )
251
+ if np .any (has_ids ):
252
+ pt2list = np .append (pt2list , self .mat .row [has_ids ])
253
+
254
+ pt2list = np .unique (pt2list )
255
+ return pt2list [np .where (pt2list > r )]
256
+
238
257
239
258
def has (self , r ):
240
259
if not (r < self .max_id ):
241
260
return 0
242
261
has_rows = (self .mat .row == r )
243
262
has_cols = (self .mat .col == r )
244
263
has_any = has_rows | has_cols
245
- if not any (has_any ):
264
+ if not np . any (has_any ):
246
265
return 0
247
- return sum (self .mat .data [has_any ])
266
+ return np . sum (self .mat .data [has_any ])
248
267
249
268
def is_included_pair (self , point_id1 , point_id2 ):
250
269
return ((self .has_pair (point_id1 , point_id2 ) + self .has_pair (point_id2 , point_id1 )) > 0 ) == self .is_include
@@ -467,9 +486,17 @@ def point_pair_idxs_in_component(self, comp, habitat):
467
486
dst = self .get_graph_node_idx (pt1_idx , node_map )
468
487
if (dst < 0 or components [dst ] != comp ):
469
488
continue
470
- for pt2_idx in range (pt1_idx + 1 , numpoints ):
471
- if (None != self .incl_pairs ) and not self .incl_pairs .is_included_pair (self .points_rc [pt1_idx , 0 ], self .points_rc [pt2_idx , 0 ]):
472
- continue
489
+ if (None != self .incl_pairs ):
490
+ ccv = self .points_rc [:, 0 ]
491
+ inc = np .array ([])
492
+ for pt in self .incl_pairs .get_possible_pair (self .points_rc [pt1_idx , 0 ]):
493
+ cc = np .where (ccv == pt )
494
+ if len (cc ) > 0 :
495
+ inc = np .append (inc , cc [0 ])
496
+ else :
497
+ inc = range (pt1_idx + 1 , numpoints )
498
+
499
+ for pt2_idx in inc :
473
500
src = self .get_graph_node_idx (pt2_idx , node_map )
474
501
if (src >= 0 and components [src ] == comp ):
475
502
yield (pt1_idx , pt2_idx )
@@ -620,13 +647,15 @@ def _make_sparse_csr(node1, node2, conductances, numnodes):
620
647
621
648
@staticmethod
622
649
def _neighbors_horiz (g_map ):
623
- """Returns values of horizontal neighbors in conductance map."""
650
+ """Returns values of horizontal neighbors in conductance map."""
651
+
624
652
m = g_map .shape [0 ]
625
653
n = g_map .shape [1 ]
626
654
627
655
g_map_l = g_map [:, 0 :(n - 1 )]
628
656
g_map_r = g_map [:, 1 :n ]
629
657
g_map_lr = np .double (np .logical_and (g_map_l , g_map_r ))
658
+
630
659
s_horiz = np .where (np .c_ [g_map_lr , np .zeros ((m ,1 ), dtype = 'int32' )].flatten ())
631
660
t_horiz = np .where (np .c_ [np .zeros ((m ,1 ), dtype = 'int32' ), g_map_lr ].flatten ())
632
661
0 commit comments