@@ -183,6 +183,27 @@ def update_D(network: nx.DiGraph, i: str, j: str, D: dict) -> None:
183183 D [(i , j )] = [float ("inf" ), []]
184184 # print(f"There is no path between {i} and {j}")
185185
186+ def update_D_multitarget (network : nx .DiGraph , source : str , targets : list [str ], D : dict , reverse = False ) -> None :
187+ # adapted from multi_source_dijkstra
188+ paths = {source : [source ]}
189+ weight = lambda u , v , data : data .get (weight , 1 )
190+ dist = dijkstra_multisource_multitarget (network , {source }, weight , paths = paths , targets = targets )
191+
192+ for target in targets :
193+ if target in dist :
194+ path = paths [target ]
195+ if reverse :
196+ path .reverse ()
197+ D [(target , source )] = [dist [target ], paths [target ]]
198+ else :
199+ D [(source , target )] = [dist [target ], paths [target ]]
200+ else :
201+ if reverse :
202+ D [(target , source )] = [float ("inf" ), []]
203+ else :
204+ D [(source , target )] = [float ("inf" ), []]
205+ # print(f"There is no path between {i} and {j}")
206+
186207def add_path_to_P (path : list , P : nx .DiGraph ) -> None :
187208 for i in range (len (path ) - 1 ):
188209 P .add_edge (path [i ], path [i + 1 ])
@@ -245,7 +266,7 @@ def check_not_visited_not_visited(not_visited: list, D: dict) -> tuple:
245266 current_t = not_visited [i ]
246267 return current_path , current_s , current_t , min_value
247268
248- def BTB_main (network : nx .DiGraph , source : list , target : list ) -> nx .DiGraph :
269+ def BTB_main (network : nx .DiGraph , sources : list , targets : list ) -> nx .DiGraph :
249270 # We do this to do avoid re-implementing a reverse multi-target dijkstra. TODO: This is more
250271 # expensive on memory. Also see an issue on why we needed to implement a multi-target dijkstra:
251272 # https://github.com/networkx/networkx/issues/703.
@@ -254,8 +275,8 @@ def BTB_main(network: nx.DiGraph, source: list, target: list) -> nx.DiGraph:
254275 # P is the returned pathway
255276 P = nx .DiGraph ()
256277
257- P .add_nodes_from (source )
258- P .add_nodes_from (target )
278+ P .add_nodes_from (sources )
279+ P .add_nodes_from (targets )
259280
260281 weights = {}
261282 if not nx .is_weighted (network ):
@@ -285,20 +306,20 @@ def BTB_main(network: nx.DiGraph, source: list, target: list) -> nx.DiGraph:
285306 not_visited = []
286307 visited = []
287308
288- for i in source :
309+ for i in sources :
289310 not_visited .append (i )
290- for j in target :
311+ for j in targets :
291312 not_visited .append (j )
292313
293314 # D is the distance matrix
294315 # Format
295316 D = {}
296- for i in source :
317+ for i in sources :
297318 # run a single_source_dijsktra to find the shortest path from source to every other nodes
298319 # val is the shortest distance from source to every other nodes
299320 # path is the shortest path from source to every other nodes
300321 val , path = nx .single_source_dijkstra (network , i )
301- for j in target :
322+ for j in targets :
302323 # if there is a path between i and j, then add the distance and the path to D
303324 if j in val :
304325 D [i , j ] = [val [j ], path [j ]]
@@ -307,8 +328,8 @@ def BTB_main(network: nx.DiGraph, source: list, target: list) -> nx.DiGraph:
307328
308329 # print(f'Original D: {D}')
309330
310- # source_target is the union of source and target
311- source_target = source + target
331+ # sources_targets is the union of sources and targets
332+ sources_targets = sources + targets
312333
313334 # Index is for debugging (will be removed later)
314335 index = 1
@@ -367,14 +388,11 @@ def BTB_main(network: nx.DiGraph, source: list, target: list) -> nx.DiGraph:
367388
368389 # If we successfully extract the path, then update the distance matrix (step 5)
369390
370- # TODO: this is the slow part
371391 for i in current_path :
372- if i not in source_target :
392+ if i not in sources_targets :
373393 # Since D is a matrix from Source to Target, we need to update the distance from source to i and from i to target
374- for s in source :
375- update_D (network , s , i , D )
376- for t in target :
377- update_D (network , i , t , D )
394+ update_D_multitarget (network_reverse , i , sources , D , reverse = True )
395+ update_D_multitarget (network , i , targets , D )
378396 # Update the distance from i to i
379397 D [(i , i )] = [float ("inf" ), []]
380398
0 commit comments