Skip to content
This repository was archived by the owner on Jul 31, 2023. It is now read-only.

Commit b035336

Browse files
committed
Removed tfrutil_path option
Change-Id: I3affb81cb5fe4d8af837f570cbeda8a4f7292d64
1 parent 0bf9263 commit b035336

File tree

4 files changed

+19
-34
lines changed

4 files changed

+19
-34
lines changed

tfrutil/accessor.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@ def to_tfr(
4040
runner: str = 'DirectRunner',
4141
project: Optional[str] = None,
4242
region: Optional[str] = None,
43-
tfrutil_path: Optional[str] = None,
4443
dataflow_options: Union[Dict[str, Any], None] = None,
4544
job_label: str = 'to-tfr',
4645
compression: Optional[str] = 'gzip',
@@ -64,7 +63,6 @@ def to_tfr(
6463
runner: Beam runner. Can be DirectRunner or DataFlowRunner.
6564
project: GCP project name (Required if DataFlowRunner).
6665
region: GCP region name (Required if DataFlowRunner).
67-
tfrutil_path: Path to tfrutil source (Required if DataFlowRunner).
6866
dataflow_options: Optional dictionary containing DataFlow options.
6967
job_label: User supplied description for the beam job name.
7068
compression: Can be 'gzip' or None for no compression.
@@ -77,7 +75,6 @@ def to_tfr(
7775
runner=runner,
7876
project=project,
7977
region=region,
80-
tfrutil_path=tfrutil_path,
8178
dataflow_options=dataflow_options,
8279
job_label=job_label,
8380
compression=compression,

tfrutil/beam_pipeline.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,6 @@ def _get_pipeline_options(
6565
job_dir: str,
6666
project: str,
6767
region: str,
68-
tfrutil_path: str,
6968
dataflow_options: Union[Dict[str, Any], None]
7069
) -> beam.pipeline.PipelineOptions:
7170
"""Returns Beam pipeline options."""
@@ -84,8 +83,8 @@ def _get_pipeline_options(
8483
options_dict['project'] = project
8584
if region:
8685
options_dict['region'] = region
87-
if tfrutil_path:
88-
options_dict['setup_file'] = os.path.join(tfrutil_path, 'setup.py')
86+
if runner == 'DataFlowRunner':
87+
options_dict['setup_file'] = os.path.join('..', 'setup.py')
8988
if dataflow_options:
9089
options_dict.update(dataflow_options)
9190

@@ -158,7 +157,6 @@ def build_pipeline(
158157
runner: str,
159158
project: str,
160159
region: str,
161-
tfrutil_path: str,
162160
output_dir: str,
163161
compression: str,
164162
num_shards: int,
@@ -172,7 +170,6 @@ def build_pipeline(
172170
runner: Beam Runner: (e.g. DataFlowRunner, DirectRunner).
173171
project: GCP project ID (if DataFlowRunner)
174172
region: GCP compute region (if DataFlowRunner)
175-
tfrutil_path: Path for TFRUtil source (required for DataFlowRunner)
176173
output_dir: GCS or Local Path for output.
177174
compression: gzip or None.
178175
num_shards: Number of shards.
@@ -193,7 +190,6 @@ def build_pipeline(
193190
job_dir,
194191
project,
195192
region,
196-
tfrutil_path,
197193
dataflow_options)
198194

199195
#with beam.Pipeline(runner, options=options) as p:

tfrutil/client.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,7 @@ def _validate_runner(
5353
df: pd.DataFrame,
5454
runner: str,
5555
project: str,
56-
region: str,
57-
tfrutil_path: str):
56+
region: str):
5857
"""Validates an appropriate beam runner is chosen."""
5958
if runner not in ['DataFlowRunner', 'DirectRunner']:
6059
raise AttributeError('Runner {} is not supported.'.format(runner))
@@ -65,11 +64,10 @@ def _validate_runner(
6564
raise AttributeError('DataFlowRunner requires GCS image locations.')
6665

6766
if (runner == 'DataFlowRunner') & (
68-
any(not v for v in [project, region, tfrutil_path])):
69-
raise AttributeError('DataFlowRunner requires project region and '
70-
'tfrutil_path to be set however project is {} '
71-
'region is {} and tfrutil_path is {}'.format(
72-
project, region, tfrutil_path))
67+
any(not v for v in [project, region])):
68+
raise AttributeError('DataFlowRunner requires project region and region '
69+
'project is {} and region is {}'.format(
70+
project, region))
7371

7472
# def read_image_directory(dirpath) -> pd.DataFrame:
7573
# """Reads image data from a directory into a Pandas DataFrame."""
@@ -131,7 +129,6 @@ def create_tfrecords(
131129
runner: str = 'DirectRunner',
132130
project: Optional[str] = None,
133131
region: Optional[str] = None,
134-
tfrutil_path: Optional[str] = None,
135132
dataflow_options: Optional[Dict[str, Any]] = None,
136133
job_label: str = 'create-tfrecords',
137134
compression: Optional[str] = 'gzip',
@@ -158,7 +155,6 @@ def create_tfrecords(
158155
runner: Beam runner. Can be 'DirectRunner' or 'DataFlowRunner'
159156
project: GCP project name (Required if DataFlowRunner)
160157
region: GCP region name (Required if DataFlowRunner)
161-
tfrutil_path: Path to TFRutil source (Required if DataFlowRunner)
162158
dataflow_options: Options dict for dataflow runner
163159
job_label: User supplied description for the beam job name.
164160
compression: Can be 'gzip' or None for no compression.
@@ -170,7 +166,7 @@ def create_tfrecords(
170166
df = to_dataframe(input_data, header, names)
171167

172168
_validate_data(df)
173-
_validate_runner(df, runner, project, region, tfrutil_path)
169+
_validate_runner(df, runner, project, region)
174170
#os.makedirs(output_dir, exist_ok=True)
175171
#TODO (mikebernico) this doesn't work with GCS locations...
176172
logfile = os.path.join('/tmp', constants.LOGFILE)
@@ -185,7 +181,6 @@ def create_tfrecords(
185181
runner=runner,
186182
project=project,
187183
region=region,
188-
tfrutil_path=tfrutil_path,
189184
output_dir=output_dir,
190185
compression=compression,
191186
num_shards=num_shards,

tfrutil/client_test.py

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,8 @@ class InputValidationTest(unittest.TestCase):
5252

5353
def setUp(self):
5454
self.test_df = test_utils.get_test_df()
55+
self.test_region = 'us-central1'
56+
self.test_project = 'foo'
5557

5658
def test_valid_dataframe(self):
5759
"""Tests valid DataFrame input."""
@@ -93,29 +95,26 @@ def test_valid_runner(self):
9395
self.assertIsNone(client._validate_runner(
9496
self.test_df,
9597
runner='DirectRunner',
96-
project='foo',
97-
region='us-central1',
98-
tfrutil_path='foo/'))
98+
project=self.test_project,
99+
region=self.test_region))
99100

100101
def test_invalid_runner(self):
101102
"""Tests invalid runner."""
102103
with self.assertRaises(AttributeError):
103104
client._validate_runner(
104105
self.test_df,
105106
runner='FooRunner',
106-
project='foo',
107-
region='us-central1',
108-
tfrutil_path='foo/')
107+
project=self.test_project,
108+
region=self.test_region)
109109

110110
def test_local_path_with_dataflow_runner(self):
111111
"""Tests DataFlowRunner conflict with local path."""
112112
with self.assertRaises(AttributeError):
113113
client._validate_runner(
114114
self.df_test,
115115
runner='DataFlowRunner',
116-
project='foo',
117-
region='us-central1',
118-
tfrutil_path='foo/')
116+
project=self.test_project,
117+
region=self.test_region)
119118

120119
def test_gcs_path_with_dataflow_runner(self):
121120
"""Tests DataFlowRunner with gcs path."""
@@ -125,9 +124,8 @@ def test_gcs_path_with_dataflow_runner(self):
125124
client._validate_runner(
126125
df2,
127126
runner='DataFlowRunner',
128-
project='foo',
129-
region='us-central1',
130-
tfrutil_path='foo/'))
127+
project=self.test_project,
128+
region=self.test_region))
131129

132130
def test_gcs_path_with_dataflow_runner_missing_param(self):
133131
"""Tests DataFlowRunner with missing required parameter."""
@@ -138,8 +136,7 @@ def test_gcs_path_with_dataflow_runner_missing_param(self):
138136
df2,
139137
runner='DataFlowRunner',
140138
project=None,
141-
region='us-central1',
142-
tfrutil_path='foo/')
139+
region=self.test_region)
143140

144141

145142
def _make_csv_tempfile(data: List[List[str]]) -> tempfile.NamedTemporaryFile:

0 commit comments

Comments
 (0)