Skip to content

Commit 51dc6a4

Browse files
committed
add stricter tests on backboning
1 parent 8a13364 commit 51dc6a4

File tree

1 file changed

+69
-0
lines changed

1 file changed

+69
-0
lines changed

tests/test_enhanced_configuration_model.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,3 +236,72 @@ def test_neg_log_likelihood_grad_returns_finite_arrays():
236236
assert grad_y.shape == y.shape
237237
assert np.all(np.isfinite(grad_x))
238238
assert np.all(np.isfinite(grad_y))
239+
240+
241+
@pytest.mark.parametrize("alphas", [
242+
[0.001, 0.01, 0.05, 0.1, 0.5, 1.0],
243+
])
244+
def test_ecm_retained_edges_monotone_in_alpha(small_undirected_graph, alphas):
245+
"""
246+
Add mock-based tests to verify that the alpha thresholding logic in ECM's apply() method
247+
correctly retains edges based on the p-values computed from the optimization output.
248+
"""
249+
counts = []
250+
for alpha in alphas:
251+
np.random.seed(123)
252+
Gp = EnhancedConfigurationModelFilter(alpha=alpha).apply(small_undirected_graph)
253+
# undirected graph stores both directions explicitly
254+
counts.append(Gp.adj.nnz)
255+
256+
assert counts == sorted(counts)
257+
258+
259+
def test_ecm_alpha_thresholding_exact_edge_counts(monkeypatch):
260+
A = _csr(
261+
data=[4, 4, 2, 2, 1, 1],
262+
rows=[0, 1, 0, 2, 1, 2],
263+
cols=[1, 0, 2, 0, 2, 1],
264+
n=3,
265+
)
266+
G = Graph.from_csr(A, directed=False, weighted=True, mode="similarity")
267+
268+
# Lower-triangle edges of this graph are:
269+
# (1,0), (2,0), (2,1)
270+
fake_pvals = np.array([0.001, 0.02, 0.2], dtype=np.float64)
271+
272+
def fake_pval_matrix_data(x, y, row, col, weights):
273+
assert len(weights) == 3
274+
return fake_pvals.copy()
275+
276+
def fake_make_objective(*args, **kwargs):
277+
def fun(v):
278+
return 0.0
279+
def jac(v):
280+
return np.zeros_like(v)
281+
return fun, jac
282+
283+
class FakeResult:
284+
success = True
285+
message = "ok"
286+
x = np.zeros(2 * G.n_nodes, dtype=np.float64)
287+
288+
def fake_minimize(*args, **kwargs):
289+
return FakeResult()
290+
291+
import graphconstructor.operators.enhanced_configuration_model as ecm_mod
292+
293+
monkeypatch.setattr(ecm_mod, "_pval_matrix_data", fake_pval_matrix_data)
294+
monkeypatch.setattr(ecm_mod, "_make_objective", fake_make_objective)
295+
monkeypatch.setattr(ecm_mod.so, "minimize", fake_minimize)
296+
297+
cases = [
298+
(0.0005, 0), # keep none
299+
(0.01, 2), # keep one undirected edge -> 2 stored entries
300+
(0.05, 4), # keep two undirected edges -> 4 stored entries
301+
(0.5, 6), # keep all three undirected edges -> 6 stored entries
302+
]
303+
304+
for alpha, expected_nnz in cases:
305+
op = EnhancedConfigurationModelFilter(alpha=alpha)
306+
Gp = op.apply(G)
307+
assert Gp.adj.nnz == expected_nnz

0 commit comments

Comments
 (0)