11import networkx as nx
22import math
33import argparse
4+ from heapq import heappop , heappush
5+ from itertools import count
46from pathlib import Path
57
8+ # From networkx, adapted to use multiple targets
9+ def dijkstra_multisource_multitarget (
10+ G , sources , weight , pred = None , paths = None , cutoff = None , targets : list | None = None
11+ ):
12+ """Uses Dijkstra's algorithm to find shortest weighted paths
13+
14+ Parameters
15+ ----------
16+ G : NetworkX graph
17+
18+ sources : non-empty iterable of nodes
19+ Starting nodes for paths. If this is just an iterable containing
20+ a single node, then all paths computed by this function will
21+ start from that node. If there are two or more nodes in this
22+ iterable, the computed paths may begin from any one of the start
23+ nodes.
24+
25+ weight: function
26+ Function with (u, v, data) input that returns that edge's weight
27+ or None to indicate a hidden edge
28+
29+ pred: dict of lists, optional(default=None)
30+ dict to store a list of predecessors keyed by that node
31+ If None, predecessors are not stored.
32+
33+ paths: dict, optional (default=None)
34+ dict to store the path list from source to each node, keyed by node.
35+ If None, paths are not stored.
36+
37+ targets : list of node labels, optional
38+ Ending node for path. Search is halted when all targets are found.
39+
40+ cutoff : integer or float, optional
41+ Length (sum of edge weights) at which the search is stopped.
42+ If cutoff is provided, only return paths with summed weight <= cutoff.
43+
44+ Returns
45+ -------
46+ distance : dictionary
47+ A mapping from node to shortest distance to that node from one
48+ of the source nodes.
49+
50+ Raises
51+ ------
52+ NodeNotFound
53+ If any of `sources` is not in `G`.
54+
55+ Notes
56+ -----
57+ The optional predecessor and path dictionaries can be accessed by
58+ the caller through the original pred and paths objects passed
59+ as arguments. No need to explicitly return pred or paths.
60+
61+ """
62+ G_succ = G ._adj # For speed-up (and works for both directed and undirected graphs)
63+
64+ dist = {} # dictionary of final distances
65+ seen = {}
66+ # fringe is heapq with 3-tuples (distance,c,node)
67+ # use the count c to avoid comparing nodes (may not be able to)
68+ c = count ()
69+ fringe = []
70+ for source in sources :
71+ seen [source ] = 0
72+ heappush (fringe , (0 , next (c ), source ))
73+ while fringe :
74+ (d , _ , v ) = heappop (fringe )
75+ if v in dist :
76+ continue # already searched this node.
77+ dist [v ] = d
78+ if targets and v in targets :
79+ targets .remove (v )
80+ if len (targets ) == 0 :
81+ break
82+ for u , e in G_succ [v ].items ():
83+ cost = weight (v , u , e )
84+ if cost is None :
85+ continue
86+ vu_dist = dist [v ] + cost
87+ if cutoff is not None :
88+ if vu_dist > cutoff :
89+ continue
90+ if u in dist :
91+ u_dist = dist [u ]
92+ if vu_dist < u_dist :
93+ raise ValueError ("Contradictory paths found:" , "negative weights?" )
94+ elif pred is not None and vu_dist == u_dist :
95+ pred [u ].append (v )
96+ elif u not in seen or vu_dist < seen [u ]:
97+ seen [u ] = vu_dist
98+ heappush (fringe , (vu_dist , next (c ), u ))
99+ if paths is not None :
100+ paths [u ] = paths [v ] + [u ]
101+ if pred is not None :
102+ pred [u ] = [v ]
103+ elif vu_dist == seen [u ]:
104+ if pred is not None :
105+ pred [u ].append (v )
106+
107+ # The optional predecessor and path dictionaries can be accessed
108+ # by the caller via the pred and paths objects passed as arguments.
109+ return dist
6110
7111def parse_arguments ():
8112 """
@@ -44,23 +148,22 @@ def read_edges(network_file: Path) -> list:
44148 return network
45149
46150
47- def read_source_target (source_file : Path , target_file : Path ) -> tuple :
48- source = []
49- target = []
151+ def read_source_target (source_file : Path , target_file : Path ) -> tuple [ list [ str ], list [ str ]] :
152+ sources : list [ str ] = []
153+ targets : list [ str ] = []
50154 with open (source_file , "r" ) as f :
51155 for line in f :
52156 line = line .strip ()
53- source .append (line )
157+ sources .append (line )
54158 with open (target_file , "r" ) as f :
55159 for line in f :
56160 line = line .strip ()
57- target .append (line )
58- return source , target
161+ targets .append (line )
162+ return sources , targets
59163
60164
61165# functions for constructing the network
62- def construct_network (network : list , source : list , target : list ) -> nx .DiGraph :
63- print (network )
166+ def construct_network (network : list , source : list [str ], target : list [str ]) -> nx .DiGraph :
64167 Network = nx .DiGraph ()
65168 Network .add_weighted_edges_from (network )
66169 Network .add_nodes_from (source )
@@ -71,15 +174,15 @@ def construct_network(network: list, source: list, target: list) -> nx.DiGraph:
71174def update_D (network : nx .DiGraph , i : str , j : str , D : dict ) -> None :
72175 # check if there is a path between i and j
73176 if nx .has_path (network , i , j ):
177+ (length , path ) = nx .single_source_dijkstra (network , i , j )
74178 D [(i , j )] = [
75- nx . dijkstra_path_length ( network , i , j ) ,
76- nx . dijkstra_path ( network , i , j ) ,
179+ length ,
180+ path ,
77181 ]
78182 else :
79183 D [(i , j )] = [float ("inf" ), []]
80184 # print(f"There is no path between {i} and {j}")
81185
82-
83186def add_path_to_P (path : list , P : nx .DiGraph ) -> None :
84187 for i in range (len (path ) - 1 ):
85188 P .add_edge (path [i ], path [i + 1 ])
@@ -142,24 +245,29 @@ def check_not_visited_not_visited(not_visited: list, D: dict) -> tuple:
142245 current_t = not_visited [i ]
143246 return current_path , current_s , current_t , min_value
144247
145-
146- def BTB_main (Network : nx .DiGraph , source : list , target : list ) -> nx .DiGraph :
248+ def BTB_main (network : nx .DiGraph , source : list , target : list ) -> nx .DiGraph :
249+ # We do this to do avoid re-implementing a reverse multi-target dijkstra. TODO: This is more
250+ # expensive on memory. Also see an issue on why we needed to implement a multi-target dijkstra:
251+ # https://github.com/networkx/networkx/issues/703.
252+ network_reverse = network .reverse ()
253+
147254 # P is the returned pathway
148255 P = nx .DiGraph ()
256+
149257 P .add_nodes_from (source )
150258 P .add_nodes_from (target )
151259
152260 weights = {}
153- if not nx .is_weighted (Network ):
261+ if not nx .is_weighted (network ):
154262 # Set all weights to 1 if the network is unweighted
155- nx .set_edge_attributes (Network , values = 1 , name = "weight" )
263+ nx .set_edge_attributes (network , values = 1 , name = "weight" )
156264 print ("Original Network is unweighted. All weights set to 1." )
157- elif nx .is_weighted (Network , weight = 1 ):
158- weights = nx .get_edge_attributes (Network , "weight" )
159- nx .set_edge_attributes (Network , values = weights , name = "weight" )
265+ elif nx .is_weighted (network , weight = 1 ):
266+ weights = nx .get_edge_attributes (network , "weight" )
267+ nx .set_edge_attributes (network , values = weights , name = "weight" )
160268 print ("Original Network is unweighted" )
161269 else :
162- weights = nx .get_edge_attributes (Network , "weight" )
270+ weights = nx .get_edge_attributes (network , "weight" )
163271
164272 # Apply negative log transformation to each weight
165273 updated_weights = {
@@ -168,7 +276,7 @@ def BTB_main(Network: nx.DiGraph, source: list, target: list) -> nx.DiGraph:
168276 }
169277
170278 # Update the graph with the transformed weights
171- nx .set_edge_attributes (Network , values = updated_weights , name = "weight" )
279+ nx .set_edge_attributes (network , values = updated_weights , name = "weight" )
172280 # print(f'Original Weights: {weights}')
173281 # print(f'Transformed Weights: {updated_weights}')
174282
@@ -189,7 +297,7 @@ def BTB_main(Network: nx.DiGraph, source: list, target: list) -> nx.DiGraph:
189297 # run a single_source_dijsktra to find the shortest path from source to every other nodes
190298 # val is the shortest distance from source to every other nodes
191299 # path is the shortest path from source to every other nodes
192- val , path = nx .single_source_dijkstra (Network , i )
300+ val , path = nx .single_source_dijkstra (network , i )
193301 for j in target :
194302 # if there is a path between i and j, then add the distance and the path to D
195303 if j in val :
@@ -258,13 +366,15 @@ def BTB_main(Network: nx.DiGraph, source: list, target: list) -> nx.DiGraph:
258366 break
259367
260368 # If we successfully extract the path, then update the distance matrix (step 5)
369+
370+ # TODO: this is the slow part
261371 for i in current_path :
262372 if i not in source_target :
263373 # Since D is a matrix from Source to Target, we need to update the distance from source to i and from i to target
264374 for s in source :
265- update_D (Network , s , i , D )
375+ update_D (network , s , i , D )
266376 for t in target :
267- update_D (Network , i , t , D )
377+ update_D (network , i , t , D )
268378 # Update the distance from i to i
269379 D [(i , i )] = [float ("inf" ), []]
270380
@@ -292,7 +402,7 @@ def write_output(output_file, P):
292402 f .write (edge [0 ] + "\t " + edge [1 ] + "\n " )
293403
294404
295- def btb_wrapper (edges : Path , sources : Path , targets : Path , output_file : Path ):
405+ def btb_wrapper (edges : Path , sources_path : Path , targets_path : Path , output_file : Path ):
296406 """
297407 Run BowTieBuilder pathway reconstruction.
298408 @param edges: Path to the edge file
@@ -302,10 +412,10 @@ def btb_wrapper(edges: Path, sources: Path, targets: Path, output_file: Path):
302412 """
303413 if not edges .exists ():
304414 raise OSError (f"Edges file { str (edges )} does not exist" )
305- if not sources .exists ():
306- raise OSError (f"Sources file { str (sources )} does not exist" )
307- if not targets .exists ():
308- raise OSError (f"Targets file { str (targets )} does not exist" )
415+ if not sources_path .exists ():
416+ raise OSError (f"Sources file { str (sources_path )} does not exist" )
417+ if not targets_path .exists ():
418+ raise OSError (f"Targets file { str (targets_path )} does not exist" )
309419
310420 if output_file .exists ():
311421 print (f"Output files { str (output_file )} (nodes) will be overwritten" )
@@ -314,10 +424,12 @@ def btb_wrapper(edges: Path, sources: Path, targets: Path, output_file: Path):
314424 output_file .parent .mkdir (parents = True , exist_ok = True )
315425
316426 edge_list = read_edges (edges )
317- source , target = read_source_target (sources , targets )
318- network = construct_network (edge_list , source , target )
427+ sources , targets = read_source_target (sources_path , targets_path )
428+ network = construct_network (edge_list , sources , targets )
429+
430+ output_graph = BTB_main (network , sources , targets )
319431
320- write_output (output_file , BTB_main ( network , source , target ) )
432+ write_output (output_file , output_graph )
321433
322434
323435def main ():
0 commit comments