Skip to content

Commit 16415f7

Browse files
yogeshgsmurching
authored andcommitted
Add style checks and refactor suggestions (#121)
* [ML-3487] add accepted and suggested pylint rc files * space, comments, messages * refactor: rename or ignore invalid-names; remove unused import * remove unused type * rename globals * ignore a variable name * bugfix: variables named without updating in error message * space, comment, alignment changes * pylint disable protected-access * ignore stlye case wise * bugfix * pylint snake_case and camelCase attribute names allowed * ignore stlye case wise * add spaces, indent, expressions to simplify reading * pylint disable what can be kept * if len if -> if any; imports, ignore importing issues * add spaces, indent, group imports * requires refactoring * Undefined variable name 'imageType' in __all__ * add spaces, indent, group imports * add and ignore todos * no len as comparsion, indent, simple expressions * add spaces, indent * no-else-return seems weird on local machine, disabling * group imports, ignore g, op names * remove trailing white spaces * ignore import error, add refactor todo, lazy logging * fix imports, disable some * use output of _validateParam, ignore too few in ThreadSafeIterator * optimize imports, ignore fixme, no-self-use * optimize imports * optimize imports, space * fix cyclic import (sparkdl.param -> sparkdl.param.image_params) * optimize imports * fixmes in code are acceptable; throw away value of validate_params; sz -> resize_image
1 parent a36b705 commit 16415f7

23 files changed

+1370
-212
lines changed

python/.pylint/accepted.rc

Lines changed: 557 additions & 0 deletions
Large diffs are not rendered by default.

python/.pylint/suggested.rc

Lines changed: 547 additions & 0 deletions
Large diffs are not rendered by default.

python/sparkdl/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,6 @@
2222
from .estimators.keras_image_file_estimator import KerasImageFileEstimator
2323

2424
__all__ = [
25-
'imageType', 'TFImageTransformer', 'TFInputGraph', 'TFTransformer',
26-
'DeepImagePredictor', 'DeepImageFeaturizer', 'KerasImageFileTransformer', 'KerasTransformer',
25+
'TFImageTransformer', 'TFInputGraph', 'TFTransformer', 'DeepImagePredictor',
26+
'DeepImageFeaturizer', 'KerasImageFileTransformer', 'KerasTransformer',
2727
'imageInputPlaceholder', 'KerasImageFileEstimator']

python/sparkdl/estimators/keras_image_file_estimator.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
__all__ = ['KerasImageFileEstimator']
3535

3636

37+
# pylint: disable=too-few-public-methods
3738
class _ThreadSafeIterator(object):
3839
"""
3940
Utility iterator class used by KerasImageFileEstimator.fitMultiple to serve models in a thread
@@ -263,7 +264,7 @@ def fitMultiple(self, dataset, paramMaps):
263264
existence of a sufficiently large (and writable) file system, users are
264265
advised to not train too many models in a single Spark job.
265266
"""
266-
[self._validateParams(pm) for pm in paramMaps]
267+
_ = [self._validateParams(pm) for pm in paramMaps]
267268

268269
def _get_tunable_name_value_map(param_map, tunable):
269270
"""takes a dictionary {`Param` -> value} and a list [`Param`], select keys that are

python/sparkdl/graph/builder.py

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,11 @@
1616
import logging
1717
import os
1818
import shutil
19-
import six
2019
from tempfile import mkdtemp
2120

2221
import keras.backend as K
2322
from keras.models import Model as KerasModel, load_model
23+
import six
2424
import tensorflow as tf
2525

2626
import sparkdl.graph.utils as tfx
@@ -83,15 +83,16 @@ def asGraphFunction(self, inputs, outputs, strip_and_freeze=True):
8383
8484
:param inputs: list, graph elements representing the inputs
8585
:param outputs: list, graph elements representing the outputs
86-
:param strip_and_freeze: bool, should we remove unused part of the graph and freee its values
86+
:param strip_and_freeze: bool, should we remove unused part of the graph and freeze its
87+
values
8788
"""
8889
if strip_and_freeze:
8990
gdef = tfx.strip_and_freeze_until(outputs, self.graph, self.sess)
9091
else:
9192
gdef = self.graph.as_graph_def(add_shapes=True)
92-
return GraphFunction(graph_def=gdef,
93-
input_names=[tfx.validated_input(elem, self.graph) for elem in inputs],
94-
output_names=[tfx.validated_output(elem, self.graph) for elem in outputs])
93+
input_names = [tfx.validated_input(elem, self.graph) for elem in inputs]
94+
output_names = [tfx.validated_output(elem, self.graph) for elem in outputs]
95+
return GraphFunction(graph_def=gdef, input_names=input_names, output_names=output_names)
9596

9697
def importGraphFunction(self, gfn, input_map=None, prefix="GFN-IMPORT", **gdef_kargs):
9798
"""
@@ -100,9 +101,11 @@ def importGraphFunction(self, gfn, input_map=None, prefix="GFN-IMPORT", **gdef_k
100101
101102
.. _a link: https://www.tensorflow.org/api_docs/python/tf/import_graph_def
102103
103-
:param gfn: GraphFunction, an object representing a TensorFlow graph and its inputs and outputs
104+
:param gfn: GraphFunction, an object representing a TensorFlow graph and its inputs and
105+
outputs
104106
:param input_map: dict, mapping from input names to existing graph elements
105-
:param prefix: str, the scope for all the variables in the :py:class:`GraphFunction` elements
107+
:param prefix: str, the scope for all the variables in the :py:class:`GraphFunction`
108+
elements
106109
107110
.. _a link: https://www.tensorflow.org/programmers_guide/variable_scope
108111
@@ -119,13 +122,11 @@ def importGraphFunction(self, gfn, input_map=None, prefix="GFN-IMPORT", **gdef_k
119122
input_names = gfn.input_names
120123
output_names = gfn.output_names
121124
scope_name = prefix
122-
if prefix is not None:
125+
if prefix:
123126
scope_name = prefix.strip()
124-
if len(scope_name) > 0:
125-
output_names = [
126-
scope_name + '/' + op_name for op_name in gfn.output_names]
127-
input_names = [
128-
scope_name + '/' + op_name for op_name in gfn.input_names]
127+
if scope_name:
128+
output_names = [scope_name + '/' + op_name for op_name in gfn.output_names]
129+
input_names = [scope_name + '/' + op_name for op_name in gfn.input_names]
129130

130131
# When importing, provide the original output op names
131132
tf.import_graph_def(gfn.graph_def,
@@ -142,7 +143,8 @@ class GraphFunction(object):
142143
"""
143144
Represent a TensorFlow graph with its GraphDef, input and output operation names.
144145
145-
:param graph_def: GraphDef, a static ProtocolBuffer object holding informations of a TensorFlow graph
146+
:param graph_def: GraphDef, a static ProtocolBuffer object holding information of a
147+
TensorFlow graph
146148
:param input_names: names to the input graph elements (must be of Placeholder type)
147149
:param output_names: names to the output graph elements
148150
"""
@@ -179,7 +181,8 @@ def fromKeras(cls, model_or_file_path):
179181
"""
180182
Build a GraphFunction from a Keras model
181183
182-
:param model_or_file_path: KerasModel or str, either a Keras model or the file path name to one
184+
:param model_or_file_path: KerasModel or str, either a Keras model or the file path name
185+
to one
183186
"""
184187
if isinstance(model_or_file_path, KerasModel):
185188
model = model_or_file_path
@@ -214,7 +217,7 @@ def fromList(cls, functions):
214217
:param functions: a list of tuples (scope name, GraphFunction object).
215218
"""
216219
assert len(functions) >= 1, ("must provide at least one function", functions)
217-
if 1 == len(functions):
220+
if len(functions) == 1:
218221
return functions[0]
219222
# Check against each intermediary layer input output function pairs
220223
for (scope_in, gfn_in), (scope_out, gfn_out) in zip(functions[:-1], functions[1:]):
@@ -252,7 +255,8 @@ def fromList(cls, functions):
252255

253256
for idx, (scope, gfn) in enumerate(functions):
254257
# Give a scope to each function to avoid name conflict
255-
if scope is None or len(scope.strip()) == 0:
258+
if scope is None or len(scope.strip()) == 0: # pylint: disable=len-as-condition
259+
# TODO: refactor above and test: if not (scope and scope.strip())
256260
scope = 'GFN-BLK-{}'.format(idx)
257261
_msg = 'merge: stage {}, scope {}'.format(idx, scope)
258262
logger.info(_msg)

python/sparkdl/graph/input.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ class TFInputGraph(object):
7777
inference, i.e. the variables are converted to constants and operations like
7878
BatchNormalization_ are converted to be independent of input batch.
7979
80-
.. _BatchNormalization: https://www.tensorflow.org/api_docs/python/tf/layers/batch_normalization
80+
.. _BatchNormalization: https://www.tensorflow.org/api_docs/python/tf/layers/batch_normalization
8181
8282
:param input_tensor_name_from_signature: dict, signature key names mapped to tensor names.
8383
Please see the example above.

python/sparkdl/graph/pieces.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,8 @@ def buildSpImageConverter(channelOrder, img_dtype):
5555
elif img_dtype == 'float32':
5656
image_float = tf.decode_raw(image_buffer, tf.float32, name="decode_raw")
5757
else:
58-
raise ValueError(
59-
'unsupported image data type "%s", currently only know how to handle uint8 and float32' % img_dtype)
58+
raise ValueError('''unsupported image data type "%s", currently only know how to
59+
handle uint8 and float32''' % img_dtype)
6060
image_reshaped = tf.reshape(image_float, shape, name="reshaped")
6161
image_reshaped = imageIO.fixColorChannelOrdering(channelOrder, image_reshaped)
6262
image_input = tf.expand_dims(image_reshaped, 0, name="image_input")

python/sparkdl/graph/tensorframes_udf.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
import logging
1818

19-
import tensorframes as tfs
19+
import tensorframes as tfs # pylint: disable=import-error
2020

2121
import sparkdl.graph.utils as tfx
2222
from sparkdl.utils import jvmapi as JVMAPI
@@ -85,6 +85,8 @@ def makeGraphUDF(graph, udf_name, fetches, feeds_to_fields_map=None, blocked=Fal
8585
placeholder_names = []
8686
placeholder_shapes = []
8787
for node in graph.as_graph_def(add_shapes=True).node:
88+
# pylint: disable=len-as-condition
89+
# todo: refactor if not(node.input) and ...
8890
if len(node.input) == 0 and str(node.op) == 'Placeholder':
8991
tnsr_name = tfx.tensor_name(node.name, graph)
9092
tnsr = graph.get_tensor_by_name(tnsr_name)

python/sparkdl/graph/utils.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@
1515
#
1616

1717
import logging
18-
import six
1918

19+
import six
2020
import tensorflow as tf
2121

2222
logger = logging.getLogger('sparkdl')
@@ -74,7 +74,7 @@ def get_op(tfobj_or_name, graph):
7474
if not isinstance(name, six.string_types):
7575
raise TypeError('invalid op request for [type {}] {}'.format(type(name), name))
7676
_op_name = op_name(name, graph=None)
77-
op = graph.get_operation_by_name(_op_name)
77+
op = graph.get_operation_by_name(_op_name) # pylint: disable=invalid-name
7878
err_msg = 'cannot locate op {} in the current graph, got [type {}] {}'
7979
assert isinstance(op, tf.Operation), err_msg.format(_op_name, type(op), op)
8080
return op
@@ -190,9 +190,8 @@ def validated_input(tfobj_or_name, graph):
190190
"""
191191
graph = validated_graph(graph)
192192
name = op_name(tfobj_or_name, graph)
193-
op = graph.get_operation_by_name(name)
194-
assert 'Placeholder' == op.type, \
195-
('input must be Placeholder, but get', op.type)
193+
op = graph.get_operation_by_name(name) # pylint: disable=invalid-name
194+
assert 'Placeholder' == op.type, ('input must be Placeholder, but get', op.type)
196195
return name
197196

198197

@@ -223,7 +222,7 @@ def strip_and_freeze_until(fetches, graph, sess=None, return_graph=False):
223222
sess.close()
224223

225224
if return_graph:
226-
g = tf.Graph()
225+
g = tf.Graph() # pylint: disable=invalid-name
227226
with g.as_default():
228227
tf.import_graph_def(gdef_frozen, name='')
229228
return g

python/sparkdl/image/imageIO.py

Lines changed: 22 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,7 @@
2525
from pyspark import SparkContext
2626
from pyspark.ml.image import ImageSchema
2727
from pyspark.sql.functions import udf
28-
from pyspark.sql.types import (
29-
BinaryType, IntegerType, StringType, StructField, StructType)
28+
from pyspark.sql.types import BinaryType, StringType, StructField, StructType
3029

3130

3231
# ImageType represents supported OpenCV types
@@ -39,8 +38,7 @@
3938
# NOTE: likely to be migrated to Spark ImageSchema code in the near future.
4039
_OcvType = namedtuple("OcvType", ["name", "ord", "nChannels", "dtype"])
4140

42-
43-
_supportedOcvTypes = (
41+
_SUPPORTED_OCV_TYPES = (
4442
_OcvType(name="CV_8UC1", ord=0, nChannels=1, dtype="uint8"),
4543
_OcvType(name="CV_32FC1", ord=5, nChannels=1, dtype="float32"),
4644
_OcvType(name="CV_8UC3", ord=16, nChannels=3, dtype="uint8"),
@@ -50,22 +48,22 @@
5048
)
5149

5250
# NOTE: likely to be migrated to Spark ImageSchema code in the near future.
53-
_ocvTypesByName = {m.name: m for m in _supportedOcvTypes}
54-
_ocvTypesByOrdinal = {m.ord: m for m in _supportedOcvTypes}
51+
_OCV_TYPES_BY_NAME = {m.name: m for m in _SUPPORTED_OCV_TYPES}
52+
_OCV_TYPES_BY_ORDINAL = {m.ord: m for m in _SUPPORTED_OCV_TYPES}
5553

5654

57-
def imageTypeByOrdinal(ord):
58-
if not ord in _ocvTypesByOrdinal:
55+
def imageTypeByOrdinal(ordinal):
56+
if not ordinal in _OCV_TYPES_BY_ORDINAL:
5957
raise KeyError("unsupported image type with ordinal %d, supported OpenCV types = %s" % (
60-
ord, str(_supportedOcvTypes)))
61-
return _ocvTypesByOrdinal[ord]
58+
ordinal, str(_SUPPORTED_OCV_TYPES)))
59+
return _OCV_TYPES_BY_ORDINAL[ordinal]
6260

6361

6462
def imageTypeByName(name):
65-
if not name in _ocvTypesByName:
66-
raise KeyError("unsupported image type with name '%s', supported supported OpenCV types = %s" % (
67-
name, str(_supportedOcvTypes)))
68-
return _ocvTypesByName[name]
63+
if not name in _OCV_TYPES_BY_NAME:
64+
raise KeyError("unsupported image type with name '%s', supported OpenCV types = %s" % (
65+
name, str(_SUPPORTED_OCV_TYPES)))
66+
return _OCV_TYPES_BY_NAME[name]
6967

7068

7169
def imageArrayToStruct(imgArray, origin=""):
@@ -151,13 +149,13 @@ def fixColorChannelOrdering(currentOrder, imgAry):
151149
elif currentOrder == 'BGR':
152150
return imgAry
153151
elif currentOrder == 'L':
154-
if len(img.shape) != 1:
152+
if len(imgAry.shape) != 1:
155153
raise ValueError(
156-
"channel order suggests only one color channel but got shape " + str(img.shape))
154+
"channel order suggests only one color channel but got shape " + str(imgAry.shape))
157155
return imgAry
158156
else:
159157
raise ValueError(
160-
"Unexpected channel order, expected one of L,RGB,BGR but got " + currentChannelOrder)
158+
"Unexpected channel order, expected one of L,RGB,BGR but got " + currentOrder)
161159

162160

163161
def _reverseChannels(ary):
@@ -176,12 +174,12 @@ def createResizeImageUDF(size):
176174
if len(size) != 2:
177175
raise ValueError(
178176
"New image size should have format [height, width] but got {}".format(size))
179-
sz = (size[1], size[0])
177+
resize_sizes = (size[1], size[0])
180178

181179
def _resizeImageAsRow(imgAsRow):
182-
if (imgAsRow.height, imgAsRow.width) == sz:
180+
if (imgAsRow.height, imgAsRow.width) == resize_sizes:
183181
return imgAsRow
184-
imgAsPil = imageStructToPIL(imgAsRow).resize(sz)
182+
imgAsPil = imageStructToPIL(imgAsRow).resize(resize_sizes)
185183
# PIL is RGB based while image schema is BGR based => we need to flip the channels
186184
imgAsArray = _reverseChannels(np.asarray(imgAsPil))
187185
return imageArrayToStruct(imgAsArray, origin=imgAsRow.origin)
@@ -228,11 +226,12 @@ def _decode(raw_bytes):
228226

229227
def readImagesWithCustomFn(path, decode_f, numPartition=None):
230228
"""
231-
Read a directory of images (or a single image) into a DataFrame using a custom library to decode the images.
229+
Read a directory of images (or a single image) into a DataFrame using a custom library to
230+
decode the images.
232231
233232
:param path: str, file path.
234-
:param decode_f: function to decode the raw bytes into an array compatible with one of the supported OpenCv modes.
235-
see @imageIO.PIL_decode for an example.
233+
:param decode_f: function to decode the raw bytes into an array compatible with one of the
234+
supported OpenCv modes. see @imageIO.PIL_decode for an example.
236235
:param numPartition: [optional] int, number or partitions to use for reading files.
237236
:return: DataFrame with schema == ImageSchema.imageSchema.
238237
"""

0 commit comments

Comments
 (0)