Skip to content

Commit 4bb8fe9

Browse files
committed
Improve API/UX
1 parent 73d8eec commit 4bb8fe9

File tree

12 files changed

+711
-0
lines changed

12 files changed

+711
-0
lines changed

hypergraphx/communities/api.py

Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
from __future__ import annotations
2+
3+
from typing import Any, Optional
4+
5+
import numpy as np
6+
7+
from hypergraphx import Hypergraph
8+
from hypergraphx.communities.results import (
9+
CorePeripheryResult,
10+
HypergraphMTResult,
11+
HyperlinkCommunitiesResult,
12+
HyMMSBMResult,
13+
HySCResult,
14+
hard_labels_from_memberships,
15+
)
16+
17+
18+
def run_core_periphery(
19+
hypergraph: Hypergraph,
20+
*,
21+
greedy_start: bool = False,
22+
n_iter: int = 1000,
23+
seed: int | None = None,
24+
rng: np.random.Generator | None = None,
25+
) -> CorePeripheryResult:
26+
"""
27+
Core-periphery coreness scores.
28+
29+
Returns
30+
-------
31+
CorePeripheryResult
32+
`scores`: dict mapping node -> coreness score.
33+
"""
34+
from hypergraphx.communities.core_periphery.model import core_periphery
35+
36+
scores = core_periphery(
37+
hypergraph,
38+
greedy_start=greedy_start,
39+
N_ITER=n_iter,
40+
seed=seed,
41+
rng=rng,
42+
)
43+
return CorePeripheryResult(scores=scores)
44+
45+
46+
def run_hyperlink_communities(
47+
hypergraph: Hypergraph,
48+
*,
49+
load_distances: str | None = None,
50+
save_distances: str | None = None,
51+
) -> HyperlinkCommunitiesResult:
52+
"""
53+
Hyperlink communities (hierarchical clustering over edge distances).
54+
55+
Returns
56+
-------
57+
HyperlinkCommunitiesResult
58+
`dendrogram`: SciPy hierarchical clustering dendrogram array.
59+
"""
60+
from hypergraphx.communities.hyperlink_comm.hyperlink_communities import (
61+
hyperlink_communities,
62+
)
63+
64+
dendrogram = hyperlink_communities(
65+
hypergraph, load_distances=load_distances, save_distances=save_distances
66+
)
67+
return HyperlinkCommunitiesResult(dendrogram=dendrogram)
68+
69+
70+
def fit_hysc(
71+
hypergraph: Hypergraph,
72+
*,
73+
k: int,
74+
seed: int = 0,
75+
weighted_laplacian: bool = False,
76+
out_inference: bool = False,
77+
out_folder: str = "../data/output/",
78+
end_file: str = "_sc.dat",
79+
) -> HySCResult:
80+
"""
81+
Hypergraph Spectral Clustering (HySC).
82+
83+
Returns
84+
-------
85+
HySCResult
86+
`memberships`: hard membership matrix u (N x K)
87+
`labels`: hard labels (N,)
88+
"""
89+
from hypergraphx.communities.hy_sc.model import HySC
90+
91+
model = HySC(
92+
seed=seed, out_inference=out_inference, out_folder=out_folder, end_file=end_file
93+
)
94+
memberships = model.fit(hypergraph, K=k, weighted_L=weighted_laplacian)
95+
labels = hard_labels_from_memberships(np.asarray(memberships))
96+
return HySCResult(memberships=np.asarray(memberships), labels=labels, model=model)
97+
98+
99+
def fit_hypergraph_mt(
100+
hypergraph: Hypergraph,
101+
*,
102+
k: int,
103+
seed: int | None = None,
104+
normalize_u: bool = False,
105+
baseline_r0: bool = True,
106+
**params: Any,
107+
) -> HypergraphMTResult:
108+
"""
109+
Hypergraph-MT mixed-membership inference.
110+
111+
Returns
112+
-------
113+
HypergraphMTResult
114+
`memberships`: membership matrix u (N x K)
115+
`affinity`: model affinity parameters w (shape depends on implementation)
116+
`max_loglik`: best achieved log-likelihood
117+
"""
118+
from hypergraphx.communities.hypergraph_mt.model import HypergraphMT
119+
120+
model = HypergraphMT(**params)
121+
memberships, affinity, max_loglik = model.fit(
122+
hypergraph,
123+
K=k,
124+
seed=seed,
125+
normalizeU=normalize_u,
126+
baseline_r0=baseline_r0,
127+
)
128+
memberships = np.asarray(memberships)
129+
labels = (
130+
hard_labels_from_memberships(memberships) if memberships.ndim == 2 else None
131+
)
132+
return HypergraphMTResult(
133+
memberships=memberships,
134+
affinity=np.asarray(affinity),
135+
max_loglik=float(max_loglik),
136+
labels=labels,
137+
model=model,
138+
)
139+
140+
141+
def fit_hy_mmsbm(
142+
hypergraph: Hypergraph,
143+
*,
144+
k: int,
145+
seed: int | None = None,
146+
n_iter: int = 500,
147+
tol: float | None = None,
148+
check_convergence_every: int = 10,
149+
**init_params: Any,
150+
) -> HyMMSBMResult:
151+
"""
152+
Hy-MMSBM Expectation-Maximization inference.
153+
154+
Returns
155+
-------
156+
HyMMSBMResult
157+
`memberships`: soft assignments u (N x K)
158+
`affinity`: affinity matrix w (K x K)
159+
`labels`: argmax hard labels (N,)
160+
"""
161+
from hypergraphx.communities.hy_mmsbm.model import HyMMSBM
162+
163+
model = HyMMSBM(K=k, seed=seed, **init_params)
164+
model.fit(
165+
hypergraph,
166+
n_iter=n_iter,
167+
tolerance=tol,
168+
check_convergence_every=check_convergence_every,
169+
)
170+
if model.u is None or model.w is None:
171+
raise RuntimeError("HyMMSBM.fit() did not produce u/w parameters.")
172+
memberships = np.asarray(model.u)
173+
labels = hard_labels_from_memberships(memberships)
174+
return HyMMSBMResult(
175+
memberships=memberships,
176+
affinity=np.asarray(model.w),
177+
trained=bool(getattr(model, "trained", True)),
178+
labels=labels,
179+
model=model,
180+
)

