Skip to content

Commit 7959c60

Browse files
committed
add tests for MetricDistanceFilter
1 parent 9e5a1ce commit 7959c60

File tree

1 file changed

+91
-0
lines changed

1 file changed

+91
-0
lines changed

tests/test_metric_distance.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
import numpy as np
2+
import networkx as nx
3+
import scipy.sparse as sp
4+
from graphconstructor import Graph
5+
from graphconstructor.operators import MetricDistanceFilter
6+
7+
def _csr(data, rows, cols, n):
8+
return sp.csr_matrix(
9+
(np.asarray(data, float), (np.asarray(rows, int), np.asarray(cols, int))),
10+
shape=(n, n),
11+
)
12+
13+
def simple_undirected_graph():
14+
A = _csr(
15+
data=[0.5, 0.5, 0.3, 0.3, 0.8, 0.8],
16+
rows=[0, 1, 0, 2, 1, 2],
17+
cols=[1, 0, 2, 0, 2, 1],
18+
n=3,
19+
)
20+
21+
return Graph.from_csr(A, directed=False, weighted=True, mode="similarity")
22+
23+
def test_basic_undirected_filtering():
24+
G0 = simple_undirected_graph()
25+
26+
out = MetricDistanceFilter(distortion=False, verbose=False).apply(G0)
27+
28+
assert isinstance(out, Graph)
29+
assert out.directed == False
30+
assert out.weighted == True
31+
32+
original_edges = G0.to_networkx().number_of_edges()
33+
result_edges = out.to_networkx().number_of_edges()
34+
assert result_edges <= original_edges
35+
36+
def test_undirected_filtering_distortion():
37+
G0 = simple_undirected_graph()
38+
39+
out = MetricDistanceFilter(distortion=True, verbose=False).apply(G0)
40+
41+
assert isinstance(out, tuple)
42+
assert len(out) == 2
43+
44+
filtered_graph, svals = out
45+
assert isinstance(filtered_graph, Graph)
46+
assert isinstance(svals, dict)
47+
48+
if svals:
49+
key = next(iter(svals.keys()))
50+
assert isinstance(key, tuple)
51+
assert len(key) == 2
52+
53+
def test_directed_graph_not_implemented():
54+
G0 = simple_undirected_graph()
55+
out = MetricDistanceFilter().apply(G0)
56+
assert out is None
57+
58+
def test_edge_removal_logic():
59+
G0 = simple_undirected_graph()
60+
out = MetricDistanceFilter().apply(G0)
61+
62+
original_nx = G0.to_networkx()
63+
out_nx = G0.to_networkx()
64+
65+
assert out_nx.number_of_edges() <= original_nx.number_of_edges()
66+
67+
if nx.is_connected(original_nx):
68+
assert nx.is_connected(out_nx)
69+
70+
def test_isolated_nodes():
71+
A = _csr(
72+
data=[0.5, 0.5],
73+
rows=[0, 1],
74+
cols=[1, 0],
75+
n=3,
76+
)
77+
G0 = Graph.from_csr(A, directed=False, weighted=True, mode="distance")
78+
out = MetricDistanceFilter().apply(G0)
79+
80+
assert out.to_networkx().number_of_nodes() == 3
81+
assert 2 in out.to_networkx().nodes()
82+
83+
def test_empty_graph():
84+
A = _csr(data=[], rows=[], cols=[], n=3)
85+
G0 = Graph.from_csr(A, directed=False, weighted=True, mode="distance")
86+
87+
out = MetricDistanceFilter().apply(G0)
88+
89+
assert out.to_networkx().number_of_edges() == 0
90+
assert out.to_networkx().number_of_nodes() == 3
91+

0 commit comments

Comments
 (0)