Skip to content
This repository was archived by the owner on Jul 31, 2023. It is now read-only.

Commit fd18181

Browse files
klmilammbernico
authored andcommitted
Use proper capitalization for Dataflow, add CLI for beam_pipeline_test
Change-Id: I10253fa57237eac6394bb13aa0920fbb804d423e
1 parent 2e68669 commit fd18181

File tree

4 files changed

+41
-37
lines changed

4 files changed

+41
-37
lines changed

tfrutil/beam_pipeline.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
"""TFRUtil Beam Pipeline.
1818
19-
This file implements the full beam pipeline for TFRUtil.
19+
This file implements the full Beam pipeline for TFRUtil.
2020
"""
2121

2222
from typing import Any, Dict, Generator, Union
@@ -94,7 +94,7 @@ def _get_pipeline_options(
9494
options_dict['project'] = project
9595
if region:
9696
options_dict['region'] = region
97-
if runner == 'DataFlowRunner':
97+
if runner == 'DataflowRunner':
9898
options_dict['setup_file'] = _get_setup_py_filepath()
9999
if dataflow_options:
100100
options_dict.update(dataflow_options)
@@ -204,15 +204,15 @@ def build_pipeline(
204204
"""Runs TFRUtil Beam Pipeline.
205205
206206
Args:
207-
df: Pandas Dataframe
207+
df: Pandas DataFrame
208208
job_label: User description for the beam job.
209-
runner: Beam Runner: (e.g. DataFlowRunner, DirectRunner).
210-
project: GCP project ID (if DataFlowRunner)
211-
region: GCP compute region (if DataFlowRunner)
209+
runner: Beam Runner: (e.g. DataflowRunner, DirectRunner).
210+
project: GCP project ID (if DataflowRunner)
211+
region: GCP compute region (if DataflowRunner)
212212
output_dir: GCS or Local Path for output.
213213
compression: gzip or None.
214214
num_shards: Number of shards.
215-
dataflow_options: DataFlow Runner Options (optional)
215+
dataflow_options: Dataflow Runner Options (optional)
216216
integer_label: Flags if label is already an integer.
217217
218218
Returns:

tfrutil/beam_pipeline_test.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,3 +83,7 @@ def test_get_setup_py_filepath(self):
8383
filepath = beam_pipeline._get_setup_py_filepath()
8484
self.assertTrue(os.path.isfile(filepath))
8585
self.assertTrue(os.path.isabs(filepath))
86+
87+
88+
if __name__ == '__main__':
89+
unittest.main()

tfrutil/client.py

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -33,20 +33,20 @@
3333

3434

3535
def _validate_data(df):
36-
""" Verify required image csv columsn exist in data."""
36+
""" Verifies required image csv columsn exist in data."""
3737
if constants.IMAGE_URI_KEY not in df.columns:
3838
# or label_col not in df.columns:
3939
raise AttributeError(
40-
'Dataframe must contain image_uri column {}.')
40+
'DataFrame must contain image_uri column {}.')
4141
if constants.LABEL_KEY not in df.columns:
4242
raise AttributeError(
43-
'Dataframe must contain label column.')
43+
'DataFrame must contain label column.')
4444
if constants.SPLIT_KEY not in df.columns:
4545
raise AttributeError(
46-
'Dataframe must contain split column.')
46+
'DataFrame must contain split column.')
4747
if list(df.columns) != constants.IMAGE_CSV_COLUMNS:
4848
raise AttributeError(
49-
'Dataframe column order must be {}'.format(
49+
'DataFrame column order must be {}'.format(
5050
constants.IMAGE_CSV_COLUMNS))
5151

5252

@@ -56,18 +56,18 @@ def _validate_runner(
5656
project: str,
5757
region: str):
5858
"""Validates an appropriate beam runner is chosen."""
59-
if runner not in ['DataFlowRunner', 'DirectRunner']:
59+
if runner not in ['DataflowRunner', 'DirectRunner']:
6060
raise AttributeError('Runner {} is not supported.'.format(runner))
6161

6262
# gcs_path is a bool, true if all image paths start with gs://
6363
gcs_path = df[constants.IMAGE_URI_KEY].str.startswith('gs://').all()
64-
if (runner == 'DataFlowRunner') & (not gcs_path):
65-
raise AttributeError('DataFlowRunner requires GCS image locations.')
64+
if (runner == 'DataflowRunner') & (not gcs_path):
65+
raise AttributeError('DataflowRunner requires GCS image locations.')
6666

67-
if (runner == 'DataFlowRunner') & (
67+
if (runner == 'DataflowRunner') & (
6868
any(not v for v in [project, region])):
6969
raise AttributeError(
70-
'DataFlowRunner requires valid `project` and `region` to be specified.'
70+
'DataflowRunner requires valid `project` and `region` to be specified.'
7171
'The `project` is {} and `region` is {}'.format(project, region))
7272

7373
# def read_image_directory(dirpath) -> pd.DataFrame:
@@ -188,19 +188,19 @@ def create_tfrecords(
188188
a Pandas DataFrame.
189189
If 'infer' (default), header is taken from the first line of a CSV
190190
runner: Beam runner. Can be 'DirectRunner' or 'DataFlowRunner'
191-
project: GCP project name (Required if DataFlowRunner)
192-
region: GCP region name (Required if DataFlowRunner)
193-
dataflow_options: Options dict for dataflow runner
194-
job_label: User supplied description for the beam job name.
191+
project: GCP project name (Required if DataflowRunner)
192+
region: GCP region name (Required if DataflowRunner)
193+
dataflow_options: Options dict for DataflowRunner
194+
job_label: User supplied description for the Beam job name.
195195
compression: Can be 'gzip' or None for no compression.
196196
num_shards: Number of shards to divide the TFRecords into. Default is
197197
0 = no sharding.
198198
199199
Returns:
200200
job_results: Dict
201-
job_id: DataFlow Job ID or 'DirectRunner'
201+
job_id: Dataflow Job ID or 'DirectRunner'
202202
metrics: (optional) Beam metrics. Only used for DirectRunner
203-
dataflow_url: (optional) Job URL for DataFlowRunner
203+
dataflow_url: (optional) Job URL for DataflowRunner
204204
"""
205205

206206
df = to_dataframe(input_data, header, names)
@@ -253,8 +253,8 @@ def create_tfrecords(
253253
logging.info("Job Complete.")
254254

255255
else:
256-
logging.info("Using DataFlow Runner.")
257-
# Construct DataFlow URL
256+
logging.info("Using Dataflow Runner.")
257+
# Construct Dataflow URL
258258

259259
job_id = result.job_id()
260260

@@ -272,8 +272,8 @@ def create_tfrecords(
272272

273273
logging.shutdown()
274274

275-
if runner == 'DataFlowRunner':
276-
# if this is a dataflow job, copy the logfile to gcs
275+
if runner == 'DataflowRunner':
276+
# if this is a Dataflow job, copy the logfile to GCS
277277
common.copy_logfile_to_gcs(logfile, output_dir)
278278

279279
return job_result

tfrutil/client_test.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def test_create_tfrecords_direct_runner(self, mock_beam):
5151

5252
@mock.patch('tfrutil.client.beam_pipeline')
5353
def test_create_tfrecords_dataflow_runner(self, mock_beam):
54-
"""Tests `create_tfrecords` DataFlow case."""
54+
"""Tests `create_tfrecords` Dataflow case."""
5555
mock_beam.build_pipeline().run().job_id.return_value = 'foo_id'
5656

5757
df2 = self.test_df.copy()
@@ -67,7 +67,7 @@ def test_create_tfrecords_dataflow_runner(self, mock_beam):
6767
os.makedirs(outdir, exist_ok=True)
6868
r = client.create_tfrecords(
6969
df2,
70-
runner='DataFlowRunner',
70+
runner='DataflowRunner',
7171
output_dir=outdir,
7272
region=self.test_region,
7373
project=self.test_project)
@@ -112,7 +112,7 @@ def test_missing_split(self):
112112
client._validate_data(df2)
113113

114114
def test_columns_out_of_order(self):
115-
"""Tests validating column order wrong."""
115+
"""Tests validating wrong column order."""
116116
with self.assertRaises(AttributeError):
117117
df2 = self.test_df.copy()
118118
cols = ['image_uri', 'split', 'label']
@@ -137,38 +137,38 @@ def test_invalid_runner(self):
137137
region=self.test_region)
138138

139139
def test_local_path_with_dataflow_runner(self):
140-
"""Tests DataFlowRunner conflict with local path."""
140+
"""Tests DataflowRunner conflict with local path."""
141141
with self.assertRaises(AttributeError):
142142
client._validate_runner(
143143
self.df_test,
144-
runner='DataFlowRunner',
144+
runner='DataflowRunner',
145145
project=self.test_project,
146146
region=self.test_region)
147147

148148
def test_gcs_path_with_dataflow_runner(self):
149-
"""Tests DataFlowRunner with gcs path."""
149+
"""Tests DataflowRunner with GCS path."""
150150
df2 = self.test_df.copy()
151151
df2[constants.IMAGE_URI_KEY] = 'gs://' + df2[constants.IMAGE_URI_KEY]
152152
self.assertIsNone(
153153
client._validate_runner(
154154
df2,
155-
runner='DataFlowRunner',
155+
runner='DataflowRunner',
156156
project=self.test_project,
157157
region=self.test_region))
158158

159159
def test_gcs_path_with_dataflow_runner_missing_param(self):
160-
"""Tests DataFlowRunner with missing required parameter."""
160+
"""Tests DataflowRunner with missing required parameter."""
161161
df2 = self.test_df.copy()
162162
df2[constants.IMAGE_URI_KEY] = 'gs://' + df2[constants.IMAGE_URI_KEY]
163163
for p, r in [
164164
(None, self.test_region), (self.test_project, None), (None, None)]:
165165
with self.assertRaises(AttributeError) as context:
166166
client._validate_runner(
167167
df2,
168-
runner='DataFlowRunner',
168+
runner='DataflowRunner',
169169
project=p,
170170
region=r)
171-
self.assertTrue('DataFlowRunner requires valid `project` and `region`'
171+
self.assertTrue('DataflowRunner requires valid `project` and `region`'
172172
in repr(context.exception))
173173

174174

0 commit comments

Comments
 (0)