Skip to content

Commit d3b9857

Browse files
authored
Refactor/chain (#172)
* define `Chain` interface * implement `AbstractChain` * refactor `svm` chain developement * refactor `neural network` chain developement * refactor `neighbors` chain developement * refactor `naive bayes` chain developement * refactor `decision tree` chain developement * refactor `cross decomposition` chain developement * refactor `clustering` chain developement * refactor and update `util.py` * update * update * update docstring * update `test_pymilo.py` to make compatible with new changes * refactor `ensemble` chain developement * refactor `linear` chain developement * reorder imports * refactor spacings * `autopep8.sh` applied * refactor concrete chain implementations and remove non-necessary object level functions to be out of class functions * refactorings * apply docstring feedbacks
1 parent 96d46ed commit d3b9857

13 files changed

+552
-1452
lines changed

pymilo/chains/chain.py

Lines changed: 214 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,214 @@
1+
# -*- coding: utf-8 -*-
2+
"""PyMilo Chain Module."""
3+
4+
from traceback import format_exc
5+
from abc import ABC, abstractmethod
6+
7+
from ..utils.util import get_sklearn_type
8+
from ..transporters.transporter import Command
9+
from ..exceptions.serialize_exception import PymiloSerializationException, SerializationErrorTypes
10+
from ..exceptions.deserialize_exception import PymiloDeserializationException, DeserializationErrorTypes
11+
12+
13+
class Chain(ABC):
14+
"""
15+
Chain Interface.
16+
17+
Each Chain serializes/deserializes the given model.
18+
"""
19+
20+
@abstractmethod
21+
def is_supported(self, model):
22+
"""
23+
Check if the given model is a sklearn's ML model supported by this chain.
24+
25+
:param model: a string name of an ML model or a sklearn object of it
26+
:type model: any object
27+
:return: check result as bool
28+
"""
29+
30+
@abstractmethod
31+
def transport(self, request, command, is_inner_model=False):
32+
"""
33+
Return the transported (serialized or deserialized) model.
34+
35+
:param request: given ML model to be transported
36+
:type request: any object
37+
:param command: command to specify whether the request should be serialized or deserialized
38+
:type command: transporter.Command
39+
:param is_inner_model: determines whether it is an inner model of a super ML model
40+
:type is_inner_model: boolean
41+
:return: the transported request as a json string or sklearn ML model
42+
"""
43+
44+
@abstractmethod
45+
def serialize(self, model):
46+
"""
47+
Return the serialized json string of the given model.
48+
49+
:param model: given ML model to be get serialized
50+
:type model: sklearn ML model
51+
:return: the serialized json string of the given ML model
52+
"""
53+
54+
@abstractmethod
55+
def deserialize(self, serialized_model, is_inner_model=False):
56+
"""
57+
Return the associated sklearn ML model of the given previously serialized ML model.
58+
59+
:param serialized_model: given json string of a ML model to get deserialized to associated sklearn ML model
60+
:type serialized_model: obj
61+
:param is_inner_model: determines whether it is an inner ML model of a super ML model
62+
:type is_inner_model: boolean
63+
:return: associated sklearn ML model
64+
"""
65+
66+
@abstractmethod
67+
def validate(self, model, command):
68+
"""
69+
Check if the provided inputs are valid in relation to each other.
70+
71+
:param model: a sklearn ML model or a json string of it, serialized through the pymilo export
72+
:type model: obj
73+
:param command: command to specify whether the request should be serialized or deserialized
74+
:type command: transporter.Command
75+
:return: None
76+
"""
77+
78+
79+
class AbstractChain(Chain):
80+
"""Abstract Chain with the general implementation of the Chain interface."""
81+
82+
def __init__(self, transporters, supported_models):
83+
"""
84+
Initialize the AbstractChain instance.
85+
86+
:param transporters: worker transporters dedicated to this chain
87+
:type transporters: transporter.AbstractTransporter[]
88+
:param supported_models: supported sklearn ML models belong to this chain
89+
:type supported_models: dict
90+
:return: an instance of the AbstractChain class
91+
"""
92+
self._transporters = transporters
93+
self._supported_models = supported_models
94+
95+
def is_supported(self, model):
96+
"""
97+
Check if the given model is a sklearn's ML model supported by this chain.
98+
99+
:param model: a string name of an ML model or a sklearn object of it
100+
:type model: any object
101+
:return: check result as bool
102+
"""
103+
model_name = model if isinstance(model, str) else get_sklearn_type(model)
104+
return model_name in self._supported_models
105+
106+
def transport(self, request, command, is_inner_model=False):
107+
"""
108+
Return the transported (serialized or deserialized) model.
109+
110+
:param request: given ML model to be transported
111+
:type request: any object
112+
:param command: command to specify whether the request should be serialized or deserialized
113+
:type command: transporter.Command
114+
:param is_inner_model: determines whether it is an inner model of a super ML model
115+
:type is_inner_model: boolean
116+
:return: the transported request as a json string or sklearn ML model
117+
"""
118+
if not is_inner_model:
119+
self.validate(request, command)
120+
121+
if command == Command.SERIALIZE:
122+
try:
123+
return self.serialize(request)
124+
except Exception as e:
125+
raise PymiloSerializationException(
126+
{
127+
'error_type': SerializationErrorTypes.VALID_MODEL_INVALID_INTERNAL_STRUCTURE,
128+
'error': {
129+
'Exception': repr(e),
130+
'Traceback': format_exc(),
131+
},
132+
'object': request,
133+
})
134+
135+
elif command == Command.DESERIALIZE:
136+
try:
137+
return self.deserialize(request, is_inner_model)
138+
except Exception as e:
139+
raise PymiloDeserializationException(
140+
{
141+
'error_type': DeserializationErrorTypes.VALID_MODEL_INVALID_INTERNAL_STRUCTURE,
142+
'error': {
143+
'Exception': repr(e),
144+
'Traceback': format_exc()},
145+
'object': request
146+
})
147+
148+
def serialize(self, model):
149+
"""
150+
Return the serialized json string of the given model.
151+
152+
:param model: given ML model to be get serialized
153+
:type model: sklearn ML model
154+
:return: the serialized json string of the given ML model
155+
"""
156+
for transporter in self._transporters:
157+
self._transporters[transporter].transport(model, Command.SERIALIZE)
158+
return model.__dict__
159+
160+
def deserialize(self, serialized_model, is_inner_model=False):
161+
"""
162+
Return the associated sklearn ML model of the given previously serialized ML model.
163+
164+
:param serialized_model: given json string of a ML model to get deserialized to associated sklearn ML model
165+
:type serialized_model: obj
166+
:param is_inner_model: determines whether it is an inner ML model of a super ML model
167+
:type is_inner_model: boolean
168+
:return: associated sklearn ML model
169+
"""
170+
raw_model = None
171+
data = None
172+
if is_inner_model:
173+
raw_model = self._supported_models[serialized_model["type"]]()
174+
data = serialized_model["data"]
175+
else:
176+
raw_model = self._supported_models[serialized_model.type]()
177+
data = serialized_model.data
178+
for transporter in self._transporters:
179+
self._transporters[transporter].transport(
180+
serialized_model, Command.DESERIALIZE, is_inner_model)
181+
for item in data:
182+
setattr(raw_model, item, data[item])
183+
return raw_model
184+
185+
def validate(self, model, command):
186+
"""
187+
Check if the provided inputs are valid in relation to each other.
188+
189+
:param model: a sklearn ML model or a json string of it, serialized through the pymilo export
190+
:type model: obj
191+
:param command: command to specify whether the request should be serialized or deserialized
192+
:type command: transporter.Command
193+
:return: None
194+
"""
195+
if command == Command.SERIALIZE:
196+
if self.is_supported(model):
197+
return
198+
else:
199+
raise PymiloSerializationException(
200+
{
201+
'error_type': SerializationErrorTypes.INVALID_MODEL,
202+
'object': model
203+
}
204+
)
205+
elif command == Command.DESERIALIZE:
206+
if self.is_supported(model.type):
207+
return
208+
else:
209+
raise PymiloDeserializationException(
210+
{
211+
'error_type': DeserializationErrorTypes.INVALID_MODEL,
212+
'object': model
213+
}
214+
)

pymilo/chains/clustering_chain.py

Lines changed: 8 additions & 142 deletions
Original file line numberDiff line numberDiff line change
@@ -1,158 +1,24 @@
11
# -*- coding: utf-8 -*-
2-
"""PyMilo chain for clustering models."""
3-
from ..transporters.transporter import Command
2+
"""PyMilo chain for Clustering models."""
43

5-
from ..transporters.general_data_structure_transporter import GeneralDataStructureTransporter
6-
from ..transporters.function_transporter import FunctionTransporter
4+
from ..chains.chain import AbstractChain
5+
from ..pymilo_param import SKLEARN_CLUSTERING_TABLE, NOT_SUPPORTED
76
from ..transporters.cfnode_transporter import CFNodeTransporter
7+
from ..transporters.function_transporter import FunctionTransporter
8+
from ..transporters.general_data_structure_transporter import GeneralDataStructureTransporter
89
from ..transporters.preprocessing_transporter import PreprocessingTransporter
910

10-
from ..utils.util import get_sklearn_type
11-
12-
from ..pymilo_param import SKLEARN_CLUSTERING_TABLE, NOT_SUPPORTED
13-
from ..exceptions.serialize_exception import PymiloSerializationException, SerializationErrorTypes
14-
from ..exceptions.deserialize_exception import PymiloDeserializationException, DeserializationErrorTypes
15-
from traceback import format_exc
16-
17-
bisecting_kmeans_support = SKLEARN_CLUSTERING_TABLE["BisectingKMeans"] != NOT_SUPPORTED
1811
CLUSTERING_CHAIN = {
1912
"PreprocessingTransporter": PreprocessingTransporter(),
2013
"GeneralDataStructureTransporter": GeneralDataStructureTransporter(),
2114
"FunctionTransporter": FunctionTransporter(),
2215
"CFNodeTransporter": CFNodeTransporter(),
2316
}
2417

25-
if bisecting_kmeans_support:
26-
from ..transporters.randomstate_transporter import RandomStateTransporter
18+
if SKLEARN_CLUSTERING_TABLE["BisectingKMeans"] != NOT_SUPPORTED:
2719
from ..transporters.bisecting_tree_transporter import BisectingTreeTransporter
20+
from ..transporters.randomstate_transporter import RandomStateTransporter
2821
CLUSTERING_CHAIN["RandomStateTransporter"] = RandomStateTransporter()
2922
CLUSTERING_CHAIN["BisectingTreeTransporter"] = BisectingTreeTransporter()
3023

31-
32-
def is_clusterer(model):
33-
"""
34-
Check if the input model is a sklearn's clustering model.
35-
36-
:param model: is a string name of a clusterer or a sklearn object of it
37-
:type model: any object
38-
:return: check result as bool
39-
"""
40-
if isinstance(model, str):
41-
return model in SKLEARN_CLUSTERING_TABLE
42-
else:
43-
return get_sklearn_type(model) in SKLEARN_CLUSTERING_TABLE.keys()
44-
45-
46-
def transport_clusterer(request, command, is_inner_model=False):
47-
"""
48-
Return the transported (Serialized or Deserialized) model.
49-
50-
:param request: given clusterer to be transported
51-
:type request: any object
52-
:param command: command to specify whether the request should be serialized or deserialized
53-
:type command: transporter.Command
54-
:param is_inner_model: determines whether it is an inner model of a super ml model
55-
:type is_inner_model: boolean
56-
:return: the transported request as a json string or sklearn clustering model
57-
"""
58-
if not is_inner_model:
59-
_validate_input(request, command)
60-
61-
if command == Command.SERIALIZE:
62-
try:
63-
return serialize_clusterer(request)
64-
except Exception as e:
65-
raise PymiloSerializationException(
66-
{
67-
'error_type': SerializationErrorTypes.VALID_MODEL_INVALID_INTERNAL_STRUCTURE,
68-
'error': {
69-
'Exception': repr(e),
70-
'Traceback': format_exc(),
71-
},
72-
'object': request,
73-
})
74-
75-
elif command == Command.DESERIALIZE:
76-
try:
77-
return deserialize_clusterer(request, is_inner_model)
78-
except Exception as e:
79-
raise PymiloDeserializationException(
80-
{
81-
'error_type': SerializationErrorTypes.VALID_MODEL_INVALID_INTERNAL_STRUCTURE,
82-
'error': {
83-
'Exception': repr(e),
84-
'Traceback': format_exc()},
85-
'object': request})
86-
87-
88-
def serialize_clusterer(clusterer_object):
89-
"""
90-
Return the serialized json string of the given clustering model.
91-
92-
:param clusterer_object: given model to be get serialized
93-
:type clusterer_object: any sklearn clustering model
94-
:return: the serialized json string of the given clusterer
95-
"""
96-
for transporter in CLUSTERING_CHAIN:
97-
CLUSTERING_CHAIN[transporter].transport(
98-
clusterer_object, Command.SERIALIZE)
99-
return clusterer_object.__dict__
100-
101-
102-
def deserialize_clusterer(clusterer, is_inner_model=False):
103-
"""
104-
Return the associated sklearn clustering model of the given clusterer.
105-
106-
:param clusterer: given json string of a clustering model to get deserialized to associated sklearn clustering model
107-
:type clusterer: obj
108-
:param is_inner_model: determines whether it is an inner model of a super ml model
109-
:type is_inner_model: boolean
110-
:return: associated sklearn clustering model
111-
"""
112-
raw_model = None
113-
data = None
114-
if is_inner_model:
115-
raw_model = SKLEARN_CLUSTERING_TABLE[clusterer["type"]]()
116-
data = clusterer["data"]
117-
else:
118-
raw_model = SKLEARN_CLUSTERING_TABLE[clusterer.type]()
119-
data = clusterer.data
120-
121-
for transporter in CLUSTERING_CHAIN:
122-
CLUSTERING_CHAIN[transporter].transport(
123-
clusterer, Command.DESERIALIZE, is_inner_model)
124-
for item in data:
125-
setattr(raw_model, item, data[item])
126-
return raw_model
127-
128-
129-
def _validate_input(model, command):
130-
"""
131-
Check if the provided inputs are valid in relation to each other.
132-
133-
:param model: a sklearn clusterer model or a json string of it, serialized through the pymilo export.
134-
:type model: obj
135-
:param command: command to specify whether the request should be serialized or deserialized
136-
:type command: transporter.Command
137-
:return: None
138-
"""
139-
if command == Command.SERIALIZE:
140-
if is_clusterer(model):
141-
return
142-
else:
143-
raise PymiloSerializationException(
144-
{
145-
'error_type': SerializationErrorTypes.INVALID_MODEL,
146-
'object': model
147-
}
148-
)
149-
elif command == Command.DESERIALIZE:
150-
if is_clusterer(model.type):
151-
return
152-
else:
153-
raise PymiloDeserializationException(
154-
{
155-
'error_type': DeserializationErrorTypes.INVALID_MODEL,
156-
'object': model
157-
}
158-
)
24+
clustering_chain = AbstractChain(CLUSTERING_CHAIN, SKLEARN_CLUSTERING_TABLE)

0 commit comments

Comments
 (0)