Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ name = "biosphere"
ndarray = "0.17.1"
rand = "0.9"
rayon = "1.11"
serde = { version = "1", features = ["derive"], optional = true }

[features]
serde = ["dep:serde"]

[profile.bench]
incremental = true
Expand All @@ -30,6 +34,7 @@ csv = "^1"
ndarray-csv = "^0.5"
criterion = "0.8"
assert_approx_eq = "1.1"
postcard = { version = "1", features = ["use-std"] }

[[bench]]
name = "bench_utils"
Expand All @@ -42,3 +47,8 @@ harness = false
[[bench]]
name = "bench_tree"
harness = false

[[bench]]
name = "bench_tree_serde"
harness = false
required-features = ["serde"]
31 changes: 30 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,35 @@ Random forests with a runtime of `O(n d log(n) + n_estimators d n max_depth)` in

`biosphere` is available as a rust crate and as a Python package.

## Serialize / deserialize a `DecisionTree`

Enable the `serde` feature and choose a serde format (here: `postcard`):

```toml
# Cargo.toml
biosphere = { version = "0.4.2", features = ["serde"] }
postcard = { version = "1", features = ["use-std"] }
```

```rust
use biosphere::DecisionTree;

let X = ndarray::array![[0.0], [1.0], [2.0], [3.0]];
let y = ndarray::array![0.0, 0.0, 1.0, 1.0];

let mut tree = DecisionTree::default();
tree.fit(&X.view(), &y.view());

// serialize and deserialize the tree
let bytes = postcard::to_stdvec(&tree).unwrap();
// deserialize the tree from bytes
let restored: DecisionTree = postcard::from_bytes(&bytes).unwrap();

assert_eq!(tree.predict(&X.view()), restored.predict(&X.view()));
```

In this repo you can run: `cargo run --example decision_tree_serde --features serde`.

## Benchmarks

Ran on an M1 Pro with `n_jobs=4`. Wall-time to fit a Random Forest including OOB score with 400 trees to
Expand All @@ -16,4 +45,4 @@ features.
| model | 1000 | 2000 | 4000 | 8000 | 16000 | 32000 | 64000 | 128000 | 256000 | 512000 | 1024000 | 2048000 |
|:-------------|:-------|:-------|:-------|:-------|:--------|:--------|:--------|:---------|:---------|:---------|:----------|:----------|
| biosphere | 0.04s | 0.08s | 0.15s | 0.32s | 0.65s | 1.40s | 2.97s | 6.48s | 15.53s | 37.91s | 96.69s | 231.82s |
| scikit-learn | 0.28s | 0.34s | 0.46s | 0.69s | 1.23s | 2.47s | 4.99s | 10.49s | 22.11s | 51.04s | 118.95s | 271.03s |
| scikit-learn | 0.28s | 0.34s | 0.46s | 0.69s | 1.23s | 2.47s | 4.99s | 10.49s | 22.11s | 51.04s | 118.95s | 271.03s |
72 changes: 72 additions & 0 deletions benches/bench_tree_serde.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
use biosphere::{DecisionTree, DecisionTreeParameters, MaxFeatures};

#[cfg(test)]
use criterion::{criterion_group, criterion_main, Criterion};
use ndarray::{Array1, Array2};
use rand::rngs::StdRng;
use rand::Rng;
use rand::SeedableRng;
use std::hint::black_box;

#[allow(non_snake_case)]
pub fn data(n: usize, d: usize, rng: &mut impl Rng) -> (Array2<f64>, Array1<f64>) {
let X = Array2::from_shape_fn((n, d), |_| rng.random::<f64>());

let y = Array1::from_shape_fn(n, |_| rng.random::<f64>());
let y = y + X.column(0) + X.column(1).map(|x| x - x * x);

(X, y)
}

#[allow(non_snake_case)]
pub fn benchmark_tree_fit_vs_deserialize(c: &mut Criterion) {
let seed = 0;
let mut rng = StdRng::seed_from_u64(seed);

let n = 10_000usize;
let d = 10usize;
let max_depth = 8usize;

let (X, y) = data(n, d, &mut rng);
let X_view = X.view();
let y_view = y.view();

let parameters = DecisionTreeParameters::default()
.with_max_depth(Some(max_depth))
.with_max_features(MaxFeatures::Value(d))
.with_random_state(seed);

let mut fitted_tree = DecisionTree::new(parameters.clone());
fitted_tree.fit(&X_view, &y_view);
let fitted_tree_bytes = postcard::to_stdvec(&fitted_tree).unwrap();

let mut group = c.benchmark_group("tree_fit_vs_deserialize");
group.bench_function(format!("fit_n={n}, d={d}, max_depth={max_depth}"), |b| {
b.iter(|| {
let mut tree = DecisionTree::new(parameters.clone());
tree.fit(&X_view, &y_view);
black_box(tree);
})
});

group.bench_function(
format!("deserialize_n={n}, d={d}, max_depth={max_depth}"),
|b| {
b.iter(|| {
let tree: DecisionTree =
postcard::from_bytes(black_box(fitted_tree_bytes.as_slice())).unwrap();
black_box(tree);
})
},
);

group.finish();
}

