1+ from dataclasses import dataclass
2+ from ..graph import Graph
3+ from .base import GraphOperator
4+ from distanceclosure .dijkstra import single_source_dijkstra_path_length
5+ from networkx .algorithms .shortest_paths .weighted import _weight_function
6+ from heapq import heappush , heappop
7+ from itertools import count
8+ import networkx as nx
9+ from typing import Literal
10+
11+ def single_source_dijkstra_path_length (G , source , weight_function , paths = None , disjunction = sum ):
12+ """Uses (a custom) Dijkstra's algorithm to find shortest weighted paths
13+
14+ Parameters
15+ ----------
16+ G : NetworkX graph
17+
18+ source : node
19+ Starting node for path.
20+
21+ weight_function: function
22+ Function with (u, v, data) input that returns that edges weight
23+
24+ paths: dict, optional (default=None)
25+ dict to store the path list from source to each node, keyed by node.
26+ If None, paths are not stored.
27+
28+ disjunction: function (default=sum)
29+ Whether to sum paths or use the max value.
30+ Use `sum` for metric and `max` for ultrametric.
31+
32+ Returns
33+ -------
34+ distance : dictionary
35+ A mapping from node to shortest distance to that node from one
36+ of the source nodes.
37+
38+ Raises
39+ ------
40+ NodeNotFound
41+ If `source` is not in `G`.
42+
43+ Note
44+ -----
45+ The optional predecessor and path dictionaries can be accessed by
46+ the caller through the original paths objects passed
47+ as arguments. No need to explicitly return paths.
48+
49+ """
50+ G_succ = G ._succ if G .is_directed () else G ._adj
51+
52+ push = heappush
53+ pop = heappop
54+ dist = {} # dictionary of final distances
55+ seen = {}
56+ # fringe is heapq with 3-tuples (distance,c,node)
57+ # use the count c to avoid comparing nodes (may not be able to)
58+ c = count ()
59+ fringe = []
60+ if source not in G :
61+ raise nx .NodeNotFound (f"Source { source } not in G" )
62+ seen [source ] = 0
63+ push (fringe , (0 , next (c ), source ))
64+ while fringe :
65+ (d , _ , v ) = pop (fringe )
66+ if v in dist :
67+ continue # already searched this node.
68+ dist [v ] = d
69+ for u , e in G_succ [v ].items ():
70+ cost = weight_function (v , u , e )
71+ if cost is None :
72+ continue
73+ vu_dist = disjunction ([dist [v ], cost ])
74+ if u in dist :
75+ u_dist = dist [u ]
76+ if vu_dist < u_dist :
77+ raise ValueError ("Contradictory paths found:" , "negative weights?" )
78+ elif u not in seen or vu_dist < seen [u ]:
79+ seen [u ] = vu_dist
80+ push (fringe , (vu_dist , next (c ), u ))
81+ if paths is not None :
82+ paths [u ] = paths [v ] + [u ]
83+ return dist
84+
85+ Mode = Literal ["distance" , "similarity" ]
86+
87+ @dataclass (slots = True )
88+ class MetricDistanceFilter (GraphOperator ):
89+ """
90+ Metric Distance Backbone Filter for similarity graphs.
91+ Code: https://github.com/CASCI-lab/distanceclosure/blob/master/distanceclosure/backbone.py
92+
93+ Parameters
94+ ----------
95+ weight : str, optional
96+ Edge property containing distance values, by default 'weight'
97+ distortion : bool, optional
98+ Whether to compute and return distortion values, by default False
99+ verbose : bool, optional
100+ Prints statements as it computes, by default False
101+ """
102+ weight : str = 'weight'
103+ distortion : bool = False
104+ verbose : bool = False
105+ mode : Mode = "distance"
106+ supported_modes = ["distance" , "similarity" ]
107+
108+ @staticmethod
109+ def _compute_distortions (D : GraphOperator , B , weight = 'weight' , disjunction = sum ):
110+ G = D .copy ()
111+
112+ G .remove_edges_from (B .edges ())
113+ weight_function = _weight_function (B , weight )
114+
115+ svals = dict ()
116+ for u in G .nodes ():
117+ metric_dist = single_source_dijkstra_path_length (B , source = u , weight_function = weight_function , disjunction = disjunction )
118+ for v in G .neighbors (u ):
119+ svals [(u , v )] = G [u ][v ][weight ]/ metric_dist [v ]
120+
121+ return svals
122+
123+ def _directed_filter (self , G : Graph ) -> Graph :
124+ pass
125+
126+ def _undirected_filter (self , D ):
127+ disjunction = sum
128+
129+ D = D .to_networkx ()
130+ G = D .copy ()
131+ weight_function = _weight_function (G , self .weight )
132+
133+ if self .verbose :
134+ total = G .number_of_nodes ()
135+ i = 0
136+
137+ for u , _ in sorted (G .degree (weight = self .weight ), key = lambda x : x [1 ]):
138+ if self .verbose :
139+ i += 1
140+ per = i / total
141+ print ("Backbone: Dijkstra: {i:d} of {total:d} ({per:.2%})" .format (i = i , total = total , per = per ))
142+
143+ metric_dist = single_source_dijkstra_path_length (G , source = u , weight_function = weight_function , disjunction = disjunction )
144+ for v in list (G .neighbors (u )):
145+ if metric_dist [v ] < G [u ][v ][self .weight ]:
146+ G .remove_edge (u , v )
147+
148+
149+ sparse_adj = nx .to_scipy_sparse_array (G )
150+ if self .distortion :
151+ svals = self ._compute_distortions (D , G , weight = self .weight , disjunction = disjunction )
152+ return Graph (sparse_adj , False , True , self .mode ), svals
153+ else :
154+ return Graph (sparse_adj , False , True , self .mode )
155+
156+ def apply (self , G : Graph ) -> Graph :
157+ self ._check_mode_supported (G )
158+ if G .directed :
159+ return self ._directed_filter (G )
160+ else :
161+ return self ._undirected_filter (G )
0 commit comments