2020import threading
2121import numpy as np
2222
23- import pyspark
2423from pyspark .ml import Estimator
2524import pyspark .ml .linalg as spla
2625
2726from sparkdl .image .imageIO import imageStructToArray
2827from sparkdl .param import (
2928 keyword_only , CanLoadImage , HasKerasModel , HasKerasOptimizer , HasKerasLoss , HasOutputMode ,
30- HasInputCol , HasInputImageNodeName , HasLabelCol , HasOutputNodeName , HasOutputCol )
29+ HasInputCol , HasLabelCol , HasOutputCol )
3130from sparkdl .transformers .keras_image import KerasImageFileTransformer
3231import sparkdl .utils .jvmapi as JVMAPI
3332import sparkdl .utils .keras_model as kmutil
@@ -74,10 +73,8 @@ def next(self):
7473 return self .__next__ ()
7574
7675
77- class KerasImageFileEstimator (Estimator , HasInputCol , HasInputImageNodeName ,
78- HasOutputCol , HasOutputNodeName , HasLabelCol ,
79- HasKerasModel , HasKerasOptimizer , HasKerasLoss ,
80- CanLoadImage , HasOutputMode ):
76+ class KerasImageFileEstimator (Estimator , HasInputCol , HasOutputCol , HasLabelCol , HasKerasModel ,
77+ HasKerasOptimizer , HasKerasLoss , CanLoadImage , HasOutputMode ):
8178 """
8279 Build a Estimator from a Keras model.
8380
@@ -138,13 +135,11 @@ def load_image_and_process(uri):
138135 """
139136
140137 @keyword_only
141- def __init__ (self , inputCol = None , inputImageNodeName = None , outputCol = None ,
142- outputNodeName = None , outputMode = "vector" , labelCol = None ,
138+ def __init__ (self , inputCol = None , outputCol = None , outputMode = "vector" , labelCol = None ,
143139 modelFile = None , imageLoader = None , kerasOptimizer = None , kerasLoss = None ,
144140 kerasFitParams = None ):
145141 """
146- __init__(self, inputCol=None, inputImageNodeName=None, outputCol=None,
147- outputNodeName=None, outputMode="vector", labelCol=None,
142+ __init__(self, inputCol=None, outputCol=None, outputMode="vector", labelCol=None,
148143 modelFile=None, imageLoader=None, kerasOptimizer=None, kerasLoss=None,
149144 kerasFitParams=None)
150145 """
@@ -155,13 +150,11 @@ def __init__(self, inputCol=None, inputImageNodeName=None, outputCol=None,
155150 self .setParams (** kwargs )
156151
157152 @keyword_only
158- def setParams (self , inputCol = None , inputImageNodeName = None , outputCol = None ,
159- outputNodeName = None , outputMode = "vector" , labelCol = None ,
153+ def setParams (self , inputCol = None , outputCol = None , outputMode = "vector" , labelCol = None ,
160154 modelFile = None , imageLoader = None , kerasOptimizer = None , kerasLoss = None ,
161155 kerasFitParams = None ):
162156 """
163- setParams(self, inputCol=None, inputImageNodeName=None, outputCol=None,
164- outputNodeName=None, outputMode="vector", labelCol=None,
157+ setParams(self, inputCol=None, outputCol=None, outputMode="vector", labelCol=None,
165158 modelFile=None, imageLoader=None, kerasOptimizer=None, kerasLoss=None,
166159 kerasFitParams=None)
167160 """
@@ -174,12 +167,23 @@ def _validateParams(self, paramMap):
174167 :param paramMap: Dict[pyspark.ml.param.Param, object]
175168 :return: True if parameters are valid
176169 """
177- if not self .isDefined (self .inputCol ):
178- raise ValueError ("Input column must be defined" )
179- if not self .isDefined (self .outputCol ):
180- raise ValueError ("Output column must be defined" )
181- if self .inputCol in paramMap :
182- raise ValueError ("Input column can not be fine tuned" )
170+ model_params = [self .kerasOptimizer , self .kerasLoss , self .kerasFitParams ]
171+ output_params = [self .outputCol , self .outputMode ]
172+
173+ params = self .params
174+ undefined = set ([p for p in params if not self .isDefined (p )])
175+ undefined_tunable = undefined .intersection (model_params + output_params )
176+ failed_define = [p .name for p in undefined .difference (undefined_tunable )]
177+ failed_tune = [p .name for p in undefined_tunable if p not in paramMap ]
178+
179+ if failed_define or failed_tune :
180+ msg = "Following Params must be"
181+ if failed_define :
182+ msg += " defined: [" + ", " .join (failed_define ) + "]"
183+ if failed_tune :
184+ msg += " defined or tuned: [" + ", " .join (failed_tune ) + "]"
185+ raise ValueError (msg )
186+
183187 return True
184188
185189 def _getNumpyFeaturesAndLabels (self , dataset ):
@@ -236,8 +240,7 @@ def _collectModels(self, kerasModelBytesRDD):
236240 """
237241 Collect Keras models on workers to MLlib Models on the driver.
238242 :param kerasModelBytesRDD: RDD of (param_map, model_bytes) tuples
239- :param paramMaps: list of ParamMaps matching the maps in `kerasModelsRDD`
240- :return: list of MLlib models
243+ :return: generator of (index, MLlib model) tuples
241244 """
242245 for (i , param_map , model_bytes ) in kerasModelBytesRDD .collect ():
243246 model_filename = kmutil .bytes_to_h5file (model_bytes )
@@ -264,7 +267,6 @@ def _name_value_map(paramMap):
264267 """takes a dictionary {param -> value} and returns a map of {param.name -> value}"""
265268 return {param .name : val for param , val in paramMap .items ()}
266269
267-
268270 sc = JVMAPI ._curr_sc ()
269271 paramNameMaps = list (enumerate (map (_name_value_map , paramMaps )))
270272 num_models = len (paramNameMaps )
0 commit comments