Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ keywords = ["machine-learning", "random-forest", "decision-tree", "ensemble", "t
name = "biosphere"

[dependencies]
ndarray = "0.17.1"
ndarray = "0.17.2"
rand = "0.9"
rayon = "1.11"

Expand Down
2 changes: 1 addition & 1 deletion biosphere-py/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,5 @@ name = "biosphere"
[dependencies]
numpy = "0.27.1"
biosphere = { path = "../" }
ndarray = "0.17.1"
ndarray = "0.17.2"
pyo3 = {version = "0.27", features = ["extension-module"]}
2 changes: 1 addition & 1 deletion biosphere-py/src/random_forest.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ impl RandomForest {
random_state,
max_depth,
max_features.value,
min_samples_split,
min_samples_leaf,
min_samples_split,
n_jobs,
);
Ok(RandomForest {
Expand Down
169 changes: 168 additions & 1 deletion src/tree/decision_tree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,10 @@ impl DecisionTree {
mod tests {
use super::*;
use crate::testing::load_iris;
use ndarray::s;
use crate::MaxFeatures;
use ndarray::{s, Array2};
use rstest::*;
use std::collections::HashSet;

#[rstest]
fn test_tree() {
Expand All @@ -121,4 +123,169 @@ mod tests {
// perfectly replicate these with another decision tree.
assert_eq!(predictions - another_predictions, Array1::<f64>::zeros(150));
}

#[test]
fn test_fit_with_constant_features_predicts_mean() {
let X = Array2::<f64>::zeros((10, 3));
let y = Array1::from_vec(vec![0., 0., 0., 0., 0., 1., 1., 1., 1., 1.]);

let mut tree = DecisionTree::default();
tree.fit(&X.view(), &y.view());
let predictions = tree.predict(&X.view());

assert_eq!(predictions, Array1::<f64>::from_elem(10, 0.5));
}

#[test]
fn test_min_samples_leaf_prevents_split() {
let X = Array2::from_shape_fn((10, 1), |(i, _)| i as f64);
let y = Array1::from_vec(vec![0., 0., 0., 0., 0., 1., 1., 1., 1., 1.]);

// Tree should not SPLIT because min_samples_leaf of 6 would leave fewer than 6 samples in one of the child nodes!
let parameters = DecisionTreeParameters::new(None, MaxFeatures::None, 2, 6, 0);
let mut tree = DecisionTree::new(parameters);
tree.fit(&X.view(), &y.view());
let predictions = tree.predict(&X.view());

assert_eq!(predictions, Array1::<f64>::from_elem(10, 0.5));
}

#[test]
fn test_min_samples_leaf_allows_valid_split() {
let X = Array2::from_shape_fn((4, 1), |(i, _)| i as f64);
let y = Array1::from_vec(vec![0., 0., 1., 1.]);

let parameters = DecisionTreeParameters::new(None, MaxFeatures::None, 2, 2, 0);
let mut tree = DecisionTree::new(parameters);
tree.fit(&X.view(), &y.view());
let predictions = tree.predict(&X.view());

assert_eq!(predictions, y);
}

#[test]
fn test_constant_features_do_not_leak_across_sibling_nodes() {
let mut X = Array2::<f64>::zeros((10, 2));
for row in 5..10 {
X[[row, 0]] = 1.;
}
for row in 7..10 {
X[[row, 1]] = 1.;
}
let y = Array1::from_vec(vec![0., 0., 0., 0., 0., 9., 9., 11., 11., 11.]);

let mut tree = DecisionTree::default();
tree.fit(&X.view(), &y.view());
let predictions = tree.predict(&X.view());

assert_eq!(predictions, y);
}

#[test]
fn test_predicts_three_classes_on_separable_data() {
let n = 30;
let X = Array2::from_shape_fn((n, 1), |(i, _)| i as f64);
let y = Array1::from_shape_fn(n, |i| {
if i < 10 {
0.
} else if i < 20 {
1.
} else {
2.
}
});

let mut tree = DecisionTree::default();
tree.fit(&X.view(), &y.view());
let predictions = tree.predict(&X.view());

assert_eq!(predictions, y);
}

#[test]
fn test_refit_overwrites_previous_tree() {
let X = Array2::from_shape_fn((10, 1), |(i, _)| i as f64);
let y = Array1::from_shape_fn(10, |i| if i < 5 { 0. } else { 10. });

let mut tree = DecisionTree::default();
tree.fit(&X.view(), &y.view());
let predictions = tree.predict(&X.view());
assert_eq!(predictions, y);

let X_constant = Array2::<f64>::zeros((10, 1));
let y_nonconstant = Array1::from_iter(1..=10).mapv(|x| x as f64);
let expected = y_nonconstant.sum() / y_nonconstant.len() as f64;

tree.fit(&X_constant.view(), &y_nonconstant.view());
let predictions = tree.predict(&X_constant.view());

assert_eq!(predictions, Array1::<f64>::from_elem(10, expected));
}

#[test]
fn test_large_tree_multiclass_grid_predictions_are_stable() {
let blocks_per_dim = 8usize;
let samples_per_block_edge = 4usize;
let grid_edge = blocks_per_dim * samples_per_block_edge;
let n = grid_edge * grid_edge;

let mut X = Array2::<f64>::zeros((n, 2));
let mut y = Array1::<f64>::zeros(n);

let mut idx = 0;
for i in 0..grid_edge {
for j in 0..grid_edge {
X[[idx, 0]] = i as f64 + 0.5;
X[[idx, 1]] = j as f64 + 0.5;

let class_x = i / samples_per_block_edge;
let class_y = j / samples_per_block_edge;
y[idx] = (class_x + blocks_per_dim * class_y) as f64;

idx += 1;
}
}

let mut tree = DecisionTree::default();
tree.fit(&X.view(), &y.view());

// Fits the training grid exactly.
assert_eq!(tree.predict(&X.view()), y);

// Probes multiple points inside each class region (not part of training grid),
// to ensure robust classification within each block.
let probes_per_block = 4usize;
let n_probes = blocks_per_dim * blocks_per_dim * probes_per_block;
let mut X_probe = Array2::<f64>::zeros((n_probes, 2));
let mut y_probe = Array1::<f64>::zeros(n_probes);

let mut idx = 0;
for class_y in 0..blocks_per_dim {
for class_x in 0..blocks_per_dim {
let base_x = (class_x * samples_per_block_edge) as f64;
let base_y = (class_y * samples_per_block_edge) as f64;
let class = (class_x + blocks_per_dim * class_y) as f64;

let corners = [
(base_x + 0.1, base_y + 0.1),
(base_x + 0.1, base_y + 3.9),
(base_x + 3.9, base_y + 0.1),
(base_x + 3.9, base_y + 3.9),
];

for (x0, x1) in corners {
X_probe[[idx, 0]] = x0;
X_probe[[idx, 1]] = x1;
y_probe[idx] = class;
idx += 1;
}
}
}

let y_pred = tree.predict(&X_probe.view()).mapv(|x| x.round());
assert_eq!(y_pred, y_probe);

let predicted_classes: HashSet<i32> = y_pred.iter().map(|x| *x as i32).collect();
assert_eq!(predicted_classes.len(), blocks_per_dim * blocks_per_dim);
}
}
26 changes: 23 additions & 3 deletions src/tree/decision_tree_node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@ pub struct DecisionTreeNode {

impl DecisionTreeNode {
fn leaf_node(&mut self, label: f64) {
// Ensure any previous split state is cleared. This matters when the same
// `DecisionTreeNode` instance is reused (e.g. refitting a tree).
self.left_child = None;
self.right_child = None;
self.feature_index = None;
self.feature_value = None;
self.label = Some(label);
}

Expand Down Expand Up @@ -75,8 +81,14 @@ impl DecisionTreeNode {
continue;
}

let (split, split_val, gain, left_sum) =
self.find_best_split(X, y, feature, samples[feature], sum);
let (split, split_val, gain, left_sum) = self.find_best_split(
X,
y,
feature,
samples[feature],
sum,
parameters.min_samples_leaf,
);

if gain > best_gain {
best_gain = gain;
Expand Down Expand Up @@ -141,8 +153,15 @@ impl DecisionTreeNode {
feature: usize,
samples: &[usize],
sum: f64,
min_samples_leaf: usize,
) -> (usize, f64, f64, f64) {
let n = samples.len();
// If we can't create two children with at least `min_samples_leaf` samples,
// no split is valid.
if n < 2 * min_samples_leaf {
return (0, 0., 0., 0.);
}

let mut cumsum = 0.;
let mut max_proxy_gain = 0.;
let mut proxy_gain: f64;
Expand Down Expand Up @@ -370,6 +389,7 @@ mod tests {
feature,
&samples,
y.slice(s![start..stop]).sum(),
1,
);

assert_eq!((expected_split, expected_split_val), (split, split_val));
Expand All @@ -388,7 +408,7 @@ mod tests {
samples.sort_unstable_by(|a, b| X[[*a, 0]].partial_cmp(&X[[*b, 0]]).unwrap());

let (split, split_val, gain, sum) =
node.find_best_split(&X.view(), &y.view(), 0, &samples, 0.);
node.find_best_split(&X.view(), &y.view(), 0, &samples, 0., 1);
assert_eq!((split, split_val, gain, sum), (0, 0., 0., 0.));
}

Expand Down
Loading