Skip to content

Commit f3b45f7

Browse files
SILIZ4IvanIsCoding
andauthored
Add stochastic block model generator (#1200)
* wip generator first version * add sbm generator to rustworkx and rustworkx-core * fix formatting error * change for 2d arrays, use community sizes instead of memberships * Use cargo workspace for ndarray deps --------- Co-authored-by: Ivan Carvalho <[email protected]>
1 parent 955d989 commit f3b45f7

File tree

12 files changed

+514
-5
lines changed

12 files changed

+514
-5
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ ahash = "0.8.6"
2828
fixedbitset = "0.4.2"
2929
hashbrown = { version = ">=0.13, <0.15", features = ["rayon"] }
3030
indexmap = { version = ">=1.9, <3", features = ["rayon"] }
31+
ndarray = { version = "0.15.6", features = ["rayon"] }
3132
num-traits = "0.2"
3233
numpy = "0.21.0"
3334
petgraph = "0.6.5"
@@ -44,6 +45,7 @@ ahash.workspace = true
4445
fixedbitset.workspace = true
4546
hashbrown.workspace = true
4647
indexmap.workspace = true
48+
ndarray.workspace = true
4749
ndarray-stats = "0.5.1"
4850
num-bigint = "0.4"
4951
num-complex = "0.4"
@@ -63,10 +65,6 @@ rustworkx-core = { path = "rustworkx-core", version = "=0.15.0" }
6365
version = "0.21.2"
6466
features = ["abi3-py38", "extension-module", "hashbrown", "num-bigint", "num-complex", "indexmap"]
6567

66-
[dependencies.ndarray]
67-
version = "^0.15.6"
68-
features = ["rayon"]
69-
7068
[dependencies.sprs]
7169
version = "^0.11"
7270
features = ["multi_thread"]

docs/source/api/random_graph_generator_functions.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ Random Graph Generator Functions
1010
rustworkx.undirected_gnp_random_graph
1111
rustworkx.directed_gnm_random_graph
1212
rustworkx.undirected_gnm_random_graph
13+
rustworkx.directed_sbm_random_graph
14+
rustworkx.undirected_sbm_random_graph
1315
rustworkx.random_geometric_graph
1416
rustworkx.hyperbolic_random_graph
1517
rustworkx.barabasi_albert_graph
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
features:
2+
- |
3+
Adds new random graph generator in rustworkx for the stochastic block model.
4+
There is a generator for directed :func:`.directed_sbm_random_graph` and
5+
undirected graphs :func:`.undirected_sbm_random_graph`.
6+
- |
7+
Adds new function ``sbm_random_graph`` to the rustworkx-core module
8+
``rustworkx_core::generators`` that samples a graph from the stochastic
9+
block model.

rustworkx-core/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ ahash.workspace = true
1616
fixedbitset.workspace = true
1717
hashbrown.workspace = true
1818
indexmap.workspace = true
19+
ndarray.workspace = true
1920
num-traits.workspace = true
2021
petgraph.workspace = true
2122
priority-queue = "2.0"

rustworkx-core/src/generators/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,4 +62,5 @@ pub use random_graph::gnp_random_graph;
6262
pub use random_graph::hyperbolic_random_graph;
6363
pub use random_graph::random_bipartite_graph;
6464
pub use random_graph::random_geometric_graph;
65+
pub use random_graph::sbm_random_graph;
6566
pub use star_graph::star_graph;

rustworkx-core/src/generators/random_graph.rs

Lines changed: 286 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
use std::hash::Hash;
1616

17+
use ndarray::ArrayView2;
1718
use petgraph::data::{Build, Create};
1819
use petgraph::visit::{
1920
Data, EdgeRef, GraphBase, GraphProp, IntoEdgeReferences, IntoEdgesDirected,
@@ -305,6 +306,131 @@ where
305306
Ok(graph)
306307
}
307308

309+
/// Generate a graph from the stochastic block model.
310+
///
311+
/// The stochastic block model is a generalization of the G<sub>np</sub> random graph
312+
/// (see [gnp_random_graph] ). The connection probability of
313+
/// nodes `u` and `v` depends on their block and is given by
314+
/// `probabilities[blocks[u]][blocks[v]]`, where `blocks[u]` is the block membership
315+
/// of vertex `u`. The number of nodes and the number of blocks are inferred from
316+
/// `sizes`.
317+
///
318+
/// Arguments:
319+
///
320+
/// * `sizes` - Number of nodes in each block.
321+
/// * `probabilities` - B x B array that contains the connection probability between
322+
/// nodes of different blocks. Must be symmetric for undirected graphs.
323+
/// * `loops` - Determines whether the graph can have loops or not.
324+
/// * `seed` - An optional seed to use for the random number generator.
325+
/// * `default_node_weight` - A callable that will return the weight to use
326+
/// for newly created nodes.
327+
/// * `default_edge_weight` - A callable that will return the weight object
328+
/// to use for newly created edges.
329+
///
330+
/// # Example
331+
/// ```rust
332+
/// use ndarray::arr2;
333+
/// use rustworkx_core::petgraph;
334+
/// use rustworkx_core::generators::sbm_random_graph;
335+
///
336+
/// let g = sbm_random_graph::<petgraph::graph::DiGraph<(), ()>, (), _, _, ()>(
337+
/// &vec![1, 2],
338+
/// &ndarray::arr2(&[[0., 1.], [0., 1.]]).view(),
339+
/// true,
340+
/// Some(10),
341+
/// || (),
342+
/// || (),
343+
/// )
344+
/// .unwrap();
345+
/// assert_eq!(g.node_count(), 3);
346+
/// assert_eq!(g.edge_count(), 6);
347+
/// ```
348+
pub fn sbm_random_graph<G, T, F, H, M>(
349+
sizes: &[usize],
350+
probabilities: &ndarray::ArrayView2<f64>,
351+
loops: bool,
352+
seed: Option<u64>,
353+
mut default_node_weight: F,
354+
mut default_edge_weight: H,
355+
) -> Result<G, InvalidInputError>
356+
where
357+
G: Build + Create + Data<NodeWeight = T, EdgeWeight = M> + NodeIndexable + GraphProp,
358+
F: FnMut() -> T,
359+
H: FnMut() -> M,
360+
G::NodeId: Eq + Hash,
361+
{
362+
let num_nodes: usize = sizes.iter().sum();
363+
if num_nodes == 0 {
364+
return Err(InvalidInputError {});
365+
}
366+
let num_communities = sizes.len();
367+
if probabilities.nrows() != num_communities
368+
|| probabilities.ncols() != num_communities
369+
|| probabilities.iter().any(|&x| !(0. ..=1.).contains(&x))
370+
{
371+
return Err(InvalidInputError {});
372+
}
373+
374+
let mut graph = G::with_capacity(num_nodes, num_nodes);
375+
let directed = graph.is_directed();
376+
if !directed && !symmetric_array(probabilities) {
377+
return Err(InvalidInputError {});
378+
}
379+
380+
for _ in 0..num_nodes {
381+
graph.add_node(default_node_weight());
382+
}
383+
let mut rng: Pcg64 = match seed {
384+
Some(seed) => Pcg64::seed_from_u64(seed),
385+
None => Pcg64::from_entropy(),
386+
};
387+
let mut blocks = Vec::new();
388+
{
389+
let mut block = 0;
390+
let mut vertices_left = sizes[0];
391+
for _ in 0..num_nodes {
392+
while vertices_left == 0 {
393+
block += 1;
394+
vertices_left = sizes[block];
395+
}
396+
blocks.push(block);
397+
vertices_left -= 1;
398+
}
399+
}
400+
401+
let between = Uniform::new(0.0, 1.0);
402+
for v in 0..(if directed || loops {
403+
num_nodes
404+
} else {
405+
num_nodes - 1
406+
}) {
407+
for w in ((if directed { 0 } else { v })..num_nodes).filter(|&w| w != v || loops) {
408+
if &between.sample(&mut rng)
409+
< probabilities.get((blocks[v], blocks[w])).unwrap_or(&0_f64)
410+
{
411+
graph.add_edge(
412+
graph.from_index(v),
413+
graph.from_index(w),
414+
default_edge_weight(),
415+
);
416+
}
417+
}
418+
}
419+
Ok(graph)
420+
}
421+
422+
fn symmetric_array<T: std::cmp::PartialEq>(mat: &ArrayView2<T>) -> bool {
423+
let n = mat.nrows();
424+
for (i, row) in mat.rows().into_iter().enumerate().take(n - 1) {
425+
for (j, m_ij) in row.iter().enumerate().skip(i + 1) {
426+
if m_ij != mat.get((j, i)).unwrap() {
427+
return false;
428+
}
429+
}
430+
}
431+
true
432+
}
433+
308434
#[inline]
309435
fn pnorm(x: f64, p: f64) -> f64 {
310436
if p == 1.0 || p == std::f64::INFINITY {
@@ -749,7 +875,7 @@ mod tests {
749875
use crate::generators::InvalidInputError;
750876
use crate::generators::{
751877
barabasi_albert_graph, gnm_random_graph, gnp_random_graph, hyperbolic_random_graph,
752-
path_graph, random_bipartite_graph, random_geometric_graph,
878+
path_graph, random_bipartite_graph, random_geometric_graph, sbm_random_graph,
753879
};
754880
use crate::petgraph;
755881

@@ -879,6 +1005,165 @@ mod tests {
8791005
};
8801006
}
8811007

1008+
// Test sbm_random_graph
1009+
#[test]
1010+
fn test_sbm_directed_complete_blocks_loops() {
1011+
let g = sbm_random_graph::<petgraph::graph::DiGraph<(), ()>, (), _, _, ()>(
1012+
&vec![1, 2],
1013+
&ndarray::arr2(&[[0., 1.], [0., 1.]]).view(),
1014+
true,
1015+
Some(10),
1016+
|| (),
1017+
|| (),
1018+
)
1019+
.unwrap();
1020+
assert_eq!(g.node_count(), 3);
1021+
assert_eq!(g.edge_count(), 6);
1022+
for (u, v) in [(1, 1), (1, 2), (2, 1), (2, 2), (0, 1), (0, 2)] {
1023+
assert_eq!(g.contains_edge(u.into(), v.into()), true);
1024+
}
1025+
assert_eq!(g.contains_edge(1.into(), 0.into()), false);
1026+
assert_eq!(g.contains_edge(2.into(), 0.into()), false);
1027+
}
1028+
1029+
#[test]
1030+
fn test_sbm_undirected_complete_blocks_loops() {
1031+
let g = sbm_random_graph::<petgraph::graph::UnGraph<(), ()>, (), _, _, ()>(
1032+
&vec![1, 2],
1033+
&ndarray::arr2(&[[0., 1.], [1., 1.]]).view(),
1034+
true,
1035+
Some(10),
1036+
|| (),
1037+
|| (),
1038+
)
1039+
.unwrap();
1040+
assert_eq!(g.node_count(), 3);
1041+
assert_eq!(g.edge_count(), 5);
1042+
for (u, v) in [(1, 1), (1, 2), (2, 2), (0, 1), (0, 2)] {
1043+
assert_eq!(g.contains_edge(u.into(), v.into()), true);
1044+
}
1045+
assert_eq!(g.contains_edge(0.into(), 0.into()), false);
1046+
}
1047+
1048+
#[test]
1049+
fn test_sbm_directed_complete_blocks_noloops() {
1050+
let g = sbm_random_graph::<petgraph::graph::DiGraph<(), ()>, (), _, _, ()>(
1051+
&vec![1, 2],
1052+
&ndarray::arr2(&[[0., 1.], [0., 1.]]).view(),
1053+
false,
1054+
Some(10),
1055+
|| (),
1056+
|| (),
1057+
)
1058+
.unwrap();
1059+
assert_eq!(g.node_count(), 3);
1060+
assert_eq!(g.edge_count(), 4);
1061+
for (u, v) in [(1, 2), (2, 1), (0, 1), (0, 2)] {
1062+
assert_eq!(g.contains_edge(u.into(), v.into()), true);
1063+
}
1064+
assert_eq!(g.contains_edge(1.into(), 0.into()), false);
1065+
assert_eq!(g.contains_edge(2.into(), 0.into()), false);
1066+
for u in 0..2 {
1067+
assert_eq!(g.contains_edge(u.into(), u.into()), false);
1068+
}
1069+
}
1070+
1071+
#[test]
1072+
fn test_sbm_undirected_complete_blocks_noloops() {
1073+
let g = sbm_random_graph::<petgraph::graph::UnGraph<(), ()>, (), _, _, ()>(
1074+
&vec![1, 2],
1075+
&ndarray::arr2(&[[0., 1.], [1., 1.]]).view(),
1076+
false,
1077+
Some(10),
1078+
|| (),
1079+
|| (),
1080+
)
1081+
.unwrap();
1082+
assert_eq!(g.node_count(), 3);
1083+
assert_eq!(g.edge_count(), 3);
1084+
for (u, v) in [(1, 2), (0, 1), (0, 2)] {
1085+
assert_eq!(g.contains_edge(u.into(), v.into()), true);
1086+
}
1087+
for u in 0..2 {
1088+
assert_eq!(g.contains_edge(u.into(), u.into()), false);
1089+
}
1090+
}
1091+
1092+
#[test]
1093+
fn test_sbm_bad_array_rows_error() {
1094+
match sbm_random_graph::<petgraph::graph::DiGraph<(), ()>, (), _, _, ()>(
1095+
&vec![1, 2],
1096+
&ndarray::arr2(&[[0., 1.], [1., 1.], [1., 1.]]).view(),
1097+
true,
1098+
Some(10),
1099+
|| (),
1100+
|| (),
1101+
) {
1102+
Ok(_) => panic!("Returned a non-error"),
1103+
Err(e) => assert_eq!(e, InvalidInputError),
1104+
};
1105+
}
1106+
#[test]
1107+
1108+
fn test_sbm_bad_array_cols_error() {
1109+
match sbm_random_graph::<petgraph::graph::DiGraph<(), ()>, (), _, _, ()>(
1110+
&vec![1, 2],
1111+
&ndarray::arr2(&[[0., 1., 1.], [1., 1., 1.]]).view(),
1112+
true,
1113+
Some(10),
1114+
|| (),
1115+
|| (),
1116+
) {
1117+
Ok(_) => panic!("Returned a non-error"),
1118+
Err(e) => assert_eq!(e, InvalidInputError),
1119+
};
1120+
}
1121+
1122+
#[test]
1123+
fn test_sbm_asymmetric_array_error() {
1124+
match sbm_random_graph::<petgraph::graph::UnGraph<(), ()>, (), _, _, ()>(
1125+
&vec![1, 2],
1126+
&ndarray::arr2(&[[0., 1.], [0., 1.]]).view(),
1127+
true,
1128+
Some(10),
1129+
|| (),
1130+
|| (),
1131+
) {
1132+
Ok(_) => panic!("Returned a non-error"),
1133+
Err(e) => assert_eq!(e, InvalidInputError),
1134+
};
1135+
}
1136+
1137+
#[test]
1138+
fn test_sbm_invalid_probability_error() {
1139+
match sbm_random_graph::<petgraph::graph::UnGraph<(), ()>, (), _, _, ()>(
1140+
&vec![1, 2],
1141+
&ndarray::arr2(&[[0., 1.], [0., -1.]]).view(),
1142+
true,
1143+
Some(10),
1144+
|| (),
1145+
|| (),
1146+
) {
1147+
Ok(_) => panic!("Returned a non-error"),
1148+
Err(e) => assert_eq!(e, InvalidInputError),
1149+
};
1150+
}
1151+
1152+
#[test]
1153+
fn test_sbm_empty_error() {
1154+
match sbm_random_graph::<petgraph::graph::DiGraph<(), ()>, (), _, _, ()>(
1155+
&vec![],
1156+
&ndarray::arr2(&[[]]).view(),
1157+
true,
1158+
Some(10),
1159+
|| (),
1160+
|| (),
1161+
) {
1162+
Ok(_) => panic!("Returned a non-error"),
1163+
Err(e) => assert_eq!(e, InvalidInputError),
1164+
};
1165+
}
1166+
8821167
// Test random_geometric_graph
8831168

8841169
#[test]

rustworkx/__init__.pyi

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,8 @@ from .rustworkx import directed_gnm_random_graph as directed_gnm_random_graph
127127
from .rustworkx import undirected_gnm_random_graph as undirected_gnm_random_graph
128128
from .rustworkx import directed_gnp_random_graph as directed_gnp_random_graph
129129
from .rustworkx import undirected_gnp_random_graph as undirected_gnp_random_graph
130+
from .rustworkx import directed_sbm_random_graph as directed_sbm_random_graph
131+
from .rustworkx import undirected_sbm_random_graph as undirected_sbm_random_graph
130132
from .rustworkx import random_geometric_graph as random_geometric_graph
131133
from .rustworkx import hyperbolic_random_graph as hyperbolic_random_graph
132134
from .rustworkx import barabasi_albert_graph as barabasi_albert_graph

0 commit comments

Comments
 (0)