|
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | 14 |
|
15 | | - |
16 | | -import pickle |
| 15 | +import os |
| 16 | +import dill |
17 | 17 | import numpy as np |
18 | 18 | import pandas as pd |
| 19 | +import autokeras as ak |
| 20 | +import tensorflow as tf |
19 | 21 | from blobcity.store import DictClass |
20 | 22 | from blobcity.utils import get_dataframe_type,dataCleaner |
21 | 23 | from blobcity.utils import AutoFeatureSelection as AFS |
22 | 24 | from blobcity.main.modelSelection import model_search |
23 | 25 | from blobcity.code_gen import yml_reader,code_generator |
24 | 26 | from sklearn.preprocessing import MinMaxScaler |
25 | 27 | from sklearn.feature_selection import SelectKBest,f_regression,f_classif |
26 | | -def train(file=None, df=None, target=None,features=None,accuracy_criteria=0.99): |
| 28 | +def train(file=None, df=None, target=None,features=None,use_neural=False,accuracy_criteria=0.99): |
27 | 29 | """ |
28 | 30 | param1: string: dataset file path |
29 | 31 |
|
@@ -56,35 +58,41 @@ def train(file=None, df=None, target=None,features=None,accuracy_criteria=0.99): |
56 | 58 | CleanedDF=dataCleaner(dataframe,features,target,dict_class) |
57 | 59 | #model search space |
58 | 60 | accuracy_criteria= accuracy_criteria if accuracy_criteria<=1.0 else (accuracy_criteria/100) |
59 | | - modelClass = model_search(CleanedDF,target,dict_class,use_neural=False,accuracy_criteria=accuracy_criteria) |
| 61 | + modelClass = model_search(CleanedDF,target,dict_class,use_neural=use_neural,accuracy_criteria=accuracy_criteria) |
60 | 62 | modelClass.yamldata=dict_class.getdict() |
61 | 63 | modelClass.feature_importance_=dict_class.feature_importance if(features==None) else calculate_feature_importance(CleanedDF.drop(target,axis=1),CleanedDF[target],dict_class) |
62 | 64 | dict_class.resetVar() |
63 | 65 | return modelClass |
64 | 66 |
|
65 | | -def load(modelFile,h5_path=None): |
| 67 | +def load(model_path=None): |
66 | 68 | """ |
67 | 69 | param1: string: (required) the filepath to the stored model. Supports .pkl models. |
68 | | - param2: string: the filepath to the stored h5 file, provide only if saved h5 file. |
69 | 70 | returns: Model file |
70 | 71 |
|
71 | | - function loads the serialized model from .pkl or .h5 format to usable format. |
| 72 | + function loads the serialized model from .pkl format to usable format. |
72 | 73 | """ |
73 | | - path_components = modelFile.split('.') |
74 | | - extension = path_components[1] if len(path_components)<=2 else path_components[-1] |
75 | | - |
76 | | - if extension == 'pkl' and h5_path in [None,""]: |
77 | | - model = pickle.load(open(modelFile, 'rb')) |
78 | | - |
79 | | - """ elif os.path.splitext(h5_path)[1] == '.h5' and h5_path!=None: |
80 | | - print("pkl path: {}, h5 path : {}".format(os.path.splitext(modelFile),os.path.splitext(h5_path))) |
81 | | - if os.path.splitext(h5_path)[0] == os.path.splitext(modelFile)[0]: |
82 | | - tfmodel = tf.keras.models.load_model(h5_path) |
83 | | - model=pickle.load(open(modelFile, 'rb')) |
84 | | - model.model=tfmodel |
| 74 | + if model_path not in [None,""]: |
| 75 | + path_components = model_path.split('.') |
| 76 | + extension = path_components[1] if len(path_components)<=2 else path_components[-1] |
| 77 | + base_path=os.path.splitext(model_path)[0] |
| 78 | + if extension == 'pkl': |
| 79 | + model = dill.load(open(model_path, 'rb')) |
| 80 | + if model.yamldata['model']['type'] in ['TF','tf','Tensorflow']: |
| 81 | + if model.yamldata['model']['save_type']=='h5': |
| 82 | + h5_path=base_path+".h5" |
| 83 | + if os.path.isfile(h5_path):model.model=tf.keras.models.load_model(h5_path) |
| 84 | + else: raise FileNotFoundError(f"{h5_path} file doest exists in the directory") |
| 85 | + elif model.yamldata['model']['save_type']=='pb': |
| 86 | + if os.path.isdir(base_path):model.model=tf.keras.models.load_model(base_path, custom_objects=ak.CUSTOM_OBJECTS) |
| 87 | + else: raise FileNotFoundError(f"{base_path} Folder doest exists") |
| 88 | + else: |
| 89 | + raise TypeError(f"{model.yamldata['model']['save_type']}, not supported save format") |
| 90 | + return model |
85 | 91 | else: |
86 | | - raise ValueError("file name for pickle and h5 file should be same") """ |
87 | | - return model |
| 92 | + raise TypeError(f"{extension}, file type must be .pkl") |
| 93 | + else: |
| 94 | + raise TypeError(f"{model_path}, path can't be None or Null") |
| 95 | + |
88 | 96 |
|
89 | 97 | def spill(filepath,yaml_path=None,doc=None): |
90 | 98 | """ |
|
0 commit comments