Skip to content
This repository was archived by the owner on Jul 31, 2023. It is now read-only.

Commit cbba97c

Browse files
committed
Adding metrics and return results.
Change-Id: Icc1efe26b51f7e3780d9724f488af1f730d0d15b
1 parent a16f7e2 commit cbba97c

File tree

11 files changed

+250
-46
lines changed

11 files changed

+250
-46
lines changed

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,4 @@ pylint >= 2.5.3
99
fire >= 0.3.1
1010
jupyter >= 1.0.0
1111
tensorflow >= 2.2.0
12-
gcsfs >= 0.6.2
12+
pyarrow < 0.17

setup.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
"pylint >= 2.5.3",
3030
"fire >= 0.3.1",
3131
"tensorflow >= 2.2.0",
32-
"gcsfs >= 0.6.2",
3332
"pyarrow < 0.17",
3433
]
3534

tfrutil/accessor.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,10 @@
2222
"""
2323
from typing import Any, Dict, Optional, Union
2424
import pandas as pd
25+
from IPython.core import display
2526

2627
from tfrutil import client
28+
from tfrutil import constants
2729

2830

2931
@pd.api.extensions.register_dataframe_accessor('tensorflow')
@@ -43,7 +45,7 @@ def to_tfr(
4345
dataflow_options: Union[Dict[str, Any], None] = None,
4446
job_label: str = 'to-tfr',
4547
compression: Optional[str] = 'gzip',
46-
num_shards: int = 0):
48+
num_shards: int = 0) -> Dict[str, Any]:
4749
"""TFRUtil Pandas Accessor.
4850
4951
TFRUtil provides an easy interface to create image-based tensorflow records
@@ -68,8 +70,14 @@ def to_tfr(
6870
compression: Can be 'gzip' or None for no compression.
6971
num_shards: Number of shards to divide the TFRecords into. Default is
7072
0 = no sharding.
73+
Returns:
74+
job_results: A dictionary of job results.
7175
"""
72-
client.create_tfrecords(
76+
display.display(
77+
display.HTML(
78+
'<b>Logging output to /tmp/{} </b>'.format(constants.LOGFILE)))
79+
80+
r = client.create_tfrecords(
7381
self._df,
7482
output_dir=output_dir,
7583
runner=runner,
@@ -79,4 +87,4 @@ def to_tfr(
7987
job_label=job_label,
8088
compression=compression,
8189
num_shards=num_shards)
82-
#TODO (mikebernico) Add notebook output for user.
90+
return r

tfrutil/accessor_test.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,9 @@ def setUp(self):
3636
def test_accessor(self):
3737
"""Tests pandas accessor."""
3838

39-
self.assertIsNone(self.test_df.tensorflow.to_tfr(
40-
runner='DirectRunner', output_dir=self.output_dir))
39+
r = self.test_df.tensorflow.to_tfr(
40+
runner='DirectRunner', output_dir=self.output_dir)
41+
self.assertTrue('metrics' in r)
4142

4243
if __name__ == '__main__':
4344
unittest.main()

tfrutil/beam_image.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from typing import Any, Dict, Generator, Tuple
2121

2222
import apache_beam as beam
23+
from apache_beam.metrics import Metrics
2324
import tensorflow as tf
2425
from PIL import Image
2526

@@ -76,11 +77,14 @@ def __init__(self, image_key: str):
7677
"""Constructor."""
7778
super().__init__()
7879
self.image_key = image_key
80+
self.image_good_counter = Metrics.counter(self.__class__, 'image_good')
81+
self.image_bad_counter = Metrics.counter(self.__class__, 'image_bad')
82+
7983

8084
# pylint: disable=unused-argument
8185
def process(
8286
self,
83-
element: Dict,
87+
element: Dict[str, Any],
8488
*args: Tuple[Any, ...],
8589
**kwargs: Dict) -> Generator[Dict[str, Any], None, None]:
8690
"""Loads image and creates image features.
@@ -96,11 +100,13 @@ def process(
96100
d['image'] = encode(image)
97101
d['image_width'], d['image_height'] = image.size
98102
d['image_channels'] = mode_to_channel(image.mode)
103+
self.image_good_counter.inc()
99104

100105
# pylint: disable=broad-except
101106
except Exception as e:
102107
logging.warning('Could not load image: %s', image_uri)
103108
logging.error('Exception was: %s', str(e))
109+
self.image_bad_counter.inc()
104110

105111
element.update(d)
106112
yield element

tfrutil/beam_pipeline.py

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,11 @@
1919
This file implements the full beam pipeline for TFRUtil.
2020
"""
2121

22+
from typing import Any, Dict, Generator, Union
23+
2224
import functools
2325
import logging
2426
import os
25-
from typing import Any, Dict, Union
2627

2728
import apache_beam as beam
2829
import pandas as pd
@@ -159,6 +160,34 @@ def _preprocessing_fn(inputs, integer_label: bool = False):
159160
return outputs
160161

161162

163+
# pylint: disable=abstract-method
164+
165+
class ToCSVRows(beam.DoFn):
166+
"""Adds image to PCollection."""
167+
168+
def __init__(self):
169+
"""Constructor."""
170+
super().__init__()
171+
self.row_count = beam.metrics.Metrics.counter(self.__class__, 'row_count')
172+
173+
174+
# pylint: disable=unused-argument
175+
# pylint: disable=arguments-differ
176+
def process(
177+
self,
178+
element: Dict[str, Any]
179+
) -> Generator[Dict[str, Any], None, None]:
180+
"""Loads image and creates image features.
181+
182+
This DoFn extracts an image being stored on local disk or GCS and
183+
yields a base64 encoded image, the image height, image width, and channels.
184+
"""
185+
element = ','.join([str(item) for item in element])
186+
self.row_count.inc()
187+
yield element
188+
189+
190+
162191
# pylint: disable=too-many-arguments
163192
# pylint: disable=too-many-locals
164193
def build_pipeline(
@@ -210,15 +239,15 @@ def build_pipeline(
210239
constants.IMAGE_CSV_METADATA.schema)
211240

212241
extract_images_fn = beam_image.ExtractImagesDoFn(constants.IMAGE_URI_KEY)
242+
flatten_rows = ToCSVRows()
213243

214244
# Each element in the image_csv_data PCollection will be a dict
215245
# including the image_csv_columns and the image features created from
216246
# extract_images_fn.
217247
image_csv_data = (
218248
p
219249
| 'ReadFromDataFrame' >> beam.Create(df.values.tolist())
220-
| 'ToCSVRows' >> beam.Map(
221-
lambda x: ','.join([str(item) for item in x]))
250+
| 'ToCSVRows' >> beam.ParDo(flatten_rows)
222251
| 'DecodeCSV' >> beam.Map(converter.decode)
223252
| 'ReadImage' >> beam.ParDo(extract_images_fn)
224253
)

tfrutil/client.py

Lines changed: 94 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,11 @@
2323
import os
2424
from typing import Any, Dict, Union, Optional, Sequence
2525

26+
import apache_beam as beam
2627
import pandas as pd
2728
import tensorflow as tf
2829

29-
# from tfrutil import common
30+
from tfrutil import common
3031
from tfrutil import constants
3132
from tfrutil import beam_pipeline
3233

@@ -118,6 +119,40 @@ def to_dataframe(
118119

119120
return df
120121

122+
123+
def _get_beam_metric(
124+
metric_filter: beam.metrics.MetricsFilter,
125+
result: beam.runners.runner.PipelineResult,
126+
metric_type: str = 'counters') -> Optional[int]:
127+
"""Queries a beam pipeline result for a specificed metric.
128+
129+
Args:
130+
metric_filter: an instance of apache_beam.metrics.MetricsFilter()
131+
metric_type: A metric type (counters, distributions, etc.)
132+
133+
Returns:
134+
Counter value or None
135+
"""
136+
query_result = result.metrics().query(metric_filter)
137+
result_val = None
138+
if query_result[metric_type]:
139+
result_val = query_result[metric_type][0].result
140+
return result_val
141+
142+
143+
def _configure_logging(logfile):
144+
"""Configures logging options."""
145+
# Remove default handlers that TF set for us.
146+
logger = logging.getLogger('')
147+
logger.handlers = []
148+
handler = logging.FileHandler(logfile)
149+
logger.addHandler(handler)
150+
logger.setLevel(constants.LOGLEVEL)
151+
# This disables annoying Tensorflow and TFX info/warning messages on console.
152+
tf_logger = logging.getLogger('tensorflow')
153+
tf_logger.handlers = []
154+
tf_logger.addHandler(handler)
155+
121156
# pylint: disable=too-many-arguments
122157
# pylint: disable=too-many-locals
123158

@@ -132,7 +167,7 @@ def create_tfrecords(
132167
dataflow_options: Optional[Dict[str, Any]] = None,
133168
job_label: str = 'create-tfrecords',
134169
compression: Optional[str] = 'gzip',
135-
num_shards: int = 0):
170+
num_shards: int = 0) -> Dict[str, Any]:
136171
"""Generates TFRecord files from given input data.
137172
138173
TFRUtil provides an easy interface to create image-based tensorflow records
@@ -161,18 +196,21 @@ def create_tfrecords(
161196
num_shards: Number of shards to divide the TFRecords into. Default is
162197
0 = no sharding.
163198
199+
Returns:
200+
job_results: Dict
201+
job_id: DataFlow Job ID or 'DirectRunner'
202+
metrics: (optional) Beam metrics. Only used for DirectRunner
203+
dataflow_url: (optional) Job URL for DataFlowRunner
164204
"""
165205

166206
df = to_dataframe(input_data, header, names)
167207

168208
_validate_data(df)
169209
_validate_runner(df, runner, project, region)
170-
#os.makedirs(output_dir, exist_ok=True)
171-
#TODO (mikebernico) this doesn't work with GCS locations...
210+
172211
logfile = os.path.join('/tmp', constants.LOGFILE)
173-
logging.basicConfig(filename=logfile, level=constants.LOGLEVEL)
174-
# This disables annoying Tensorflow and TFX info/warning messages.
175-
logging.getLogger('tensorflow').setLevel(logging.ERROR)
212+
_configure_logging(logfile)
213+
176214

177215
integer_label = pd.api.types.is_integer_dtype(df[constants.LABEL_KEY])
178216
p = beam_pipeline.build_pipeline(
@@ -187,15 +225,55 @@ def create_tfrecords(
187225
dataflow_options=dataflow_options,
188226
integer_label=integer_label)
189227

190-
# TODO(mikbernico) Handle this async for the DataFlow case.
191228
result = p.run()
192-
result.wait_until_finish()
193-
# TODO(mikebernico) Add metrics here.
229+
230+
if runner == 'DirectRunner':
231+
logging.info("Using DirectRunner. TFRUtil will block until job completes.")
232+
result.wait_until_finish()
233+
234+
row_count_filter = beam.metrics.MetricsFilter().with_name('row_count')
235+
good_image_filter = beam.metrics.MetricsFilter().with_name('image_good')
236+
bad_image_filter = beam.metrics.MetricsFilter().with_name('image_bad')
237+
238+
row_count = _get_beam_metric(row_count_filter, result)
239+
good_image_count = _get_beam_metric(good_image_filter, result)
240+
bad_image_count = _get_beam_metric(bad_image_filter, result)
241+
242+
# TODO(mikebernico): Profile metric impact with larger dataset.
243+
metrics = {
244+
'rows': row_count,
245+
'good_images': good_image_count,
246+
'bad_images': bad_image_count,
247+
}
248+
249+
job_result = {
250+
'job_id': 'DirectRunner',
251+
'metrics': metrics
252+
}
253+
logging.info("Job Complete.")
254+
255+
else:
256+
logging.info("Using DataFlow Runner.")
257+
# Construct DataFlow URL
258+
259+
job_id = result.job_id()
260+
261+
url = (
262+
constants.CONSOLE_DATAFLOW_URI +
263+
region +
264+
'/' +
265+
job_id +
266+
'?project=' +
267+
project)
268+
job_result = {
269+
'job_id': job_id,
270+
'dataflow_url': url
271+
}
272+
194273
logging.shutdown()
195274

196-
# FIXME: Issue where GCSFS is not picking up the `logfile` even if it exists.
197-
if os.path.exists(logfile):
198-
pass
199-
# common.copy_to_gcs(logfile,
200-
# os.path.join(output_dir, constants.LOGFILE),
201-
# recursive=False)
275+
if runner == 'DataFlowRunner':
276+
# if this is a dataflow job, copy the logfile to gcs
277+
common.copy_logfile_to_gcs(logfile, output_dir)
278+
279+
return job_result

tfrutil/client_test.py

Lines changed: 35 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
# limitations under the License.
1616

1717
"""Tests for client."""
18-
18+
import os
1919
from typing import List
2020

2121
import csv
@@ -35,14 +35,43 @@ class ClientTest(unittest.TestCase):
3535

3636
def setUp(self):
3737
self.test_df = test_utils.get_test_df()
38+
self.test_region = 'us-central1'
39+
self.test_project = 'foo'
3840

39-
def test_create_tfrecords(self):
40-
"""Tests `create_tfrecords` valid case."""
41-
42-
self.assertIsNone(client.create_tfrecords(
41+
@mock.patch('tfrutil.client.beam_pipeline')
42+
def test_create_tfrecords_direct_runner(self, mock_beam):
43+
"""Tests `create_tfrecords` Direct case."""
44+
mock_beam.build_pipeline().run().wait_until_finished.return_value = {
45+
'rows':6}
46+
r = client.create_tfrecords(
4347
self.test_df,
4448
runner='DirectRunner',
45-
output_dir='/tmp/train'))
49+
output_dir='/tmp/direct_runner')
50+
self.assertTrue('metrics' in r)
51+
52+
@mock.patch('tfrutil.client.beam_pipeline')
53+
def test_create_tfrecords_dataflow_runner(self, mock_beam):
54+
"""Tests `create_tfrecords` DataFlow case."""
55+
mock_beam.build_pipeline().run().job_id.return_value = 'foo_id'
56+
57+
df2 = self.test_df.copy()
58+
df2[constants.IMAGE_URI_KEY] = 'gs://' + df2[constants.IMAGE_URI_KEY]
59+
60+
outdir = '/tmp/dataflow_runner'
61+
62+
expected = {
63+
'job_id': 'foo_id',
64+
'dataflow_url': 'https://console.cloud.google.com/dataflow/jobs/' +
65+
'us-central1/foo_id?project=foo'}
66+
67+
os.makedirs(outdir, exist_ok=True)
68+
r = client.create_tfrecords(
69+
df2,
70+
runner='DataFlowRunner',
71+
output_dir=outdir,
72+
region=self.test_region,
73+
project=self.test_project)
74+
self.assertEqual(r, expected)
4675

4776

4877
# pylint: disable=protected-access

0 commit comments

Comments
 (0)