Skip to content

Commit a3cffac

Browse files
authored
minor refactorings in naming conventions (#116)
1 parent 5966627 commit a3cffac

File tree

7 files changed

+47
-66
lines changed

7 files changed

+47
-66
lines changed

pymilo/chains/ensemble_chain.py

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -188,19 +188,16 @@ def serialize_ensemble(ensemble_object):
188188
}
189189

190190
elif isinstance(value, dict):
191-
if not check_str_in_iterable("pymilo-bypass", value):
192-
if check_str_in_iterable(
193-
"pymiloed-data-structure",
194-
value) and value["pymiloed-data-structure"] == "Bunch":
195-
new_value = {}
196-
for inner_key, inner_value in value["pymiloed-data"].items():
197-
new_value[inner_key] = serialize_possible_ml_model(inner_value)[1]
198-
value["pymiloed-data"] = new_value
199-
else:
200-
new_value = {}
201-
for inner_key, inner_value in value.items():
202-
new_value[inner_key] = serialize_possible_ml_model(inner_value)[1]
203-
ensemble_object.__dict__[key] = new_value
191+
if check_str_in_iterable("pymilo-bunch", value):
192+
new_value = {}
193+
for inner_key, inner_value in value["pymilo-bunch"].items():
194+
new_value[inner_key] = serialize_possible_ml_model(inner_value)[1]
195+
value["pymilo-bunch"] = new_value
196+
else:
197+
new_value = {}
198+
for inner_key, inner_value in value.items():
199+
new_value[inner_key] = serialize_possible_ml_model(inner_value)[1]
200+
ensemble_object.__dict__[key] = new_value
204201

205202
elif isinstance(value, ndarray):
206203
has_inner_model, result = serialize_models_in_ndarray(value)

pymilo/chains/linear_model_chain.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def serialize_linear_model(linear_model_object):
104104
linear_model_object.__dict__[key] = {
105105
"pymilo-inner-model-data": transport_linear_model(linear_model_object.__dict__[key], Command.SERIALIZE, True),
106106
"pymilo-inner-model-type": get_sklearn_type(linear_model_object.__dict__[key]),
107-
"pymilo-by-pass": True
107+
"pymilo-bypass": True
108108
}
109109
# now serializing non-linear model fields
110110
for transporter in LINEAR_MODEL_CHAIN:

pymilo/transporters/adamoptimizer_transporter.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
# -*- coding: utf-8 -*-
22
"""PyMilo Adam optimizer object transporter."""
3-
from sklearn.neural_network._stochastic_optimizers import AdamOptimizer
43
from .transporter import AbstractTransporter
4+
from ..utils.util import check_str_in_iterable
5+
from sklearn.neural_network._stochastic_optimizers import AdamOptimizer
56

67

