|
14 | 14 |
|
15 | 15 | use std::hash::Hash; |
16 | 16 |
|
| 17 | +use ndarray::ArrayView2; |
17 | 18 | use petgraph::data::{Build, Create}; |
18 | 19 | use petgraph::visit::{ |
19 | 20 | Data, EdgeRef, GraphBase, GraphProp, IntoEdgeReferences, IntoEdgesDirected, |
@@ -305,6 +306,131 @@ where |
305 | 306 | Ok(graph) |
306 | 307 | } |
307 | 308 |
|
| 309 | +/// Generate a graph from the stochastic block model. |
| 310 | +/// |
| 311 | +/// The stochastic block model is a generalization of the G<sub>np</sub> random graph |
| 312 | +/// (see [gnp_random_graph] ). The connection probability of |
| 313 | +/// nodes `u` and `v` depends on their block and is given by |
| 314 | +/// `probabilities[blocks[u]][blocks[v]]`, where `blocks[u]` is the block membership |
| 315 | +/// of vertex `u`. The number of nodes and the number of blocks are inferred from |
| 316 | +/// `sizes`. |
| 317 | +/// |
| 318 | +/// Arguments: |
| 319 | +/// |
| 320 | +/// * `sizes` - Number of nodes in each block. |
| 321 | +/// * `probabilities` - B x B array that contains the connection probability between |
| 322 | +/// nodes of different blocks. Must be symmetric for undirected graphs. |
| 323 | +/// * `loops` - Determines whether the graph can have loops or not. |
| 324 | +/// * `seed` - An optional seed to use for the random number generator. |
| 325 | +/// * `default_node_weight` - A callable that will return the weight to use |
| 326 | +/// for newly created nodes. |
| 327 | +/// * `default_edge_weight` - A callable that will return the weight object |
| 328 | +/// to use for newly created edges. |
| 329 | +/// |
| 330 | +/// # Example |
| 331 | +/// ```rust |
| 332 | +/// use ndarray::arr2; |
| 333 | +/// use rustworkx_core::petgraph; |
| 334 | +/// use rustworkx_core::generators::sbm_random_graph; |
| 335 | +/// |
| 336 | +/// let g = sbm_random_graph::<petgraph::graph::DiGraph<(), ()>, (), _, _, ()>( |
| 337 | +/// &vec![1, 2], |
| 338 | +/// &ndarray::arr2(&[[0., 1.], [0., 1.]]).view(), |
| 339 | +/// true, |
| 340 | +/// Some(10), |
| 341 | +/// || (), |
| 342 | +/// || (), |
| 343 | +/// ) |
| 344 | +/// .unwrap(); |
| 345 | +/// assert_eq!(g.node_count(), 3); |
| 346 | +/// assert_eq!(g.edge_count(), 6); |
| 347 | +/// ``` |
| 348 | +pub fn sbm_random_graph<G, T, F, H, M>( |
| 349 | + sizes: &[usize], |
| 350 | + probabilities: &ndarray::ArrayView2<f64>, |
| 351 | + loops: bool, |
| 352 | + seed: Option<u64>, |
| 353 | + mut default_node_weight: F, |
| 354 | + mut default_edge_weight: H, |
| 355 | +) -> Result<G, InvalidInputError> |
| 356 | +where |
| 357 | + G: Build + Create + Data<NodeWeight = T, EdgeWeight = M> + NodeIndexable + GraphProp, |
| 358 | + F: FnMut() -> T, |
| 359 | + H: FnMut() -> M, |
| 360 | + G::NodeId: Eq + Hash, |
| 361 | +{ |
| 362 | + let num_nodes: usize = sizes.iter().sum(); |
| 363 | + if num_nodes == 0 { |
| 364 | + return Err(InvalidInputError {}); |
| 365 | + } |
| 366 | + let num_communities = sizes.len(); |
| 367 | + if probabilities.nrows() != num_communities |
| 368 | + || probabilities.ncols() != num_communities |
| 369 | + || probabilities.iter().any(|&x| !(0. ..=1.).contains(&x)) |
| 370 | + { |
| 371 | + return Err(InvalidInputError {}); |
| 372 | + } |
| 373 | + |
| 374 | + let mut graph = G::with_capacity(num_nodes, num_nodes); |
| 375 | + let directed = graph.is_directed(); |
| 376 | + if !directed && !symmetric_array(probabilities) { |
| 377 | + return Err(InvalidInputError {}); |
| 378 | + } |
| 379 | + |
| 380 | + for _ in 0..num_nodes { |
| 381 | + graph.add_node(default_node_weight()); |
| 382 | + } |
| 383 | + let mut rng: Pcg64 = match seed { |
| 384 | + Some(seed) => Pcg64::seed_from_u64(seed), |
| 385 | + None => Pcg64::from_entropy(), |
| 386 | + }; |
| 387 | + let mut blocks = Vec::new(); |
| 388 | + { |
| 389 | + let mut block = 0; |
| 390 | + let mut vertices_left = sizes[0]; |
| 391 | + for _ in 0..num_nodes { |
| 392 | + while vertices_left == 0 { |
| 393 | + block += 1; |
| 394 | + vertices_left = sizes[block]; |
| 395 | + } |
| 396 | + blocks.push(block); |
| 397 | + vertices_left -= 1; |
| 398 | + } |
| 399 | + } |
| 400 | + |
| 401 | + let between = Uniform::new(0.0, 1.0); |
| 402 | + for v in 0..(if directed || loops { |
| 403 | + num_nodes |
| 404 | + } else { |
| 405 | + num_nodes - 1 |
| 406 | + }) { |
| 407 | + for w in ((if directed { 0 } else { v })..num_nodes).filter(|&w| w != v || loops) { |
| 408 | + if &between.sample(&mut rng) |
| 409 | + < probabilities.get((blocks[v], blocks[w])).unwrap_or(&0_f64) |
| 410 | + { |
| 411 | + graph.add_edge( |
| 412 | + graph.from_index(v), |
| 413 | + graph.from_index(w), |
| 414 | + default_edge_weight(), |
| 415 | + ); |
| 416 | + } |
| 417 | + } |
| 418 | + } |
| 419 | + Ok(graph) |
| 420 | +} |
| 421 | + |
| 422 | +fn symmetric_array<T: std::cmp::PartialEq>(mat: &ArrayView2<T>) -> bool { |
| 423 | + let n = mat.nrows(); |
| 424 | + for (i, row) in mat.rows().into_iter().enumerate().take(n - 1) { |
| 425 | + for (j, m_ij) in row.iter().enumerate().skip(i + 1) { |
| 426 | + if m_ij != mat.get((j, i)).unwrap() { |
| 427 | + return false; |
| 428 | + } |
| 429 | + } |
| 430 | + } |
| 431 | + true |
| 432 | +} |
| 433 | + |
308 | 434 | #[inline] |
309 | 435 | fn pnorm(x: f64, p: f64) -> f64 { |
310 | 436 | if p == 1.0 || p == std::f64::INFINITY { |
@@ -749,7 +875,7 @@ mod tests { |
749 | 875 | use crate::generators::InvalidInputError; |
750 | 876 | use crate::generators::{ |
751 | 877 | barabasi_albert_graph, gnm_random_graph, gnp_random_graph, hyperbolic_random_graph, |
752 | | - path_graph, random_bipartite_graph, random_geometric_graph, |
| 878 | + path_graph, random_bipartite_graph, random_geometric_graph, sbm_random_graph, |
753 | 879 | }; |
754 | 880 | use crate::petgraph; |
755 | 881 |
|
@@ -879,6 +1005,165 @@ mod tests { |
879 | 1005 | }; |
880 | 1006 | } |
881 | 1007 |
|
| 1008 | + // Test sbm_random_graph |
| 1009 | + #[test] |
| 1010 | + fn test_sbm_directed_complete_blocks_loops() { |
| 1011 | + let g = sbm_random_graph::<petgraph::graph::DiGraph<(), ()>, (), _, _, ()>( |
| 1012 | + &vec![1, 2], |
| 1013 | + &ndarray::arr2(&[[0., 1.], [0., 1.]]).view(), |
| 1014 | + true, |
| 1015 | + Some(10), |
| 1016 | + || (), |
| 1017 | + || (), |
| 1018 | + ) |
| 1019 | + .unwrap(); |
| 1020 | + assert_eq!(g.node_count(), 3); |
| 1021 | + assert_eq!(g.edge_count(), 6); |
| 1022 | + for (u, v) in [(1, 1), (1, 2), (2, 1), (2, 2), (0, 1), (0, 2)] { |
| 1023 | + assert_eq!(g.contains_edge(u.into(), v.into()), true); |
| 1024 | + } |
| 1025 | + assert_eq!(g.contains_edge(1.into(), 0.into()), false); |
| 1026 | + assert_eq!(g.contains_edge(2.into(), 0.into()), false); |
| 1027 | + } |
| 1028 | + |
| 1029 | + #[test] |
| 1030 | + fn test_sbm_undirected_complete_blocks_loops() { |
| 1031 | + let g = sbm_random_graph::<petgraph::graph::UnGraph<(), ()>, (), _, _, ()>( |
| 1032 | + &vec![1, 2], |
| 1033 | + &ndarray::arr2(&[[0., 1.], [1., 1.]]).view(), |
| 1034 | + true, |
| 1035 | + Some(10), |
| 1036 | + || (), |
| 1037 | + || (), |
| 1038 | + ) |
| 1039 | + .unwrap(); |
| 1040 | + assert_eq!(g.node_count(), 3); |
| 1041 | + assert_eq!(g.edge_count(), 5); |
| 1042 | + for (u, v) in [(1, 1), (1, 2), (2, 2), (0, 1), (0, 2)] { |
| 1043 | + assert_eq!(g.contains_edge(u.into(), v.into()), true); |
| 1044 | + } |
| 1045 | + assert_eq!(g.contains_edge(0.into(), 0.into()), false); |
| 1046 | + } |
| 1047 | + |
| 1048 | + #[test] |
| 1049 | + fn test_sbm_directed_complete_blocks_noloops() { |
| 1050 | + let g = sbm_random_graph::<petgraph::graph::DiGraph<(), ()>, (), _, _, ()>( |
| 1051 | + &vec![1, 2], |
| 1052 | + &ndarray::arr2(&[[0., 1.], [0., 1.]]).view(), |
| 1053 | + false, |
| 1054 | + Some(10), |
| 1055 | + || (), |
| 1056 | + || (), |
| 1057 | + ) |
| 1058 | + .unwrap(); |
| 1059 | + assert_eq!(g.node_count(), 3); |
| 1060 | + assert_eq!(g.edge_count(), 4); |
| 1061 | + for (u, v) in [(1, 2), (2, 1), (0, 1), (0, 2)] { |
| 1062 | + assert_eq!(g.contains_edge(u.into(), v.into()), true); |
| 1063 | + } |
| 1064 | + assert_eq!(g.contains_edge(1.into(), 0.into()), false); |
| 1065 | + assert_eq!(g.contains_edge(2.into(), 0.into()), false); |
| 1066 | + for u in 0..2 { |
| 1067 | + assert_eq!(g.contains_edge(u.into(), u.into()), false); |
| 1068 | + } |
| 1069 | + } |
| 1070 | + |
| 1071 | + #[test] |
| 1072 | + fn test_sbm_undirected_complete_blocks_noloops() { |
| 1073 | + let g = sbm_random_graph::<petgraph::graph::UnGraph<(), ()>, (), _, _, ()>( |
| 1074 | + &vec![1, 2], |
| 1075 | + &ndarray::arr2(&[[0., 1.], [1., 1.]]).view(), |
| 1076 | + false, |
| 1077 | + Some(10), |
| 1078 | + || (), |
| 1079 | + || (), |
| 1080 | + ) |
| 1081 | + .unwrap(); |
| 1082 | + assert_eq!(g.node_count(), 3); |
| 1083 | + assert_eq!(g.edge_count(), 3); |
| 1084 | + for (u, v) in [(1, 2), (0, 1), (0, 2)] { |
| 1085 | + assert_eq!(g.contains_edge(u.into(), v.into()), true); |
| 1086 | + } |
| 1087 | + for u in 0..2 { |
| 1088 | + assert_eq!(g.contains_edge(u.into(), u.into()), false); |
| 1089 | + } |
| 1090 | + } |
| 1091 | + |
| 1092 | + #[test] |
| 1093 | + fn test_sbm_bad_array_rows_error() { |
| 1094 | + match sbm_random_graph::<petgraph::graph::DiGraph<(), ()>, (), _, _, ()>( |
| 1095 | + &vec![1, 2], |
| 1096 | + &ndarray::arr2(&[[0., 1.], [1., 1.], [1., 1.]]).view(), |
| 1097 | + true, |
| 1098 | + Some(10), |
| 1099 | + || (), |
| 1100 | + || (), |
| 1101 | + ) { |
| 1102 | + Ok(_) => panic!("Returned a non-error"), |
| 1103 | + Err(e) => assert_eq!(e, InvalidInputError), |
| 1104 | + }; |
| 1105 | + } |
| 1106 | + #[test] |
| 1107 | + |
| 1108 | + fn test_sbm_bad_array_cols_error() { |
| 1109 | + match sbm_random_graph::<petgraph::graph::DiGraph<(), ()>, (), _, _, ()>( |
| 1110 | + &vec![1, 2], |
| 1111 | + &ndarray::arr2(&[[0., 1., 1.], [1., 1., 1.]]).view(), |
| 1112 | + true, |
| 1113 | + Some(10), |
| 1114 | + || (), |
| 1115 | + || (), |
| 1116 | + ) { |
| 1117 | + Ok(_) => panic!("Returned a non-error"), |
| 1118 | + Err(e) => assert_eq!(e, InvalidInputError), |
| 1119 | + }; |
| 1120 | + } |
| 1121 | + |
| 1122 | + #[test] |
| 1123 | + fn test_sbm_asymmetric_array_error() { |
| 1124 | + match sbm_random_graph::<petgraph::graph::UnGraph<(), ()>, (), _, _, ()>( |
| 1125 | + &vec![1, 2], |
| 1126 | + &ndarray::arr2(&[[0., 1.], [0., 1.]]).view(), |
| 1127 | + true, |
| 1128 | + Some(10), |
| 1129 | + || (), |
| 1130 | + || (), |
| 1131 | + ) { |
| 1132 | + Ok(_) => panic!("Returned a non-error"), |
| 1133 | + Err(e) => assert_eq!(e, InvalidInputError), |
| 1134 | + }; |
| 1135 | + } |
| 1136 | + |
| 1137 | + #[test] |
| 1138 | + fn test_sbm_invalid_probability_error() { |
| 1139 | + match sbm_random_graph::<petgraph::graph::UnGraph<(), ()>, (), _, _, ()>( |
| 1140 | + &vec![1, 2], |
| 1141 | + &ndarray::arr2(&[[0., 1.], [0., -1.]]).view(), |
| 1142 | + true, |
| 1143 | + Some(10), |
| 1144 | + || (), |
| 1145 | + || (), |
| 1146 | + ) { |
| 1147 | + Ok(_) => panic!("Returned a non-error"), |
| 1148 | + Err(e) => assert_eq!(e, InvalidInputError), |
| 1149 | + }; |
| 1150 | + } |
| 1151 | + |
| 1152 | + #[test] |
| 1153 | + fn test_sbm_empty_error() { |
| 1154 | + match sbm_random_graph::<petgraph::graph::DiGraph<(), ()>, (), _, _, ()>( |
| 1155 | + &vec![], |
| 1156 | + &ndarray::arr2(&[[]]).view(), |
| 1157 | + true, |
| 1158 | + Some(10), |
| 1159 | + || (), |
| 1160 | + || (), |
| 1161 | + ) { |
| 1162 | + Ok(_) => panic!("Returned a non-error"), |
| 1163 | + Err(e) => assert_eq!(e, InvalidInputError), |
| 1164 | + }; |
| 1165 | + } |
| 1166 | + |
882 | 1167 | // Test random_geometric_graph |
883 | 1168 |
|
884 | 1169 | #[test] |
|
0 commit comments