Skip to content

Commit c83f9e3

Browse files
committed
feature: support serialization and deserialization
1 parent 23079f4 commit c83f9e3

File tree

7 files changed

+221
-3
lines changed

7 files changed

+221
-3
lines changed

Cargo.toml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,10 @@ name = "biosphere"
1717
ndarray = "0.17.2"
1818
rand = "0.9"
1919
rayon = "1.11"
20+
serde = { version = "1", features = ["derive"], optional = true }
21+
22+
[features]
23+
serde = ["dep:serde"]
2024

2125
[profile.bench]
2226
incremental = true
@@ -29,6 +33,7 @@ csv = "^1"
2933
ndarray-csv = "^0.5"
3034
criterion = "0.8"
3135
assert_approx_eq = "1.1"
36+
postcard = { version = "1", features = ["use-std"] }
3237

3338
[[bench]]
3439
name = "bench_utils"
@@ -45,3 +50,8 @@ harness = false
4550
[[bench]]
4651
name = "bench_ops"
4752
harness = false
53+
54+
[[bench]]
55+
name = "bench_tree_serde"
56+
harness = false
57+
required-features = ["serde"]

README.md

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,35 @@ Random forests with a runtime of `O(n d log(n) + n_estimators d n max_depth)` in
66

77
`biosphere` is available as a rust crate and as a Python package.
88

9+
## Serialize / deserialize a `DecisionTree`
10+
11+
Enable the `serde` feature and choose a serde format (here: `postcard`):
12+
13+
```toml
14+
# Cargo.toml
15+
biosphere = { version = "0.4.2", features = ["serde"] }
16+
postcard = { version = "1", features = ["use-std"] }
17+
```
18+
19+
```rust
20+
use biosphere::DecisionTree;
21+
22+
let X = ndarray::array![[0.0], [1.0], [2.0], [3.0]];
23+
let y = ndarray::array![0.0, 0.0, 1.0, 1.0];
24+
25+
let mut tree = DecisionTree::default();
26+
tree.fit(&X.view(), &y.view());
27+
28+
// serialize and deserialize the tree
29+
let bytes = postcard::to_stdvec(&tree).unwrap();
30+
// deserialize the tree from bytes
31+
let restored: DecisionTree = postcard::from_bytes(&bytes).unwrap();
32+
33+
assert_eq!(tree.predict(&X.view()), restored.predict(&X.view()));
34+
```
35+
36+
In this repo you can run: `cargo run --example decision_tree_serde --features serde`.
37+
938
## Benchmarks
1039

1140
Ran on an M1 Pro with `n_jobs=4`. Wall-time to fit a Random Forest including OOB score with 400 trees to
@@ -16,4 +45,4 @@ features.
1645
| model | 1000 | 2000 | 4000 | 8000 | 16000 | 32000 | 64000 | 128000 | 256000 | 512000 | 1024000 | 2048000 |
1746
|:-------------|:-------|:-------|:-------|:-------|:--------|:--------|:--------|:---------|:---------|:---------|:----------|:----------|
1847
| biosphere | 0.04s | 0.08s | 0.15s | 0.32s | 0.65s | 1.40s | 2.97s | 6.48s | 15.53s | 37.91s | 96.69s | 231.82s |
19-
| scikit-learn | 0.28s | 0.34s | 0.46s | 0.69s | 1.23s | 2.47s | 4.99s | 10.49s | 22.11s | 51.04s | 118.95s | 271.03s |
48+
| scikit-learn | 0.28s | 0.34s | 0.46s | 0.69s | 1.23s | 2.47s | 4.99s | 10.49s | 22.11s | 51.04s | 118.95s | 271.03s |

