@@ -13,6 +13,63 @@ pub enum MaxFeatures {
1313 Callable ( fn ( usize ) -> usize ) ,
1414}
1515
16+ #[ cfg( feature = "serde" ) ]
17+ impl serde:: Serialize for MaxFeatures {
18+ fn serialize < S > ( & self , serializer : S ) -> Result < S :: Ok , S :: Error >
19+ where
20+ S : serde:: Serializer ,
21+ {
22+ #[ derive( serde:: Serialize ) ]
23+ enum MaxFeaturesRepr {
24+ None ,
25+ Fraction ( f64 ) ,
26+ Value ( usize ) ,
27+ Sqrt ,
28+ }
29+
30+ let repr = match self {
31+ MaxFeatures :: None => MaxFeaturesRepr :: None ,
32+ MaxFeatures :: Fraction ( fraction) => MaxFeaturesRepr :: Fraction ( * fraction) ,
33+ MaxFeatures :: Value ( value) => MaxFeaturesRepr :: Value ( * value) ,
34+ MaxFeatures :: Sqrt => MaxFeaturesRepr :: Sqrt ,
35+ MaxFeatures :: Callable ( _) => {
36+ return Err ( serde:: ser:: Error :: custom (
37+ "MaxFeatures::Callable cannot be serialized" ,
38+ ) ) ;
39+ }
40+ } ;
41+
42+ serde:: Serialize :: serialize ( & repr, serializer)
43+ }
44+ }
45+
46+ #[ cfg( feature = "serde" ) ]
47+ impl < ' de > serde:: Deserialize < ' de > for MaxFeatures {
48+ fn deserialize < D > ( deserializer : D ) -> Result < Self , D :: Error >
49+ where
50+ D : serde:: Deserializer < ' de > ,
51+ {
52+ #[ derive( serde:: Deserialize ) ]
53+ enum MaxFeaturesRepr {
54+ None ,
55+ Fraction ( f64 ) ,
56+ Value ( usize ) ,
57+ Sqrt ,
58+ Callable ,
59+ }
60+
61+ match <MaxFeaturesRepr as serde:: Deserialize >:: deserialize ( deserializer) ? {
62+ MaxFeaturesRepr :: None => Ok ( MaxFeatures :: None ) ,
63+ MaxFeaturesRepr :: Fraction ( fraction) => Ok ( MaxFeatures :: Fraction ( fraction) ) ,
64+ MaxFeaturesRepr :: Value ( value) => Ok ( MaxFeatures :: Value ( value) ) ,
65+ MaxFeaturesRepr :: Sqrt => Ok ( MaxFeatures :: Sqrt ) ,
66+ MaxFeaturesRepr :: Callable => Err ( serde:: de:: Error :: custom (
67+ "MaxFeatures::Callable cannot be deserialized" ,
68+ ) ) ,
69+ }
70+ }
71+ }
72+
1673impl MaxFeatures {
1774 pub fn from_n_features ( & self , n_features : usize ) -> usize {
1875 let value = match self {
@@ -28,6 +85,7 @@ impl MaxFeatures {
2885}
2986
3087#[ derive( Clone , Debug ) ]
88+ #[ cfg_attr( feature = "serde" , derive( serde:: Deserialize , serde:: Serialize ) ) ]
3189pub struct DecisionTreeParameters {
3290 // Maximum depth of the tree. If `None`, nodes are expanded until all leaves are
3391 // pure or contain fewer than `min_samples_split` samples.
0 commit comments