Skip to content

Commit ff101c1

Browse files
committed
feat: multitarget dijkstra
1 parent 7bb1929 commit ff101c1

File tree

1 file changed

+33
-15
lines changed

1 file changed

+33
-15
lines changed

btb.py

Lines changed: 33 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,27 @@ def update_D(network: nx.DiGraph, i: str, j: str, D: dict) -> None:
183183
D[(i, j)] = [float("inf"), []]
184184
# print(f"There is no path between {i} and {j}")
185185

186+
def update_D_multitarget(network: nx.DiGraph, source: str, targets: list[str], D: dict, reverse=False) -> None:
187+
# adapted from multi_source_dijkstra
188+
paths = {source: [source]}
189+
weight = lambda u, v, data: data.get(weight, 1)
190+
dist = dijkstra_multisource_multitarget(network, {source}, weight, paths=paths, targets=targets)
191+
192+
for target in targets:
193+
if target in dist:
194+
path = paths[target]
195+
if reverse:
196+
path.reverse()
197+
D[(target, source)] = [dist[target], paths[target]]
198+
else:
199+
D[(source, target)] = [dist[target], paths[target]]
200+
else:
201+
if reverse:
202+
D[(target, source)] = [float("inf"), []]
203+
else:
204+
D[(source, target)] = [float("inf"), []]
205+
# print(f"There is no path between {i} and {j}")
206+
186207
def add_path_to_P(path: list, P: nx.DiGraph) -> None:
187208
for i in range(len(path) - 1):
188209
P.add_edge(path[i], path[i + 1])
@@ -245,7 +266,7 @@ def check_not_visited_not_visited(not_visited: list, D: dict) -> tuple:
245266
current_t = not_visited[i]
246267
return current_path, current_s, current_t, min_value
247268

248-
def BTB_main(network: nx.DiGraph, source: list, target: list) -> nx.DiGraph:
269+
def BTB_main(network: nx.DiGraph, sources: list, targets: list) -> nx.DiGraph:
249270
# We do this to do avoid re-implementing a reverse multi-target dijkstra. TODO: This is more
250271
# expensive on memory. Also see an issue on why we needed to implement a multi-target dijkstra:
251272
# https://github.com/networkx/networkx/issues/703.
@@ -254,8 +275,8 @@ def BTB_main(network: nx.DiGraph, source: list, target: list) -> nx.DiGraph:
254275
# P is the returned pathway
255276
P = nx.DiGraph()
256277

257-
P.add_nodes_from(source)
258-
P.add_nodes_from(target)
278+
P.add_nodes_from(sources)
279+
P.add_nodes_from(targets)
259280

260281
weights = {}
261282
if not nx.is_weighted(network):
@@ -285,20 +306,20 @@ def BTB_main(network: nx.DiGraph, source: list, target: list) -> nx.DiGraph:
285306
not_visited = []
286307
visited = []
287308

288-
for i in source:
309+
for i in sources:
289310
not_visited.append(i)
290-
for j in target:
311+
for j in targets:
291312
not_visited.append(j)
292313

293314
# D is the distance matrix
294315
# Format
295316
D = {}
296-
for i in source:
317+
for i in sources:
297318
# run a single_source_dijsktra to find the shortest path from source to every other nodes
298319
# val is the shortest distance from source to every other nodes
299320
# path is the shortest path from source to every other nodes
300321
val, path = nx.single_source_dijkstra(network, i)
301-
for j in target:
322+
for j in targets:
302323
# if there is a path between i and j, then add the distance and the path to D
303324
if j in val:
304325
D[i, j] = [val[j], path[j]]
@@ -307,8 +328,8 @@ def BTB_main(network: nx.DiGraph, source: list, target: list) -> nx.DiGraph:
307328

308329
# print(f'Original D: {D}')
309330

310-
# source_target is the union of source and target
311-
source_target = source + target
331+
# sources_targets is the union of sources and targets
332+
sources_targets = sources + targets
312333

313334
# Index is for debugging (will be removed later)
314335
index = 1
@@ -367,14 +388,11 @@ def BTB_main(network: nx.DiGraph, source: list, target: list) -> nx.DiGraph:
367388

368389
# If we successfully extract the path, then update the distance matrix (step 5)
369390

370-
# TODO: this is the slow part
371391
for i in current_path:
372-
if i not in source_target:
392+
if i not in sources_targets:
373393
# Since D is a matrix from Source to Target, we need to update the distance from source to i and from i to target
374-
for s in source:
375-
update_D(network, s, i, D)
376-
for t in target:
377-
update_D(network, i, t, D)
394+
update_D_multitarget(network_reverse, i, sources, D, reverse=True)
395+
update_D_multitarget(network, i, targets, D)
378396
# Update the distance from i to i
379397
D[(i, i)] = [float("inf"), []]
380398

0 commit comments

Comments
 (0)