benches/bench_tree_serde.rs

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
use biosphere::{DecisionTree, DecisionTreeParameters, MaxFeatures};
2+
3+
#[cfg(test)]
4+
use criterion::{criterion_group, criterion_main, Criterion};
5+
use ndarray::{Array1, Array2};
6+
use rand::rngs::StdRng;
7+
use rand::Rng;
8+
use rand::SeedableRng;
9+
use std::hint::black_box;
10+
11+
#[allow(non_snake_case)]
12+
pub fn data(n: usize, d: usize, rng: &mut impl Rng) -> (Array2<f64>, Array1<f64>) {
13+
let X = Array2::from_shape_fn((n, d), |_| rng.random::<f64>());
14+
15+
let y = Array1::from_shape_fn(n, |_| rng.random::<f64>());
16+
let y = y + X.column(0) + X.column(1).map(|x| x - x * x);
17+
18+
(X, y)
19+
}
20+
21+
#[allow(non_snake_case)]
22+
pub fn benchmark_tree_fit_vs_deserialize(c: &mut Criterion) {
23+
let seed = 0;
24+
let mut rng = StdRng::seed_from_u64(seed);
25+
26+
let n = 10_000usize;
27+
let d = 10usize;
28+
let max_depth = 8usize;
29+
30+
let (X, y) = data(n, d, &mut rng);
31+
let X_view = X.view();
32+
let y_view = y.view();
33+
34+
let parameters = DecisionTreeParameters::default()
35+
.with_max_depth(Some(max_depth))
36+
.with_max_features(MaxFeatures::Value(d))
37+
.with_random_state(seed);
38+
39+
let mut fitted_tree = DecisionTree::new(parameters.clone());
40+
fitted_tree.fit(&X_view, &y_view);
41+
let fitted_tree_bytes = postcard::to_stdvec(&fitted_tree).unwrap();
42+
43+
let mut group = c.benchmark_group("tree_fit_vs_deserialize");
44+
group.bench_function(format!("fit_n={n}, d={d}, max_depth={max_depth}"), |b| {
45+
b.iter(|| {
46+
let mut tree = DecisionTree::new(parameters.clone());
47+
tree.fit(&X_view, &y_view);
48+
black_box(tree);
49+
})
50+
});
51+
52+
group.bench_function(
53+
format!("deserialize_n={n}, d={d}, max_depth={max_depth}"),
54+
|b| {
55+
b.iter(|| {
56+
let tree: DecisionTree =
57+
postcard::from_bytes(black_box(fitted_tree_bytes.as_slice())).unwrap();
58+
black_box(tree);
59+
})
60+
},
61+
);
62+
63+
group.finish();
64+
}
65+
66+
criterion_group!(
67+
name = tree_serde;
68+
config = Criterion::default().sample_size(10);
69+
targets = benchmark_tree_fit_vs_deserialize
70+
);
71+
72+
criterion_main!(tree_serde);

examples/decision_tree_serde.rs

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
#![allow(non_snake_case)]
2+
3+
#[cfg(not(feature = "serde"))]
4+
fn main() {
5+
eprintln!(
6+
"This example requires the `serde` feature.\n\nRun:\n cargo run --example decision_tree_serde --features serde"
7+
);
8+
}
9+
10+
#[cfg(feature = "serde")]
11+
fn main() -> Result<(), Box<dyn std::error::Error>> {
12+
use biosphere::DecisionTree;
13+
14+
let X = ndarray::array![[0.0], [1.0], [2.0], [3.0]];
15+
let y = ndarray::array![0.0, 0.0, 1.0, 1.0];
16+
17+
let mut tree = DecisionTree::default();
18+
tree.fit(&X.view(), &y.view());
19+
20+
let bytes = postcard::to_stdvec(&tree)?;
21+
let restored: DecisionTree = postcard::from_bytes(&bytes)?;
22+
23+
assert_eq!(tree.predict(&X.view()), restored.predict(&X.view()));
24+
println!("Serialized {} bytes", bytes.len());
25+
Ok(())
26+
}

src/tree/decision_tree.rs

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ use ndarray::{Array1, ArrayView1, ArrayView2};
55
use rand::rngs::StdRng;
66
use rand::SeedableRng;
77

