Skip to content

Commit 35bd833

Browse files
committed
fix linting
1 parent 2fd70be commit 35bd833

File tree

3 files changed

+48
-36
lines changed

3 files changed

+48
-36
lines changed

graphconstructor/operators/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@
44
from .knn_selector import KNNSelector
55
from .locally_adaptive_sparsification import LocallyAdaptiveSparsification
66
from .marginal_likelihood import MarginalLikelihoodFilter
7+
from .metric_distance import MetricDistanceFilter
78
from .minimum_spanning_tree import MinimumSpanningTree
89
from .noise_corrected import NoiseCorrected
910
from .weight_threshold import WeightThreshold
10-
from .metric_distance import MetricDistanceFilter
1111

1212

1313
__all__ = [
Lines changed: 31 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
from dataclasses import dataclass
2-
from ..graph import Graph
3-
from .base import GraphOperator
2+
from typing import Literal
3+
import networkx as nx
44
from distanceclosure.dijkstra import single_source_dijkstra_path_length
55
from networkx.algorithms.shortest_paths.weighted import _weight_function
6-
import networkx as nx
7-
from typing import Literal
6+
from ..graph import Graph
7+
from .base import GraphOperator
8+
89

910
Mode = Literal["distance", "similarity"]
1011

12+
1113
@dataclass(slots=True)
1214
class MetricDistanceFilter(GraphOperator):
1315
"""
@@ -23,65 +25,67 @@ class MetricDistanceFilter(GraphOperator):
2325
verbose : bool, optional
2426
Prints statements as it computes, by default False
2527
"""
26-
weight: str = 'weight'
28+
29+
weight: str = "weight"
2730
distortion: bool = False
2831
verbose: bool = False
2932
mode: Mode = "distance"
3033
supported_modes = ["distance", "similarity"]
3134

3235
@staticmethod
33-
def _compute_distortions(D: GraphOperator, B, weight='weight', disjunction=sum):
36+
def _compute_distortions(D: GraphOperator, B, weight="weight", disjunction=sum):
3437
G = D.copy()
35-
38+
3639
G.remove_edges_from(B.edges())
3740
weight_function = _weight_function(B, weight)
3841

39-
svals = dict()
42+
svals = dict()
4043
for u in G.nodes():
41-
metric_dist = single_source_dijkstra_path_length(B, source=u, weight_function=weight_function, disjunction=disjunction)
44+
metric_dist = single_source_dijkstra_path_length(
45+
B, source=u, weight_function=weight_function, disjunction=disjunction
46+
)
4247
for v in G.neighbors(u):
43-
svals[(u, v)] = G[u][v][weight]/metric_dist[v]
44-
48+
svals[(u, v)] = G[u][v][weight] / metric_dist[v]
49+
4550
return svals
46-
51+
4752
def _directed_filter(self, G: Graph) -> Graph:
48-
raise NotImplementedError(
49-
"MetricDistanceFilter is defined only for undirected graphs."
50-
)
51-
52-
def _undirected_filter(self, D):
53+
raise NotImplementedError("MetricDistanceFilter is defined only for undirected graphs.")
54+
55+
def _undirected_filter(self, D):
5356
disjunction = sum
54-
57+
5558
D = D.to_networkx()
5659
G = D.copy()
5760
weight_function = _weight_function(G, self.weight)
58-
61+
5962
if self.verbose:
6063
total = G.number_of_nodes()
6164
i = 0
62-
65+
6366
for u, _ in sorted(G.degree(weight=self.weight), key=lambda x: x[1]):
6467
if self.verbose:
6568
i += 1
66-
per = i/total
69+
per = i / total
6770
print("Backbone: Dijkstra: {i:d} of {total:d} ({per:.2%})".format(i=i, total=total, per=per))
68-
69-
metric_dist = single_source_dijkstra_path_length(G, source=u, weight_function=weight_function, disjunction=disjunction)
71+
72+
metric_dist = single_source_dijkstra_path_length(
73+
G, source=u, weight_function=weight_function, disjunction=disjunction
74+
)
7075
for v in list(G.neighbors(u)):
7176
if metric_dist[v] < G[u][v][self.weight]:
7277
G.remove_edge(u, v)
73-
7478

7579
sparse_adj = nx.to_scipy_sparse_array(G)
7680
if self.distortion:
77-
svals = self._compute_distortions(D, G, weight=self.weight, disjunction=disjunction)
81+
svals = self._compute_distortions(D, G, weight=self.weight, disjunction=disjunction)
7882
return Graph(sparse_adj, False, True, self.mode), svals
7983
else:
8084
return Graph(sparse_adj, False, True, self.mode)
81-
85+
8286
def apply(self, G: Graph) -> Graph:
8387
self._check_mode_supported(G)
8488
if G.directed:
8589
return self._directed_filter(G)
8690
else:
87-
return self._undirected_filter(G)
91+
return self._undirected_filter(G)

tests/test_metric_distance.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,18 @@
1-
import numpy as np
21
import networkx as nx
3-
import scipy.sparse as sp
2+
import numpy as np
43
import pytest
4+
import scipy.sparse as sp
55
from graphconstructor import Graph
66
from graphconstructor.operators import MetricDistanceFilter
77

8+
89
def _csr(data, rows, cols, n):
910
return sp.csr_matrix(
1011
(np.asarray(data, float), (np.asarray(rows, int), np.asarray(cols, int))),
1112
shape=(n, n),
1213
)
1314

15+
1416
def simple_undirected_graph():
1517
A = _csr(
1618
data=[0.5, 0.5, 0.3, 0.3, 0.8, 0.8],
@@ -21,6 +23,7 @@ def simple_undirected_graph():
2123

2224
return Graph.from_csr(A, directed=False, weighted=True, mode="similarity")
2325

26+
2427
def simple_directed_graph():
2528
A = _csr(
2629
data=[0.5, 0.5, 0.3],
@@ -31,19 +34,21 @@ def simple_directed_graph():
3134

3235
return Graph.from_csr(A, directed=True, weighted=True, mode="similarity")
3336

37+
3438
def test_basic_undirected_filtering():
3539
G0 = simple_undirected_graph()
3640

3741
out = MetricDistanceFilter(distortion=False, verbose=False).apply(G0)
3842

3943
assert isinstance(out, Graph)
40-
assert out.directed == False
41-
assert out.weighted == True
44+
assert not out.directed
45+
assert out.weighted
4246

4347
original_edges = G0.to_networkx().number_of_edges()
4448
result_edges = out.to_networkx().number_of_edges()
4549
assert result_edges <= original_edges
46-
50+
51+
4752
def test_undirected_filtering_distortion():
4853
G0 = simple_undirected_graph()
4954

@@ -55,29 +60,32 @@ def test_undirected_filtering_distortion():
5560
filtered_graph, svals = out
5661
assert isinstance(filtered_graph, Graph)
5762
assert isinstance(svals, dict)
58-
63+
5964
if svals:
6065
key = next(iter(svals.keys()))
6166
assert isinstance(key, tuple)
6267
assert len(key) == 2
6368

69+
6470
def test_directed_graph_not_implemented():
6571
G0 = simple_directed_graph()
6672
with pytest.raises(NotImplementedError):
6773
MetricDistanceFilter().apply(G0)
6874

75+
6976
def test_edge_removal_logic():
7077
G0 = simple_undirected_graph()
7178
out = MetricDistanceFilter().apply(G0)
7279

7380
original_nx = G0.to_networkx()
74-
out_nx = G0.to_networkx()
81+
out_nx = out.to_networkx()
7582

7683
assert out_nx.number_of_edges() <= original_nx.number_of_edges()
7784

7885
if nx.is_connected(original_nx):
7986
assert nx.is_connected(out_nx)
8087

88+
8189
def test_isolated_nodes():
8290
A = _csr(
8391
data=[0.5, 0.5],
@@ -91,6 +99,7 @@ def test_isolated_nodes():
9199
assert out.to_networkx().number_of_nodes() == 3
92100
assert 2 in out.to_networkx().nodes()
93101

102+
94103
def test_empty_graph():
95104
A = _csr(data=[], rows=[], cols=[], n=3)
96105
G0 = Graph.from_csr(A, directed=False, weighted=True, mode="distance")
@@ -99,4 +108,3 @@ def test_empty_graph():
99108

100109
assert out.to_networkx().number_of_edges() == 0
101110
assert out.to_networkx().number_of_nodes() == 3
102-

0 commit comments

Comments
 (0)