hypergraphx/communities/results.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
from __future__ import annotations
2+
3+
from dataclasses import dataclass
4+
from typing import Any, Dict, Optional
5+
6+
import numpy as np
7+
8+
9+
def hard_labels_from_memberships(u: np.ndarray) -> np.ndarray:
10+
"""
11+
Convert a membership matrix `u` (N x K) to hard labels via argmax.
12+
"""
13+
if u.ndim != 2:
14+
raise ValueError("u must be a 2D array of shape (N, K).")
15+
if u.shape[1] == 0:
16+
raise ValueError("u must have K>0 columns.")
17+
return np.asarray(np.argmax(u, axis=1), dtype=int)
18+
19+
20+
@dataclass(frozen=True)
21+
class CorePeripheryResult:
22+
"""Result for core-periphery scoring.
23+
24+
Attributes
25+
----------
26+
scores : dict
27+
Mapping `node -> coreness score` (float).
28+
"""
29+
30+
scores: Dict[Any, float]
31+
32+
33+
@dataclass(frozen=True)
34+
class HyperlinkCommunitiesResult:
35+
"""Result for hyperlink communities.
36+
37+
Attributes
38+
----------
39+
dendrogram : np.ndarray
40+
SciPy hierarchical clustering dendrogram.
41+
Use a cut height to extract flat edge-cluster labels.
42+
"""
43+
44+
dendrogram: np.ndarray
45+
46+
47+
@dataclass(frozen=True)
48+
class HySCResult:
49+
"""Result for Hypergraph Spectral Clustering (HySC).
50+
51+
Attributes
52+
----------
53+
memberships : np.ndarray
54+
Hard-membership matrix `u` of shape (N, K).
55+
labels : np.ndarray
56+
Hard labels of shape (N,), derived from memberships.
57+
model : object
58+
The fitted HySC model instance.
59+
"""
60+
61+
memberships: np.ndarray
62+
labels: np.ndarray
63+
model: Any
64+
65+
66+
@dataclass(frozen=True)
67+
class HypergraphMTResult:
68+
"""Result for Hypergraph-MT.
69+
70+
Attributes
71+
----------
72+
memberships : np.ndarray
73+
Membership matrix `u` of shape (N, K).
74+
affinity : np.ndarray
75+
Affinity parameters `w` as returned by the model.
76+
max_loglik : float
77+
Best achieved log-likelihood across realizations.
78+
labels : np.ndarray | None
79+
Optional hard labels derived from memberships (argmax). Present when the
80+
returned `memberships` has shape (N, K) with K>0.
81+
model : object
82+
The fitted HypergraphMT model instance.
83+
"""
84+
85+
memberships: np.ndarray
86+
affinity: np.ndarray
87+
max_loglik: float
88+
labels: Optional[np.ndarray]
89+
model: Any
90+
91+
92+
@dataclass(frozen=True)
93+
class HyMMSBMResult:
94+
"""Result for Hy-MMSBM.
95+
96+
Attributes
97+
----------
98+
memberships : np.ndarray
99+
Soft assignments `u` of shape (N, K).
100+
affinity : np.ndarray
101+
Affinity matrix `w` of shape (K, K).
102+
trained : bool
103+
Whether the model reports itself as trained.
104+
labels : np.ndarray
105+
Hard labels derived from memberships (argmax).
106+
model : object
107+
The fitted HyMMSBM model instance.
108+
"""
109+
110+
memberships: np.ndarray
111+
affinity: np.ndarray
112+
trained: bool
113+
labels: np.ndarray
114+
model: Any

0 commit comments

Comments
 (0)