criterion_group!(
name = tree_serde;
config = Criterion::default().sample_size(10);
targets = benchmark_tree_fit_vs_deserialize
);

criterion_main!(tree_serde);
26 changes: 26 additions & 0 deletions examples/decision_tree_serde.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#![allow(non_snake_case)]

#[cfg(not(feature = "serde"))]
fn main() {
eprintln!(
"This example requires the `serde` feature.\n\nRun:\n cargo run --example decision_tree_serde --features serde"
);
}

#[cfg(feature = "serde")]
fn main() -> Result<(), Box<dyn std::error::Error>> {
use biosphere::DecisionTree;

let X = ndarray::array![[0.0], [1.0], [2.0], [3.0]];
let y = ndarray::array![0.0, 0.0, 1.0, 1.0];

let mut tree = DecisionTree::default();
tree.fit(&X.view(), &y.view());

let bytes = postcard::to_stdvec(&tree)?;
let restored: DecisionTree = postcard::from_bytes(&bytes)?;

assert_eq!(tree.predict(&X.view()), restored.predict(&X.view()));
println!("Serialized {} bytes", bytes.len());
Ok(())
}
22 changes: 22 additions & 0 deletions src/tree/decision_tree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use ndarray::{Array1, ArrayView1, ArrayView2};
use rand::rngs::StdRng;
use rand::SeedableRng;

#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
pub struct DecisionTree {
decision_tree_parameters: DecisionTreeParameters,
node: DecisionTreeNode,
Expand Down Expand Up @@ -122,3 +123,24 @@ mod tests {
assert_eq!(predictions - another_predictions, Array1::<f64>::zeros(150));
}
}

#[cfg(feature = "serde")]
#[test]
fn test_serialized_deserialized_tree_predicts_same_as_fit_tree() {
let data = load_iris();
let X = data.slice(s![.., 0..4]);
let y = data.slice(s![.., 4]);

let parameters = DecisionTreeParameters::default()
.with_max_depth(Some(4))
.with_max_features(MaxFeatures::Value(2))
.with_random_state(123);
let mut tree = DecisionTree::new(parameters);
tree.fit(&X, &y);
let predictions = tree.predict(&X);

let bytes = postcard::to_stdvec(&tree).unwrap();
let restored_tree: DecisionTree = postcard::from_bytes(bytes.as_slice()).unwrap();

assert_eq!(predictions, restored_tree.predict(&X));
}
1 change: 1 addition & 0 deletions src/tree/decision_tree_node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ static MIN_GAIN_TO_SPLIT: f64 = 1e-12;
static FEATURE_THRESHOLD: f64 = 1e-14;

#[derive(Default)]
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
pub struct DecisionTreeNode {
pub left_child: Option<Box<DecisionTreeNode>>,
pub right_child: Option<Box<DecisionTreeNode>>,
Expand Down
58 changes: 58 additions & 0 deletions src/tree/decision_tree_parameters.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,63 @@ pub enum MaxFeatures {
Callable(fn(usize) -> usize),
}

#[cfg(feature = "serde")]
impl serde::Serialize for MaxFeatures {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
#[derive(serde::Serialize)]
enum MaxFeaturesRepr {
None,
Fraction(f64),
Value(usize),
Sqrt,
}

let repr = match self {
MaxFeatures::None => MaxFeaturesRepr::None,
MaxFeatures::Fraction(fraction) => MaxFeaturesRepr::Fraction(*fraction),
MaxFeatures::Value(value) => MaxFeaturesRepr::Value(*value),
MaxFeatures::Sqrt => MaxFeaturesRepr::Sqrt,
MaxFeatures::Callable(_) => {
return Err(serde::ser::Error::custom(
"MaxFeatures::Callable cannot be serialized",
));
}
};

serde::Serialize::serialize(&repr, serializer)
}
}

#[cfg(feature = "serde")]
impl<'de> serde::Deserialize<'de> for MaxFeatures {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
#[derive(serde::Deserialize)]
enum MaxFeaturesRepr {
None,
Fraction(f64),
Value(usize),
Sqrt,
Callable,
}

match <MaxFeaturesRepr as serde::Deserialize>::deserialize(deserializer)? {
MaxFeaturesRepr::None => Ok(MaxFeatures::None),
MaxFeaturesRepr::Fraction(fraction) => Ok(MaxFeatures::Fraction(fraction)),
MaxFeaturesRepr::Value(value) => Ok(MaxFeatures::Value(value)),
MaxFeaturesRepr::Sqrt => Ok(MaxFeatures::Sqrt),
MaxFeaturesRepr::Callable => Err(serde::de::Error::custom(
"MaxFeatures::Callable cannot be deserialized",
)),
}
}
}

impl MaxFeatures {
pub fn from_n_features(&self, n_features: usize) -> usize {
let value = match self {
Expand All @@ -28,6 +85,7 @@ impl MaxFeatures {
}

#[derive(Clone, Debug)]
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
pub struct DecisionTreeParameters {
// Maximum depth of the tree. If `None`, nodes are expanded until all leaves are
// pure or contain fewer than `min_samples_split` samples.
Expand Down
Loading