78
class AdamOptimizerTransporter(AbstractTransporter):
@@ -24,7 +25,8 @@ def serialize(self, data, key, model_type):
2425
if isinstance(data[key], AdamOptimizer):
2526
optimizer = data[key]
2627
data[key] = {
27-
'params': {
28+
"pymilo-bypass": True,
29+
'pymilo-adamoptimizer': {
2830
"params": data["coefs_"] + data["intercepts_"],
2931
'type': "AdamOptimizer",
3032
'beta_1': optimizer.beta_1,
@@ -55,10 +57,8 @@ def deserialize(self, data, key, model_type):
5557
:return: pymilo deserialized output of data[key]
5658
"""
5759
content = data[key]
58-
59-
if (key == "_optimizer" and (model_type ==
60-
"MLPRegressor" or model_type == "MLPClassifier")):
61-
optimizer = content['params']
60+
if check_str_in_iterable("pymilo-adamoptimizer", content):
61+
optimizer = content["pymilo-adamoptimizer"]
6262
if (optimizer["type"] == "AdamOptimizer"):
6363
return AdamOptimizer(
6464
params=optimizer["params"],

pymilo/transporters/bunch_transporter.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
# -*- coding: utf-8 -*-
22
"""PyMilo Bunch transporter."""
3-
4-
from ..utils.util import check_str_in_iterable
53
from .transporter import AbstractTransporter
4+
from ..utils.util import check_str_in_iterable
5+
66
bunch_support = False
77
try:
88
from sklearn.utils._bunch import Bunch
@@ -28,13 +28,13 @@ def serialize(self, data, key, model_type):
2828
"""
2929
if bunch_support and isinstance(data[key], Bunch):
3030
bunch = data[key]
31-
dicted_bunch = {}
32-
dicted_bunch["pymiloed-data-structure"] = "Bunch"
3331
_dict = {}
3432
for key, value in bunch.items():
3533
_dict[key] = value
36-
dicted_bunch["pymiloed-data"] = _dict
37-
return dicted_bunch
34+
return {
35+
"pymilo-bypass": True,
36+
"pymilo-bunch": _dict,
37+
}
3838

3939
return data[key]
4040

@@ -56,12 +56,9 @@ def deserialize(self, data, key, model_type):
5656
:return: pymilo deserialized output of data[key]
5757
"""
5858
content = data[key]
59-
if bunch_support and check_str_in_iterable(
60-
"pymiloed-data-structure",
61-
content) and content["pymiloed-data-structure"] == "Bunch":
59+
if bunch_support and check_str_in_iterable("pymilo-bunch", content):
6260
bunch = Bunch()
63-
dicted_bunch = content["pymiloed-data"]
64-
for key, value in dicted_bunch.items():
61+
for key, value in content["pymilo-bunch"].items():
6562
bunch[key] = value
6663
return bunch
6764
else:

pymilo/transporters/randomstate_transporter.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
"""PyMilo RandomState(MT19937) object transporter."""
33
import numpy as np
44
from .transporter import AbstractTransporter
5-
5+
from ..utils.util import check_str_in_iterable
66

77
class RandomStateTransporter(AbstractTransporter):
88
"""Customized PyMilo Transporter developed to handle RandomState field."""
@@ -24,14 +24,15 @@ def serialize(self, data, key, model_type):
2424
if isinstance(data[key], np.random.RandomState):
2525
inner_random_state = data[key]
2626
data[key] = {
27-
'state': (
27+
"pymilo-bypass": True,
28+
"pymilo-randomstate": (
2829
inner_random_state.get_state()[0],
2930
inner_random_state.get_state()[1].tolist(),
3031
inner_random_state.get_state()[2],
3132
inner_random_state.get_state()[3],
32-
inner_random_state.get_state()[4]
33-
)
34-
}
33+
inner_random_state.get_state()[4],
34+
),
35+
}
3536
return data[key]
3637

3738
def deserialize(self, data, key, model_type):
@@ -56,15 +57,14 @@ def deserialize(self, data, key, model_type):
5657
"""
5758
content = data[key]
5859

59-
if key == "_random_state" and (
60-
model_type == "MLPRegressor" or model_type == "MLPClassifier"):
61-
inner_random_state = content['state']
60+
if check_str_in_iterable("pymilo-randomstate", content):
61+
inner_random_state = content["pymilo-randomstate"]
6262
inner_random_state = (
6363
inner_random_state[0],
6464
np.array(inner_random_state[1]),
6565
inner_random_state[2],
6666
inner_random_state[3],
67-
inner_random_state[4]
67+
inner_random_state[4],
6868
)
6969
_random_state = np.random.RandomState()
7070
_random_state.set_state(inner_random_state)

pymilo/transporters/sgdoptimizer_transporter.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
# -*- coding: utf-8 -*-
22
"""PyMilo SGDOptimizer object transporter."""
3-
from sklearn.neural_network._stochastic_optimizers import SGDOptimizer
43
from .transporter import AbstractTransporter
4+
from ..utils.util import check_str_in_iterable
5+
from sklearn.neural_network._stochastic_optimizers import SGDOptimizer
56

67

78
class SGDOptimizerTransporter(AbstractTransporter):
@@ -24,7 +25,8 @@ def serialize(self, data, key, model_type):
2425
if isinstance(data[key], SGDOptimizer):
2526
optimizer = data[key]
2627
data[key] = {
27-
'params': {
28+
"pymilo-bypass": True,
29+
'pymilo-sgdoptimizer': {
2830
'type': "SGDOptimizer",
2931
'learning_rate': optimizer.learning_rate,
3032
'momentum': optimizer.momentum,
@@ -55,10 +57,8 @@ def deserialize(self, data, key, model_type):
5557
:return: pymilo deserialized output of data[key]
5658
"""
5759
content = data[key]
58-
59-
if (key == "_optimizer" and (model_type ==
60-
"MLPRegressor" or model_type == "MLPClassifier")):
61-
optimizer = content['params']
60+
if check_str_in_iterable('pymilo-sgdoptimizer', content):
61+
optimizer = content['pymilo-sgdoptimizer']
6262
if (optimizer["type"] == "SGDOptimizer"):
6363
return SGDOptimizer(
6464
learning_rate=optimizer['learning_rate'],

pymilo/transporters/tree_transporter.py

Lines changed: 7 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
# -*- coding: utf-8 -*-
22
"""PyMilo Tree(from sklearn.tree._tree) object transporter."""
3+
import numpy as np
34
from sklearn.tree._tree import Tree
4-
5+
from ..pymilo_param import NUMPY_TYPE_DICT
56
from .transporter import AbstractTransporter
7+
from ..utils.util import check_str_in_iterable
68
from .general_data_structure_transporter import GeneralDataStructureTransporter
7-
from ..pymilo_param import NUMPY_TYPE_DICT
8-
9-
import numpy as np
109

1110

1211
class TreeTransporter(AbstractTransporter):
@@ -28,10 +27,9 @@ def serialize(self, data, key, model_type):
2827
gdst = GeneralDataStructureTransporter()
2928
tree = data[key]
3029
tree_inner_state = tree.__getstate__()
31-
3230
data[key] = {
3331
'pymilo-bypass': True,
34-
'params': {
32+
'pymilo-tree': {
3533
'internal_state': {
3634
"max_depth": tree_inner_state["max_depth"],
3735
"node_count": tree_inner_state["node_count"],
@@ -47,7 +45,6 @@ def serialize(self, data, key, model_type):
4745
'n_outputs': tree.n_outputs,
4846
}
4947
}
50-
5148
return data[key]
5249

5350
def deserialize(self, data, key, model_type):
@@ -69,19 +66,12 @@ def deserialize(self, data, key, model_type):
6966
:return: pymilo deserialized output of data[key]
7067
"""
7168
content = data[key]
72-
73-
if (key == "tree_" and
74-
(model_type == "DecisionTreeRegressor"
75-
or model_type == "DecisionTreeClassifier"
76-
or model_type == "ExtraTreeRegressor"
77-
or model_type == "ExtraTreeClassifier"
78-
)):
69+
if check_str_in_iterable('pymilo-tree', content):
7970
gdst = GeneralDataStructureTransporter()
80-
tree_params = content['params']
81-
71+
tree_params = content['pymilo-tree']
8272
tree_internal_state = tree_params["internal_state"]
83-
8473
nodes_dtype_spec = []
74+
8575
for idx, node_type in enumerate(tree_internal_state["nodes"]["types"]):
8676
nodes_dtype_spec.append(
8777
(tree_internal_state["nodes"]["field-names"][idx], NUMPY_TYPE_DICT["numpy." + node_type]))
@@ -106,10 +96,7 @@ def deserialize(self, data, key, model_type):
10696
n_classes,
10797
tree_params["n_outputs"]
10898
)
109-
11099
_tree.__setstate__(tree_internal_state)
111-
112100
return _tree
113-
114101
else:
115102
return content

0 commit comments

Comments
 (0)