Skip to content
This repository was archived by the owner on Jul 31, 2023. It is now read-only.

Commit 0bf9263

Browse files
committed
Code cleanup: convert " -> '; #pylint -> # pylint
Change-Id: I84654365ff986df294419ca45952c3969608e607
1 parent 684d2db commit 0bf9263

File tree

11 files changed

+156
-151
lines changed

11 files changed

+156
-151
lines changed

tfrutil/accessor.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from tfrutil import client
2727

2828

29-
@pd.api.extensions.register_dataframe_accessor("tensorflow")
29+
@pd.api.extensions.register_dataframe_accessor('tensorflow')
3030
class TFRUtilAccessor:
3131
"""DataFrame Accessor class for TFRUtil."""
3232

@@ -37,13 +37,13 @@ def __init__(self, pandas_obj):
3737
def to_tfr(
3838
self,
3939
output_dir: str,
40-
runner: str = "DirectRunner",
40+
runner: str = 'DirectRunner',
4141
project: Optional[str] = None,
4242
region: Optional[str] = None,
4343
tfrutil_path: Optional[str] = None,
4444
dataflow_options: Union[Dict[str, Any], None] = None,
45-
job_label: str = "to-tfr",
46-
compression: Optional[str] = "gzip",
45+
job_label: str = 'to-tfr',
46+
compression: Optional[str] = 'gzip',
4747
num_shards: int = 0):
4848
"""TFRUtil Pandas Accessor.
4949
@@ -54,9 +54,9 @@ def to_tfr(
5454
import tfrutil
5555
5656
df.tfrutil.to_tfr(
57-
output_dir="gcs://foo/bar/train",
58-
runner="DirectRunner",
59-
compression="gzip",
57+
output_dir='gcs://foo/bar/train',
58+
runner='DirectRunner',
59+
compression='gzip',
6060
num_shards=10)
6161
6262
Args:
@@ -67,10 +67,9 @@ def to_tfr(
6767
tfrutil_path: Path to tfrutil source (Required if DataFlowRunner).
6868
dataflow_options: Optional dictionary containing DataFlow options.
6969
job_label: User supplied description for the beam job name.
70-
compression: Can be "gzip" or None for no compression.
70+
compression: Can be 'gzip' or None for no compression.
7171
num_shards: Number of shards to divide the TFRecords into. Default is
7272
0 = no sharding.
73-
7473
"""
7574
client.create_tfrecords(
7675
self._df,

tfrutil/accessor_test.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,12 @@
1616

1717
"""Tests for pandas accessor."""
1818

19+
import os
1920
import unittest
2021

2122
# pylint: disable=unused-import
22-
from tfrutil import accessor
23-
from tfrutil import constants
23+
24+
import tfrutil
2425
from tfrutil import test_utils
2526

2627

@@ -29,12 +30,14 @@ class DataFrameAccessor(unittest.TestCase):
2930

3031
def setUp(self):
3132
self.test_df = test_utils.get_test_df()
33+
self.output_dir = '/tmp/train'
34+
os.makedirs(self.output_dir, exist_ok=True)
3235

3336
def test_accessor(self):
3437
"""Tests pandas accessor."""
3538

36-
self.assertIsNone(self.test_df.tensorflow.to_tfr(runner="DirectRunner",
37-
output_dir="/tmp/train"))
39+
self.assertIsNone(self.test_df.tensorflow.to_tfr(
40+
runner='DirectRunner', output_dir=self.output_dir))
3841

39-
if __name__ == "__main__":
42+
if __name__ == '__main__':
4043
unittest.main()

tfrutil/beam_image.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -61,10 +61,10 @@ def load(image_uri):
6161
"""Loads an image."""
6262

6363
try:
64-
with tf.io.gfile.GFile(image_uri, "rb") as f:
64+
with tf.io.gfile.GFile(image_uri, 'rb') as f:
6565
return Image.open(f)
6666
except tf.python.framework.errors_impl.NotFoundError:
67-
raise OSError("File {} was not found.".format(image_uri))
67+
raise OSError('File {} was not found.'.format(image_uri))
6868

6969

7070
# pylint: disable=abstract-method
@@ -93,14 +93,14 @@ def process(
9393
try:
9494
image_uri = element[self.image_key]
9595
image = load(image_uri)
96-
d["image"] = encode(image)
97-
d["image_width"], d["image_height"] = image.size
98-
d["image_channels"] = mode_to_channel(image.mode)
96+
d['image'] = encode(image)
97+
d['image_width'], d['image_height'] = image.size
98+
d['image_channels'] = mode_to_channel(image.mode)
9999

100-
#pylint: disable=broad-except
100+
# pylint: disable=broad-except
101101
except Exception as e:
102-
logging.warning("Could not load image: %s", image_uri)
103-
logging.error("Exception was: %s", str(e))
102+
logging.warning('Could not load image: %s', image_uri)
103+
logging.error('Exception was: %s', str(e))
104104

105105
element.update(d)
106106
yield element

tfrutil/beam_image_test.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -37,17 +37,17 @@ class BeamImageTests(unittest.TestCase):
3737
def setUp(self):
3838
self.pipeline = test_utils.get_test_pipeline()
3939
self.df = test_utils.get_test_df()
40-
self.image_file = "tfrutil/test_data/images/cat/cat-640x853-1.jpg"
40+
self.image_file = 'tfrutil/test_data/images/cat/cat-640x853-1.jpg'
4141

4242
def test_load(self):
4343
"""Tests the image loading function."""
4444
img = beam_image.load(self.image_file)
4545
self.assertIsInstance(img, PIL.JpegImagePlugin.JpegImageFile)
4646

4747
def test_file_not_found_load(self):
48-
"""Test loading an image that doesn"t exist."""
48+
"""Test loading an image that doesn't exist."""
4949
with self.assertRaises(OSError):
50-
_ = beam_image.load("tfrutil/test_data/images/cat/food.jpg")
50+
_ = beam_image.load('tfrutil/test_data/images/cat/food.jpg')
5151

5252
def test_mode_to_channel(self):
5353
"""Tests `mode_to_channel`."""
@@ -91,11 +91,11 @@ def test_extract_image_dofn(self):
9191

9292
data = (
9393
p
94-
| "ReadFromDataFrame" >> beam.Create(self.df.values.tolist())
95-
| "FlattenDataFrame" >> beam.Map(
96-
lambda x: ",".join([str(item) for item in x]))
97-
| "DecodeCSV" >> beam.Map(converter.decode)
98-
| "ExtractImage" >> beam.ParDo(extract_images_fn)
94+
| 'ReadFromDataFrame' >> beam.Create(self.df.values.tolist())
95+
| 'FlattenDataFrame' >> beam.Map(
96+
lambda x: ','.join([str(item) for item in x]))
97+
| 'DecodeCSV' >> beam.Map(converter.decode)
98+
| 'ExtractImage' >> beam.ParDo(extract_images_fn)
9999
)
100100

101101
def key_matcher(expected_keys):
@@ -108,10 +108,10 @@ def _equal(actual):
108108
actual_keys = set(element.keys())
109109
if actual_keys != expected_keys_:
110110
raise util.BeamAssertException(
111-
"PCollection key match failed. Actual ({}) vs. expected ({})"
111+
'PCollection key match failed. Actual ({}) vs. expected ({})'
112112
.format(actual_keys, expected_keys_))
113113
return _equal
114114

115-
expected_keys = ["image_uri", "label", "split", "image",
116-
"image_height", "image_width", "image_channels"]
115+
expected_keys = ['image_uri', 'label', 'split', 'image',
116+
'image_height', 'image_width', 'image_channels']
117117
util.assert_that(data, key_matcher(expected_keys))

tfrutil/beam_pipeline.py

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -45,10 +45,10 @@ def _get_job_name(job_label: str = None) -> str:
4545
insure uniqueness.
4646
"""
4747

48-
job_name = "tfrutil-" + common.get_timestamp()
48+
job_name = 'tfrutil-' + common.get_timestamp()
4949
if job_label:
50-
job_label = job_label.replace("_", "-")
51-
job_name += "-" + job_label
50+
job_label = job_label.replace('_', '-')
51+
job_name += '-' + job_label
5252

5353
return job_name
5454

@@ -71,13 +71,13 @@ def _get_pipeline_options(
7171
"""Returns Beam pipeline options."""
7272

7373
options_dict = {
74-
"runner": runner,
75-
"staging_location": os.path.join(job_dir, "staging"),
76-
"temp_location": os.path.join(job_dir, "tmp"),
77-
"job_name": job_name,
78-
"teardown_policy": "TEARDOWN_ALWAYS",
79-
"save_main_session": True,
80-
"pipeline_type_check": False,
74+
'runner': runner,
75+
'staging_location': os.path.join(job_dir, 'staging'),
76+
'temp_location': os.path.join(job_dir, 'tmp'),
77+
'job_name': job_name,
78+
'teardown_policy': 'TEARDOWN_ALWAYS',
79+
'save_main_session': True,
80+
'pipeline_type_check': False,
8181
}
8282

8383
if project:
@@ -97,7 +97,7 @@ def _partition_fn(
9797
unused_num_partitions: int = -1) -> int:
9898
"""Returns index used to partition an element from a PCollection."""
9999
del unused_num_partitions
100-
dataset_type = element[constants.SPLIT_KEY].decode("utf-8")
100+
dataset_type = element[constants.SPLIT_KEY].decode('utf-8')
101101
try:
102102
index = constants.SPLIT_VALUES.index(dataset_type)
103103
except ValueError as e:
@@ -198,7 +198,7 @@ def build_pipeline(
198198

199199
#with beam.Pipeline(runner, options=options) as p:
200200
p = beam.Pipeline(options=options)
201-
with tft_beam.Context(temp_dir=os.path.join(job_dir, "tft_tmp")):
201+
with tft_beam.Context(temp_dir=os.path.join(job_dir, 'tft_tmp')):
202202

203203
converter = tft.coders.CsvCoder(constants.IMAGE_CSV_COLUMNS,
204204
constants.IMAGE_CSV_METADATA.schema)
@@ -210,11 +210,11 @@ def build_pipeline(
210210
# extract_images_fn.
211211
image_csv_data = (
212212
p
213-
| "ReadFromDataFrame" >> beam.Create(df.values.tolist())
214-
| "ToCSVRows" >> beam.Map(
215-
lambda x: ",".join([str(item) for item in x]))
216-
| "DecodeCSV" >> beam.Map(converter.decode)
217-
| "ReadImage" >> beam.ParDo(extract_images_fn)
213+
| 'ReadFromDataFrame' >> beam.Create(df.values.tolist())
214+
| 'ToCSVRows' >> beam.Map(
215+
lambda x: ','.join([str(item) for item in x]))
216+
| 'DecodeCSV' >> beam.Map(converter.decode)
217+
| 'ReadImage' >> beam.ParDo(extract_images_fn)
218218
)
219219

220220
# Split dataset into train and validation.
@@ -259,22 +259,22 @@ def build_pipeline(
259259
_ = (
260260
transformed_train_data
261261
| 'EncodeTrainData' >> beam.Map(transformed_data_coder.encode)
262-
| 'WriteTrainData' >> tfr_writer(prefix="train"))
262+
| 'WriteTrainData' >> tfr_writer(prefix='train'))
263263

264264
_ = (
265265
transformed_val_data
266266
| 'EncodeValData' >> beam.Map(transformed_data_coder.encode)
267-
| 'WriteValData' >> tfr_writer(prefix="val"))
267+
| 'WriteValData' >> tfr_writer(prefix='val'))
268268

269269
_ = (
270270
transformed_test_data
271271
| 'EncodeTestData' >> beam.Map(transformed_data_coder.encode)
272-
| 'WriteTestData' >> tfr_writer(prefix="test"))
272+
| 'WriteTestData' >> tfr_writer(prefix='test'))
273273

274274
_ = (
275275
discard_data
276276
| 'DiscardDataWriter' >> beam.io.WriteToText(
277-
os.path.join(job_dir, "discarded-data")))
277+
os.path.join(job_dir, 'discarded-data')))
278278

279279
# Output transform function and metadata
280280
_ = (transform_fn | 'WriteTransformFn' >> tft_beam.WriteTransformFn(

tfrutil/beam_pipeline_test.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -25,29 +25,29 @@
2525
from tfrutil import beam_pipeline
2626

2727

28-
#pylint: disable=protected-access
28+
# pylint: disable=protected-access
2929

3030
class BeamPipelineTests(unittest.TestCase):
3131
"""Tests for beam_image.py"""
3232

3333
def test_processing_fn_with_int_label(self):
34-
"Test preprocessing fn with integer label."
34+
'Test preprocessing fn with integer label.'
3535
element = {
36-
"split": "TRAIN",
37-
"image_uri": "gs://foo/bar.jpg",
38-
"label": 1}
36+
'split': 'TRAIN',
37+
'image_uri': 'gs://foo/bar.jpg',
38+
'label': 1}
3939
result = beam_pipeline._preprocessing_fn(element, integer_label=True)
4040
self.assertEqual(element, result)
4141

42-
@mock.patch("tfrutil.beam_pipeline.tft")
42+
@mock.patch('tfrutil.beam_pipeline.tft')
4343
def test_processing_fn_with_string_label(self, mock_transform):
44-
"Test preprocessing fn with string label."
44+
'Test preprocessing fn with string label.'
4545
mock_transform.compute_and_apply_vocabulary.return_value = tf.constant(
4646
0, dtype=tf.int64)
4747
element = {
48-
"split": "TRAIN",
49-
"image_uri": "gs://foo/bar.jpg",
50-
"label": tf.constant("cat", dtype=tf.string)}
48+
'split': 'TRAIN',
49+
'image_uri': 'gs://foo/bar.jpg',
50+
'label': tf.constant('cat', dtype=tf.string)}
5151
result = beam_pipeline._preprocessing_fn(element, integer_label=False)
5252
result['label'] = result['label'].numpy()
5353
self.assertEqual(0, result['label'])
@@ -65,14 +65,14 @@ def test_partition_fn(self):
6565
"""Test the partition function."""
6666

6767
test_data = {
68-
"split": "update_me",
69-
"image_uri": "gs://foo/bar0.jpg",
70-
"label": 1}
68+
'split': 'update_me',
69+
'image_uri': 'gs://foo/bar0.jpg',
70+
'label': 1}
7171

72-
for i, part in enumerate(["TRAIN", "VALIDATION", "TEST", "FOO"]):
73-
test_data['split'] = part.encode("utf-8")
72+
for i, part in enumerate(['TRAIN', 'VALIDATION', 'TEST', 'FOO']):
73+
test_data['split'] = part.encode('utf-8')
7474
index = beam_pipeline._partition_fn(test_data)
7575

7676
self.assertEqual(
7777
index, i,
78-
"{} should be index {} but was index {}".format(part, i, index))
78+
'{} should be index {} but was index {}'.format(part, i, index))

0 commit comments

Comments
 (0)