1
1
# -*- coding: utf-8 -*-
2
2
"""PyMilo Tree(from sklearn.tree._tree) object transporter."""
3
+ import numpy as np
3
4
from sklearn .tree ._tree import Tree
4
-
5
+ from .. pymilo_param import NUMPY_TYPE_DICT
5
6
from .transporter import AbstractTransporter
7
+ from ..utils .util import check_str_in_iterable
6
8
from .general_data_structure_transporter import GeneralDataStructureTransporter
7
- from ..pymilo_param import NUMPY_TYPE_DICT
8
-
9
- import numpy as np
10
9
11
10
12
11
class TreeTransporter (AbstractTransporter ):
@@ -28,10 +27,9 @@ def serialize(self, data, key, model_type):
28
27
gdst = GeneralDataStructureTransporter ()
29
28
tree = data [key ]
30
29
tree_inner_state = tree .__getstate__ ()
31
-
32
30
data [key ] = {
33
31
'pymilo-bypass' : True ,
34
- 'params ' : {
32
+ 'pymilo-tree ' : {
35
33
'internal_state' : {
36
34
"max_depth" : tree_inner_state ["max_depth" ],
37
35
"node_count" : tree_inner_state ["node_count" ],
@@ -47,7 +45,6 @@ def serialize(self, data, key, model_type):
47
45
'n_outputs' : tree .n_outputs ,
48
46
}
49
47
}
50
-
51
48
return data [key ]
52
49
53
50
def deserialize (self , data , key , model_type ):
@@ -69,19 +66,12 @@ def deserialize(self, data, key, model_type):
69
66
:return: pymilo deserialized output of data[key]
70
67
"""
71
68
content = data [key ]
72
-
73
- if (key == "tree_" and
74
- (model_type == "DecisionTreeRegressor"
75
- or model_type == "DecisionTreeClassifier"
76
- or model_type == "ExtraTreeRegressor"
77
- or model_type == "ExtraTreeClassifier"
78
- )):
69
+ if check_str_in_iterable ('pymilo-tree' , content ):
79
70
gdst = GeneralDataStructureTransporter ()
80
- tree_params = content ['params' ]
81
-
71
+ tree_params = content ['pymilo-tree' ]
82
72
tree_internal_state = tree_params ["internal_state" ]
83
-
84
73
nodes_dtype_spec = []
74
+
85
75
for idx , node_type in enumerate (tree_internal_state ["nodes" ]["types" ]):
86
76
nodes_dtype_spec .append (
87
77
(tree_internal_state ["nodes" ]["field-names" ][idx ], NUMPY_TYPE_DICT ["numpy." + node_type ]))
@@ -106,10 +96,7 @@ def deserialize(self, data, key, model_type):
106
96
n_classes ,
107
97
tree_params ["n_outputs" ]
108
98
)
109
-
110
99
_tree .__setstate__ (tree_internal_state )
111
-
112
100
return _tree
113
-
114
101
else :
115
102
return content
0 commit comments