8+
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
89
pub struct DecisionTree {
910
decision_tree_parameters: DecisionTreeParameters,
1011
node: DecisionTreeNode,
@@ -95,12 +96,12 @@ impl DecisionTree {
9596

9697
#[cfg(test)]
9798
mod tests {
98-
use std::collections::HashSet;
9999
use super::*;
100100
use crate::testing::load_iris;
101+
use crate::MaxFeatures;
101102
use ndarray::{s, Array2};
102103
use rstest::*;
103-
use crate::MaxFeatures;
104+
use std::collections::HashSet;
104105

105106
#[rstest]
106107
fn test_tree() {
@@ -288,4 +289,25 @@ mod tests {
288289
let predicted_classes: HashSet<i32> = y_pred.iter().map(|x| *x as i32).collect();
289290
assert_eq!(predicted_classes.len(), blocks_per_dim * blocks_per_dim);
290291
}
292+
293+
#[cfg(feature = "serde")]
294+
#[test]
295+
fn test_serialized_deserialized_tree_predicts_same_as_fit_tree() {
296+
let data = load_iris();
297+
let X = data.slice(s![.., 0..4]);
298+
let y = data.slice(s![.., 4]);
299+
300+
let parameters = DecisionTreeParameters::default()
301+
.with_max_depth(Some(4))
302+
.with_max_features(MaxFeatures::Value(2))
303+
.with_random_state(123);
304+
let mut tree = DecisionTree::new(parameters);
305+
tree.fit(&X, &y);
306+
let predictions = tree.predict(&X);
307+
308+
let bytes = postcard::to_stdvec(&tree).unwrap();
309+
let restored_tree: DecisionTree = postcard::from_bytes(bytes.as_slice()).unwrap();
310+
311+
assert_eq!(predictions, restored_tree.predict(&X));
312+
}
291313
}

src/tree/decision_tree_node.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ static MIN_GAIN_TO_SPLIT: f64 = 1e-12;
88
static FEATURE_THRESHOLD: f64 = 1e-14;
99

1010
#[derive(Default)]
11+
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
1112
pub struct DecisionTreeNode {
1213
pub left_child: Option<Box<DecisionTreeNode>>,
1314
pub right_child: Option<Box<DecisionTreeNode>>,

src/tree/decision_tree_parameters.rs

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,63 @@ pub enum MaxFeatures {
1313
Callable(fn(usize) -> usize),
1414
}
1515

16+
#[cfg(feature = "serde")]
17+
impl serde::Serialize for MaxFeatures {
18+
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
19+
where
20+
S: serde::Serializer,
21+
{
22+
#[derive(serde::Serialize)]
23+
enum MaxFeaturesRepr {
24+
None,
25+
Fraction(f64),
26+
Value(usize),
27+
Sqrt,
28+
}
29+
30+
let repr = match self {
31+
MaxFeatures::None => MaxFeaturesRepr::None,
32+
MaxFeatures::Fraction(fraction) => MaxFeaturesRepr::Fraction(*fraction),
33+
MaxFeatures::Value(value) => MaxFeaturesRepr::Value(*value),
34+
MaxFeatures::Sqrt => MaxFeaturesRepr::Sqrt,
35+
MaxFeatures::Callable(_) => {
36+
return Err(serde::ser::Error::custom(
37+
"MaxFeatures::Callable cannot be serialized",
38+
));
39+
}
40+
};
41+
42+
serde::Serialize::serialize(&repr, serializer)
43+
}
44+
}
45+
46+
#[cfg(feature = "serde")]
47+
impl<'de> serde::Deserialize<'de> for MaxFeatures {
48+
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
49+
where
50+
D: serde::Deserializer<'de>,
51+
{
52+
#[derive(serde::Deserialize)]
53+
enum MaxFeaturesRepr {
54+
None,
55+
Fraction(f64),
56+
Value(usize),
57+
Sqrt,
58+
Callable,
59+
}
60+
61+
match <MaxFeaturesRepr as serde::Deserialize>::deserialize(deserializer)? {
62+
MaxFeaturesRepr::None => Ok(MaxFeatures::None),
63+
MaxFeaturesRepr::Fraction(fraction) => Ok(MaxFeatures::Fraction(fraction)),
64+
MaxFeaturesRepr::Value(value) => Ok(MaxFeatures::Value(value)),
65+
MaxFeaturesRepr::Sqrt => Ok(MaxFeatures::Sqrt),
66+
MaxFeaturesRepr::Callable => Err(serde::de::Error::custom(
67+
"MaxFeatures::Callable cannot be deserialized",
68+
)),
69+
}
70+
}
71+
}
72+
1673
impl MaxFeatures {
1774
pub fn from_n_features(&self, n_features: usize) -> usize {
1875
let value = match self {
@@ -28,6 +85,7 @@ impl MaxFeatures {
2885
}
2986

3087
#[derive(Clone, Debug)]
88+
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
3189
pub struct DecisionTreeParameters {
3290
// Maximum depth of the tree. If `None`, nodes are expanded until all leaves are
3391
// pure or contain fewer than `min_samples_split` samples.

0 commit comments

Comments
 (0)