Skip to content

Commit 4b45a1f

Browse files
committed
improve generalised filtering nodes
1 parent 9045dc0 commit 4b45a1f

File tree

10 files changed

+46
-13
lines changed

10 files changed

+46
-13
lines changed

docs/source/notebooks/0.3-Generalised_filtering.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,7 @@
308308
},
309309
{
310310
"cell_type": "code",
311-
"execution_count": 7,
311+
"execution_count": null,
312312
"id": "1798765e-3d65-4bfd-964b-7f9b6b0902be",
313313
"metadata": {},
314314
"outputs": [
@@ -358,7 +358,7 @@
358358
},
359359
{
360360
"cell_type": "code",
361-
"execution_count": 8,
361+
"execution_count": null,
362362
"id": "2d921e51-a940-42b2-88f2-e25bd7ab5a01",
363363
"metadata": {
364364
"editable": true,

examples/exponential.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ fn main() {
77

88
// create a network with two exponential family state nodes
99
network.add_nodes(
10-
"exponential-state",
10+
"ef-state",
1111
None,
1212
None,
1313
None,

src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
pub mod model;
22
pub mod utils;
3-
pub mod math;
3+
pub mod maths;
44
pub mod updates;

src/math.rs

Lines changed: 0 additions & 3 deletions
This file was deleted.

src/maths/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
pub mod sufficient_statistics;

src/maths/sufficient_statistics.rs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
pub fn normal(x: &f64) -> Vec<f64> {
2+
vec![*x, x.powf(2.0)]
3+
}
4+
5+
pub fn multivariate_normal(x: &Vec<f64>) -> Vec<f64> {
6+
vec![*x, x.powf(2.0)]
7+
}
8+
9+
pub fn get_sufficient_statistics_fn(distribution: String) {
10+
if distribution == "normal" {
11+
normal;
12+
} else if distribution == "multivariate_normal" {
13+
multivariate_normal;
14+
}
15+
}

src/model.rs

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,21 @@ impl Network {
7373
/// * `value_children` - The indexes of the node's value children.
7474
/// * `volatility_children` - The indexes of the node's volatility children.
7575
/// * `volatility_parents` - The indexes of the node's volatility parents.
76-
#[pyo3(signature = (kind="continuous-state", value_parents=None, value_children=None, volatility_parents=None, volatility_children=None,))]
77-
pub fn add_nodes(&mut self, kind: &str, value_parents: Option<Vec<usize>>,
76+
#[pyo3(signature = (
77+
kind="continuous-state",
78+
value_parents=None,
79+
value_children=None,
80+
volatility_parents=None,
81+
volatility_children=None,
82+
ef_dimension=None,
83+
ef_distribution=None,
84+
ef_learning=None,
85+
)
86+
)]
87+
pub fn add_nodes(
88+
&mut self,
89+
kind: &str,
90+
value_parents: Option<Vec<usize>>,
7891
value_children: Option<Vec<usize>>,
7992
volatility_parents: Option<Vec<usize>>, volatility_children: Option<Vec<usize>>, ) {
8093

@@ -86,6 +99,7 @@ impl Network {
8699
self.inputs.push(node_id);
87100
}
88101

102+
// Update the edges variable
89103
let edges = AdjacencyLists{
90104
node_type: String::from(kind),
91105
value_parents: value_parents,
@@ -94,6 +108,11 @@ impl Network {
94108
volatility_children: volatility_children,
95109
};
96110

111+
// Add emtpy adjacency lists in the new node
112+
self.edges.insert(node_id, edges);
113+
114+
// TODO: Update the edges of parents and children accordingly
115+
97116
// add edges and attributes
98117
if kind == "continuous-state" {
99118

@@ -107,7 +126,6 @@ impl Network {
107126
(String::from("autoconnection_strength"), 1.0)].into_iter().collect();
108127

109128
self.attributes.floats.insert(node_id, attributes);
110-
self.edges.insert(node_id, edges);
111129

112130
} else if kind == "ef-state" {
113131

@@ -123,6 +141,7 @@ impl Network {
123141

124142
}
125143
}
144+
}
126145

127146
pub fn set_update_sequence(&mut self) {
128147
self.update_sequence = set_update_sequence(self);

src/updates/prediction/continuous.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ use crate::model::Network;
33
/// Prediction from a continuous state node
44
///
55
/// # Arguments
6-
/// * `network` - The main network containing the node.
6+
/// * `network` - The main network structure.
77
/// * `node_idx` - The node index.
88
///
99
/// # Returns

src/updates/prediction_error/exponential.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
use crate::model::Network;
2-
use crate::math::sufficient_statistics;
32

43
/// Updating an exponential family state node
54
///
65
/// # Arguments
76
/// * `network` - The main network containing the node.
87
/// * `node_idx` - The node index.
8+
/// * `sufficient_statistics` - A function computing the sufficient statistics of an exponential family distribution.
99
///
1010
/// # Returns
1111
/// * `network` - The network after message passing.
12-
pub fn prediction_error_exponential_state_node(network: &mut Network, node_idx: usize) {
12+
pub fn prediction_error_exponential_state_node(network: &mut Network, node_idx: usize, sufficient_statistics: fn(&f64) -> Vec<f64>) {
1313

1414
let floats_attributes = network.attributes.floats.get_mut(&node_idx).expect("No floats attributes");
1515
let vectors_attributes = network.attributes.vectors.get_mut(&node_idx).expect("No vector attributes");

src/utils/set_sequence.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
use crate::{model::{AdjacencyLists, Network, UpdateSequence}, updates::{posterior::continuous::posterior_update_continuous_state_node, prediction::continuous::prediction_continuous_state_node, prediction_error::{continuous::prediction_error_continuous_state_node, exponential::prediction_error_exponential_state_node}}};
22
use crate::utils::function_pointer::FnType;
3+
use crate::maths::sufficient_statistics::get_sufficient_statistics_fn;
34

45
pub fn set_update_sequence(network: &Network) -> UpdateSequence {
56
let predictions = get_predictions_sequence(network);

0 commit comments

Comments
 (0)