Skip to content

Commit 3f668d9

Browse files
phi-dbqjkbradley
authored andcommitted
Porting Keras Estimator API and Reference Implementation (#35)
## What changes are proposed in this pull request? Creating a Spark MLlib Estimator API for Keras models, with a reference implementation. It provides a taste of how to ingest Image from URI in a `DataFrame` and use them to train a Keras model. The changes consist of these components. 1. Extracted a few Params types for Keras Transformers/Estimators. 2. Keras utilities - Serialization: model <=> hdf5 <=> bytes (for broadcast) - Check avaialble Keras options (optimizers, loss functions, etc.) 3. Keras Estimator. ## How is this patch tested? - [x] Unit tests - [x] Manual tests
1 parent 20ce48f commit 3f668d9

File tree

15 files changed

+961
-200
lines changed

15 files changed

+961
-200
lines changed
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
#
2+
# Copyright 2017 Databricks, Inc.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
#
Lines changed: 317 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,317 @@
1+
#
2+
# Copyright 2017 Databricks, Inc.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
#
16+
17+
# pylint: disable=protected-access
18+
from __future__ import absolute_import, division, print_function
19+
20+
import logging
21+
import numpy as np
22+
23+
from pyspark.ml import Estimator
24+
import pyspark.ml.linalg as spla
25+
from pyspark.ml.param import Param, Params, TypeConverters
26+
27+
from sparkdl.image.imageIO import imageStructToArray
28+
from sparkdl.param import (
29+
keyword_only, CanLoadImage, HasKerasModel, HasKerasOptimizer, HasKerasLoss, HasOutputMode,
30+
HasInputCol, HasInputImageNodeName, HasLabelCol, HasOutputNodeName, HasOutputCol)
31+
from sparkdl.transformers.keras_image import KerasImageFileTransformer
32+
import sparkdl.utils.jvmapi as JVMAPI
33+
import sparkdl.utils.keras_model as kmutil
34+
35+
__all__ = ['KerasImageFileEstimator']
36+
37+
logger = logging.getLogger('sparkdl')
38+
39+
class KerasImageFileEstimator(Estimator, HasInputCol, HasInputImageNodeName,
40+
HasOutputCol, HasOutputNodeName, HasLabelCol,
41+
HasKerasModel, HasKerasOptimizer, HasKerasLoss,
42+
CanLoadImage, HasOutputMode):
43+
"""
44+
Build a Estimator from a Keras model.
45+
46+
First, create a model and save it to file system
47+
48+
.. code-block:: python
49+
50+
from keras.applications.resnet50 import ResNet50
51+
model = ResNet50(weights=None)
52+
model.save("path_to_my_model.h5")
53+
54+
Then, create a image loading function that reads image data from URI,
55+
preprocess them, and returns the numerical tensor.
56+
57+
.. code-block:: python
58+
59+
def load_image_and_process(uri):
60+
import PIL.Image
61+
from keras.applications.imagenet_utils import preprocess_input
62+
63+
original_image = PIL.Image.open(uri).convert('RGB')
64+
resized_image = original_image.resize((224, 224), PIL.Image.ANTIALIAS)
65+
image_array = np.array(resized_image).astype(np.float32)
66+
image_tensor = preprocess_input(image_array[np.newaxis, :])
67+
return image_tensor
68+
69+
70+
Assume the image URIs live in the following DataFrame.
71+
72+
.. code-block:: python
73+
74+
original_dataset = spark.createDataFrame([
75+
Row(imageUri="image1_uri", imageLabel="image1_label"),
76+
Row(imageUri="image2_uri", imageLabel="image2_label"),
77+
# and more rows ...
78+
])
79+
stringIndexer = StringIndexer(inputCol="imageLabel", outputCol="categoryIndex")
80+
indexed_dateset = stringIndexer.fit(original_dataset).transform(original_dataset)
81+
encoder = OneHotEncoder(inputCol="categoryIndex", outputCol="categoryVec")
82+
image_dataset = encoder.transform(indexed_dateset)
83+
84+
We can then create a Keras estimator that takes our saved model file and
85+
train it using Spark.
86+
87+
.. code-block:: python
88+
89+
estimator = KerasImageFileEstimator(inputCol="imageUri",
90+
outputCol="name_of_result_column",
91+
labelCol="categoryVec",
92+
imageLoader=load_image_and_process,
93+
kerasOptimizer="adam",
94+
kerasLoss="categorical_crossentropy",
95+
kerasFitParams={"epochs": 5, "batch_size": 64},
96+
modelFile="path_to_my_model.h5")
97+
98+
transformers = estimator.fit(image_dataset)
99+
100+
"""
101+
102+
@keyword_only
103+
def __init__(self, inputCol=None, inputImageNodeName=None, outputCol=None,
104+
outputNodeName=None, outputMode="vector", labelCol=None,
105+
modelFile=None, imageLoader=None, kerasOptimizer=None, kerasLoss=None,
106+
kerasFitParams=None):
107+
"""
108+
__init__(self, inputCol=None, inputImageNodeName=None, outputCol=None,
109+
outputNodeName=None, outputMode="vector", labelCol=None,
110+
modelFile=None, imageLoader=None, kerasOptimizer=None, kerasLoss=None,
111+
kerasFitParams=None)
112+
"""
113+
# NOTE(phi-dbq): currently we ignore output mode, as the actual output are the
114+
# trained models and the Transformers built from them.
115+
super(KerasImageFileEstimator, self).__init__()
116+
kwargs = self._input_kwargs
117+
self.setParams(**kwargs)
118+
119+
@keyword_only
120+
def setParams(self, inputCol=None, inputImageNodeName=None, outputCol=None,
121+
outputNodeName=None, outputMode="vector", labelCol=None,
122+
modelFile=None, imageLoader=None, kerasOptimizer=None, kerasLoss=None,
123+
kerasFitParams=None):
124+
"""
125+
setParams(self, inputCol=None, inputImageNodeName=None, outputCol=None,
126+
outputNodeName=None, outputMode="vector", labelCol=None,
127+
modelFile=None, imageLoader=None, kerasOptimizer=None, kerasLoss=None,
128+
kerasFitParams=None)
129+
"""
130+
kwargs = self._input_kwargs
131+
return self._set(**kwargs)
132+
133+
def fit(self, dataset, params=None):
134+
"""
135+
Fits a model to the input dataset with optional parameters.
136+
137+
.. warning:: This returns the byte serialized HDF5 file for each model to the driver.
138+
If the model file is large, the driver might go out-of-memory.
139+
As we cannot assume the existence of a sufficiently large (and writable)
140+
file system, users are advised to not train too many models in a single
141+
Spark job.
142+
143+
:param dataset: input dataset, which is an instance of :py:class:`pyspark.sql.DataFrame`.
144+
The column `inputCol` should be of type `sparkdl.image.imageIO.imgSchema`.
145+
:param params: An optional param map that overrides embedded params. If a list/tuple of
146+
param maps is given, this calls fit on each param map and returns a list of
147+
models.
148+
:return: fitted model(s). If params includes a list of param maps, the order of these
149+
models matches the order of the param maps.
150+
"""
151+
self._validateParams()
152+
if params is None:
153+
paramMaps = [dict()]
154+
elif isinstance(params, (list, tuple)):
155+
if len(params) == 0:
156+
paramMaps = [dict()]
157+
else:
158+
self._validateFitParams(params)
159+
paramMaps = params
160+
elif isinstance(params, dict):
161+
self._validateFitParams(params)
162+
paramMaps = [params]
163+
else:
164+
raise ValueError("Params must be either a param map or a list/tuple of param maps, "
165+
"but got %s." % type(params))
166+
return self._fitInParallel(dataset, paramMaps)
167+
168+
def _validateParams(self):
169+
"""
170+
Check Param values so we can throw errors on the driver, rather than workers.
171+
:return: True if parameters are valid
172+
"""
173+
if not self.isDefined(self.inputCol):
174+
raise ValueError("Input column must be defined")
175+
if not self.isDefined(self.outputCol):
176+
raise ValueError("Output column must be defined")
177+
return True
178+
179+
def _validateFitParams(self, params):
180+
""" Check if an input parameter set is valid """
181+
if isinstance(params, (list, tuple, dict)):
182+
assert self.getInputCol() not in params, \
183+
"params {} cannot contain input column name {}".format(params, self.getInputCol())
184+
return True
185+
186+
def _getNumpyFeaturesAndLabels(self, dataset):
187+
"""
188+
We assume the training data fits in memory on a single server.
189+
The input dataframe is converted to numerical image features and
190+
broadcast to all the worker nodes.
191+
"""
192+
image_uri_col = self.getInputCol()
193+
label_col = None
194+
if self.isDefined(self.labelCol) and self.getLabelCol() != "":
195+
label_col = self.getLabelCol()
196+
tmp_image_col = self._loadedImageCol()
197+
image_df = self.loadImagesInternal(dataset, image_uri_col).dropna(subset=[tmp_image_col])
198+
199+
# Extract features
200+
localFeatures = []
201+
rows = image_df.collect()
202+
for row in rows:
203+
spimg = row[tmp_image_col]
204+
features = imageStructToArray(spimg)
205+
localFeatures.append(features)
206+
207+
if not localFeatures: # NOTE(phi-dbq): pep-8 recommended against testing 0 == len(array)
208+
raise ValueError("Cannot extract any feature from dataset!")
209+
X = np.stack(localFeatures, axis=0)
210+
211+
# Extract labels
212+
y = None
213+
if label_col is not None:
214+
label_schema = image_df.schema[label_col]
215+
label_dtype = label_schema.dataType
216+
assert isinstance(label_dtype, spla.VectorUDT), \
217+
"must encode labels in one-hot vector format, but got {}".format(label_dtype)
218+
219+
localLabels = []
220+
for row in rows:
221+
try:
222+
_keras_label = row[label_col].array
223+
except ValueError:
224+
raise ValueError("Cannot extract encoded label array")
225+
localLabels.append(_keras_label)
226+
227+
if not localLabels:
228+
raise ValueError("Failed to load any labels from dataset, but labels are required")
229+
230+
y = np.stack(localLabels, axis=0)
231+
assert y.shape[0] == X.shape[0], \
232+
"number of features {} != number of labels {}".format(X.shape[0], y.shape[0])
233+
234+
return X, y
235+
236+
def _collectModels(self, kerasModelsBytesRDD):
237+
"""
238+
Collect Keras models on workers to MLlib Models on the driver.
239+
:param kerasModelBytesRDD: RDD of (param_map, model_bytes) tuples
240+
:param paramMaps: list of ParamMaps matching the maps in `kerasModelsRDD`
241+
:return: list of MLlib models
242+
"""
243+
transformers = []
244+
for (param_map, model_bytes) in kerasModelsBytesRDD.collect():
245+
model_filename = kmutil.bytes_to_h5file(model_bytes)
246+
transformers.append({
247+
'paramMap': param_map,
248+
'transformer': KerasImageFileTransformer(modelFile=model_filename)})
249+
250+
return transformers
251+
252+
def _fitInParallel(self, dataset, paramMaps):
253+
"""
254+
Fits len(paramMaps) models in parallel, one in each Spark task.
255+
:param paramMaps: non-empty list or tuple of ParamMaps (dict values)
256+
:return: list of fitted models, matching the order of paramMaps
257+
"""
258+
sc = JVMAPI._curr_sc()
259+
paramMapsRDD = sc.parallelize(paramMaps, numSlices=len(paramMaps))
260+
261+
# Extract image URI from provided dataset and create features as numpy arrays
262+
localFeatures, localLabels = self._getNumpyFeaturesAndLabels(dataset)
263+
localFeaturesBc = sc.broadcast(localFeatures)
264+
localLabelsBc = None if localLabels is None else sc.broadcast(localLabels)
265+
266+
# Broadcast Keras model (HDF5) file content as bytes
267+
modelBytes = self._loadModelAsBytes()
268+
modelBytesBc = sc.broadcast(modelBytes)
269+
270+
# Obtain params for this estimator instance
271+
baseParamMap = self.extractParamMap()
272+
baseParamDict = dict([(param.name, val) for param, val in baseParamMap.items()])
273+
baseParamDictBc = sc.broadcast(baseParamDict)
274+
275+
def _local_fit(override_param_map):
276+
"""
277+
Fit locally a model with a combination of this estimator's param,
278+
with overriding parameters provided by the input.
279+
:param override_param_map: dict, key type is MLllib Param
280+
They are meant to override the base estimator's params.
281+
:return: serialized Keras HDF5 file bytes
282+
"""
283+
# Update params
284+
params = baseParamDictBc.value
285+
override_param_dict = dict([
286+
(param.name, val) for param, val in override_param_map.items()])
287+
params.update(override_param_dict)
288+
289+
# Create Keras model
290+
model = kmutil.bytes_to_model(modelBytesBc.value)
291+
model.compile(optimizer=params['kerasOptimizer'], loss=params['kerasLoss'])
292+
293+
# Retrieve features and labels and fit Keras model
294+
features = localFeaturesBc.value
295+
labels = None if localLabelsBc is None else localLabelsBc.value
296+
_fit_params = params['kerasFitParams']
297+
model.fit(x=features, y=labels, **_fit_params)
298+
299+
return kmutil.model_to_bytes(model)
300+
301+
kerasModelBytesRDD = paramMapsRDD.map(lambda paramMap: (paramMap, _local_fit(paramMap)))
302+
return self._collectModels(kerasModelBytesRDD)
303+
304+
def _loadModelAsBytes(self):
305+
"""
306+
(usable on driver only)
307+
Load the Keras model file as a byte string.
308+
:return: str containing the model data
309+
"""
310+
with open(self.getModelFile(), mode='rb') as fin:
311+
fileContent = fin.read()
312+
return fileContent
313+
314+
def _fit(self, dataset): # pylint: disable=unused-argument
315+
err_msgs = ["This function should not have been called",
316+
"Please contact library maintainers to file a bug"]
317+
raise NotImplementedError('\n'.join(err_msgs))

python/sparkdl/param/__init__.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# Copyright 2017 Databricks, Inc.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
#
15+
16+
from sparkdl.param.shared_params import (
17+
keyword_only, HasInputCol, HasOutputCol, HasLabelCol, HasKerasModel,
18+
HasKerasLoss, HasKerasOptimizer, HasOutputNodeName, SparkDLTypeConverters)
19+
from sparkdl.param.image_params import (
20+
CanLoadImage, HasInputImageNodeName, HasOutputMode, OUTPUT_MODES)

0 commit comments

Comments
 (0)