|
| 1 | +# -*- coding: utf-8 -*- |
| 2 | +"""PyMilo Feature Extraction transporter.""" |
| 3 | +from scipy.sparse import csr_matrix |
| 4 | + |
| 5 | +from ..pymilo_param import SKLEARN_FEATURE_EXTRACTION_TABLE |
| 6 | +from ..utils.util import check_str_in_iterable, get_sklearn_type |
| 7 | +from .transporter import AbstractTransporter, Command |
| 8 | +from .general_data_structure_transporter import GeneralDataStructureTransporter |
| 9 | +from .randomstate_transporter import RandomStateTransporter |
| 10 | + |
| 11 | +FEATURE_EXTRACTION_CHAIN = { |
| 12 | + "GeneralDataStructureTransporter": GeneralDataStructureTransporter(), |
| 13 | + "RandomStateTransporter": RandomStateTransporter(), |
| 14 | +} |
| 15 | + |
| 16 | + |
| 17 | +class FeatureExtractorTransporter(AbstractTransporter): |
| 18 | + """Feature Extractor object dedicated Transporter.""" |
| 19 | + |
| 20 | + def serialize(self, data, key, model_type): |
| 21 | + """ |
| 22 | + Serialize Feature Extractor object. |
| 23 | +
|
| 24 | + serialize the data[key] of the given model which type is model_type. |
| 25 | + basically in order to fully serialize a model, we should traverse over all the keys of its data dictionary and |
| 26 | + pass it through the chain of associated transporters to get fully serialized. |
| 27 | +
|
| 28 | + :param data: the internal data dictionary of the given model |
| 29 | + :type data: dict |
| 30 | + :param key: the special key of the data param, which we're going to serialize its value(data[key]) |
| 31 | + :type key: object |
| 32 | + :param model_type: the model type of the ML model, which data dictionary is given as the data param |
| 33 | + :type model_type: str |
| 34 | + :return: pymilo serialized output of data[key] |
| 35 | + """ |
| 36 | + if self.is_fe_module(data[key]): |
| 37 | + return self.serialize_fe_module(data[key]) |
| 38 | + return data[key] |
| 39 | + |
| 40 | + def deserialize(self, data, key, model_type): |
| 41 | + """ |
| 42 | + Deserialize previously pymilo serialized feature extraction object. |
| 43 | +
|
| 44 | + deserialize the data[key] of the given model which type is model_type. |
| 45 | + basically in order to fully deserialize a model, we should traverse over all the keys of its serialized data dictionary and |
| 46 | + pass it through the chain of associated transporters to get fully deserialized. |
| 47 | +
|
| 48 | + :param data: the internal data dictionary of the associated json file of the ML model which is generated previously by |
| 49 | + pymilo export. |
| 50 | + :type data: dict |
| 51 | + :param key: the special key of the data param, which we're going to deserialize its value(data[key]) |
| 52 | + :type key: object |
| 53 | + :param model_type: the model type of the ML model, which internal serialized data dictionary is given as the data param |
| 54 | + :type model_type: str |
| 55 | + :return: pymilo deserialized output of data[key] |
| 56 | + """ |
| 57 | + content = data[key] |
| 58 | + if self.is_fe_module(content): |
| 59 | + return self.deserialize_fe_module(content) |
| 60 | + return content |
| 61 | + |
| 62 | + def is_fe_module(self, fe_module): |
| 63 | + """ |
| 64 | + Check whether the given module is a sklearn Feature Extraction module or not. |
| 65 | +
|
| 66 | + :param fe_module: given object |
| 67 | + :type fe_module: any |
| 68 | + :return: bool |
| 69 | + """ |
| 70 | + if isinstance(fe_module, dict): |
| 71 | + return check_str_in_iterable( |
| 72 | + "pymilo-feature_extraction-type", |
| 73 | + fe_module) and fe_module["pymilo-feature_extraction-type"] in SKLEARN_FEATURE_EXTRACTION_TABLE |
| 74 | + return get_sklearn_type(fe_module) in SKLEARN_FEATURE_EXTRACTION_TABLE |
| 75 | + |
| 76 | + def serialize_fe_module(self, fe_module): |
| 77 | + """ |
| 78 | + Serialize Feature Extraction object. |
| 79 | +
|
| 80 | + :param fe_module: given sklearn feature extraction module |
| 81 | + :type fe_module: sklearn.feature_extraction |
| 82 | + :return: pymilo serialized fe_module |
| 83 | + """ |
| 84 | + # add one depth inner preprocessing module population |
| 85 | + for key, value in fe_module.__dict__.items(): |
| 86 | + if self.is_fe_module(value): |
| 87 | + fe_module.__dict__[key] = self.serialize_fe_module(value) |
| 88 | + elif isinstance(value, csr_matrix): |
| 89 | + fe_module.__dict__[key] = { |
| 90 | + "pymilo-bypass": True, |
| 91 | + "pymilo-csr_matrix": FEATURE_EXTRACTION_CHAIN["GeneralDataStructureTransporter"].serialize_dict( |
| 92 | + value.__dict__ |
| 93 | + ) |
| 94 | + } |
| 95 | + |
| 96 | + for transporter in FEATURE_EXTRACTION_CHAIN: |
| 97 | + FEATURE_EXTRACTION_CHAIN[transporter].transport( |
| 98 | + fe_module, Command.SERIALIZE) |
| 99 | + return { |
| 100 | + "pymilo-bypass": True, |
| 101 | + "pymilo-feature_extraction-type": get_sklearn_type(fe_module), |
| 102 | + "pymilo-feature_extraction-data": fe_module.__dict__ |
| 103 | + } |
| 104 | + |
| 105 | + def deserialize_fe_module(self, serialized_fe_module): |
| 106 | + """ |
| 107 | + Deserialize Feature Extraction object. |
| 108 | +
|
| 109 | + :param serialized_fe_module: serializezd feature extraction module(by pymilo) |
| 110 | + :type serialized_fe_module: dict |
| 111 | + :return: retrieved associated sklearn.feature_extraction module |
| 112 | + """ |
| 113 | + data = serialized_fe_module["pymilo-feature_extraction-data"] |
| 114 | + associated_type = SKLEARN_FEATURE_EXTRACTION_TABLE[serialized_fe_module["pymilo-feature_extraction-type"]] |
| 115 | + retrieved_fe_module = associated_type() |
| 116 | + for key in data: |
| 117 | + # add one depth inner feature extraction module population |
| 118 | + if self.is_fe_module(data[key]): |
| 119 | + data[key] = self.deserialize_fe_module(data[key]) |
| 120 | + elif check_str_in_iterable("pymilo-csr_matrix", data[key]): |
| 121 | + csr_matrix_dict = FEATURE_EXTRACTION_CHAIN["GeneralDataStructureTransporter"].get_deserialized_dict( |
| 122 | + data[key]["pymilo-csr_matrix"]) |
| 123 | + cm = csr_matrix(csr_matrix_dict['_shape']) |
| 124 | + for _key in csr_matrix_dict: |
| 125 | + setattr(cm, _key, csr_matrix_dict[_key]) |
| 126 | + data[key] = cm |
| 127 | + for transporter in FEATURE_EXTRACTION_CHAIN: |
| 128 | + data[key] = FEATURE_EXTRACTION_CHAIN[transporter].deserialize(data, key, "") |
| 129 | + for key in data: |
| 130 | + setattr(retrieved_fe_module, key, data[key]) |
| 131 | + return retrieved_fe_module |
0 commit comments