Skip to content

Commit 168d5ba

Browse files
committed
Add MetricDistanceFilter algorithm and update export info
1 parent 7a844c0 commit 168d5ba

File tree

2 files changed

+163
-0
lines changed

2 files changed

+163
-0
lines changed

graphconstructor/operators/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from .minimum_spanning_tree import MinimumSpanningTree
88
from .noise_corrected import NoiseCorrected
99
from .weight_threshold import WeightThreshold
10+
from .metric_distance import MetricDistanceFilter
1011

1112

1213
__all__ = [
@@ -16,6 +17,7 @@
1617
"KNNSelector",
1718
"LocallyAdaptiveSparsification",
1819
"MarginalLikelihoodFilter",
20+
"MetricDistanceFilter",
1921
"MinimumSpanningTree",
2022
"NoiseCorrected",
2123
"WeightThreshold",
Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
from dataclasses import dataclass
2+
from ..graph import Graph
3+
from .base import GraphOperator
4+
from distanceclosure.dijkstra import single_source_dijkstra_path_length
5+
from networkx.algorithms.shortest_paths.weighted import _weight_function
6+
from heapq import heappush, heappop
7+
from itertools import count
8+
import networkx as nx
9+
from typing import Literal
10+
11+
def single_source_dijkstra_path_length(G, source, weight_function, paths=None, disjunction=sum):
12+
"""Uses (a custom) Dijkstra's algorithm to find shortest weighted paths
13+
14+
Parameters
15+
----------
16+
G : NetworkX graph
17+
18+
source : node
19+
Starting node for path.
20+
21+
weight_function: function
22+
Function with (u, v, data) input that returns that edges weight
23+
24+
paths: dict, optional (default=None)
25+
dict to store the path list from source to each node, keyed by node.
26+
If None, paths are not stored.
27+
28+
disjunction: function (default=sum)
29+
Whether to sum paths or use the max value.
30+
Use `sum` for metric and `max` for ultrametric.
31+
32+
Returns
33+
-------
34+
distance : dictionary
35+
A mapping from node to shortest distance to that node from one
36+
of the source nodes.
37+
38+
Raises
39+
------
40+
NodeNotFound
41+
If `source` is not in `G`.
42+
43+
Note
44+
-----
45+
The optional predecessor and path dictionaries can be accessed by
46+
the caller through the original paths objects passed
47+
as arguments. No need to explicitly return paths.
48+
49+
"""
50+
G_succ = G._succ if G.is_directed() else G._adj
51+
52+
push = heappush
53+
pop = heappop
54+
dist = {} # dictionary of final distances
55+
seen = {}
56+
# fringe is heapq with 3-tuples (distance,c,node)
57+
# use the count c to avoid comparing nodes (may not be able to)
58+
c = count()
59+
fringe = []
60+
if source not in G:
61+
raise nx.NodeNotFound(f"Source {source} not in G")
62+
seen[source] = 0
63+
push(fringe, (0, next(c), source))
64+
while fringe:
65+
(d, _, v) = pop(fringe)
66+
if v in dist:
67+
continue # already searched this node.
68+
dist[v] = d
69+
for u, e in G_succ[v].items():
70+
cost = weight_function(v, u, e)
71+
if cost is None:
72+
continue
73+
vu_dist = disjunction([dist[v], cost])
74+
if u in dist:
75+
u_dist = dist[u]
76+
if vu_dist < u_dist:
77+
raise ValueError("Contradictory paths found:", "negative weights?")
78+
elif u not in seen or vu_dist < seen[u]:
79+
seen[u] = vu_dist
80+
push(fringe, (vu_dist, next(c), u))
81+
if paths is not None:
82+
paths[u] = paths[v] + [u]
83+
return dist
84+
85+
Mode = Literal["distance", "similarity"]
86+
87+
@dataclass(slots=True)
88+
class MetricDistanceFilter(GraphOperator):
89+
"""
90+
Metric Distance Backbone Filter for similarity graphs.
91+
Code: https://github.com/CASCI-lab/distanceclosure/blob/master/distanceclosure/backbone.py
92+
93+
Parameters
94+
----------
95+
weight : str, optional
96+
Edge property containing distance values, by default 'weight'
97+
distortion : bool, optional
98+
Whether to compute and return distortion values, by default False
99+
verbose : bool, optional
100+
Prints statements as it computes, by default False
101+
"""
102+
weight: str = 'weight'
103+
distortion: bool = False
104+
verbose: bool = False
105+
mode: Mode = "distance"
106+
supported_modes = ["distance", "similarity"]
107+
108+
@staticmethod
109+
def _compute_distortions(D: GraphOperator, B, weight='weight', disjunction=sum):
110+
G = D.copy()
111+
112+
G.remove_edges_from(B.edges())
113+
weight_function = _weight_function(B, weight)
114+
115+
svals = dict()
116+
for u in G.nodes():
117+
metric_dist = single_source_dijkstra_path_length(B, source=u, weight_function=weight_function, disjunction=disjunction)
118+
for v in G.neighbors(u):
119+
svals[(u, v)] = G[u][v][weight]/metric_dist[v]
120+
121+
return svals
122+
123+
def _directed_filter(self, G: Graph) -> Graph:
124+
pass
125+
126+
def _undirected_filter(self, D):
127+
disjunction = sum
128+
129+
D = D.to_networkx()
130+
G = D.copy()
131+
weight_function = _weight_function(G, self.weight)
132+
133+
if self.verbose:
134+
total = G.number_of_nodes()
135+
i = 0
136+
137+
for u, _ in sorted(G.degree(weight=self.weight), key=lambda x: x[1]):
138+
if self.verbose:
139+
i += 1
140+
per = i/total
141+
print("Backbone: Dijkstra: {i:d} of {total:d} ({per:.2%})".format(i=i, total=total, per=per))
142+
143+
metric_dist = single_source_dijkstra_path_length(G, source=u, weight_function=weight_function, disjunction=disjunction)
144+
for v in list(G.neighbors(u)):
145+
if metric_dist[v] < G[u][v][self.weight]:
146+
G.remove_edge(u, v)
147+
148+
149+
sparse_adj = nx.to_scipy_sparse_array(G)
150+
if self.distortion:
151+
svals = self._compute_distortions(D, G, weight=self.weight, disjunction=disjunction)
152+
return Graph(sparse_adj, False, True, self.mode), svals
153+
else:
154+
return Graph(sparse_adj, False, True, self.mode)
155+
156+
def apply(self, G: Graph) -> Graph:
157+
self._check_mode_supported(G)
158+
if G.directed:
159+
return self._directed_filter(G)
160+
else:
161+
return self._undirected_filter(G)

0 commit comments

Comments
 (0)