2323import os
2424from typing import Any , Dict , Union , Optional , Sequence
2525
26+ import apache_beam as beam
2627import pandas as pd
2728import tensorflow as tf
2829
29- # from tfrutil import common
30+ from tfrutil import common
3031from tfrutil import constants
3132from tfrutil import beam_pipeline
3233
@@ -118,6 +119,40 @@ def to_dataframe(
118119
119120 return df
120121
122+
123+ def _get_beam_metric (
124+ metric_filter : beam .metrics .MetricsFilter ,
125+ result : beam .runners .runner .PipelineResult ,
126+ metric_type : str = 'counters' ) -> Optional [int ]:
127+ """Queries a beam pipeline result for a specificed metric.
128+
129+ Args:
130+ metric_filter: an instance of apache_beam.metrics.MetricsFilter()
131+ metric_type: A metric type (counters, distributions, etc.)
132+
133+ Returns:
134+ Counter value or None
135+ """
136+ query_result = result .metrics ().query (metric_filter )
137+ result_val = None
138+ if query_result [metric_type ]:
139+ result_val = query_result [metric_type ][0 ].result
140+ return result_val
141+
142+
143+ def _configure_logging (logfile ):
144+ """Configures logging options."""
145+ # Remove default handlers that TF set for us.
146+ logger = logging .getLogger ('' )
147+ logger .handlers = []
148+ handler = logging .FileHandler (logfile )
149+ logger .addHandler (handler )
150+ logger .setLevel (constants .LOGLEVEL )
151+ # This disables annoying Tensorflow and TFX info/warning messages on console.
152+ tf_logger = logging .getLogger ('tensorflow' )
153+ tf_logger .handlers = []
154+ tf_logger .addHandler (handler )
155+
121156# pylint: disable=too-many-arguments
122157# pylint: disable=too-many-locals
123158
@@ -132,7 +167,7 @@ def create_tfrecords(
132167 dataflow_options : Optional [Dict [str , Any ]] = None ,
133168 job_label : str = 'create-tfrecords' ,
134169 compression : Optional [str ] = 'gzip' ,
135- num_shards : int = 0 ):
170+ num_shards : int = 0 ) -> Dict [ str , Any ] :
136171 """Generates TFRecord files from given input data.
137172
138173 TFRUtil provides an easy interface to create image-based tensorflow records
@@ -161,18 +196,21 @@ def create_tfrecords(
161196 num_shards: Number of shards to divide the TFRecords into. Default is
162197 0 = no sharding.
163198
199+ Returns:
200+ job_results: Dict
201+ job_id: DataFlow Job ID or 'DirectRunner'
202+ metrics: (optional) Beam metrics. Only used for DirectRunner
203+ dataflow_url: (optional) Job URL for DataFlowRunner
164204 """
165205
166206 df = to_dataframe (input_data , header , names )
167207
168208 _validate_data (df )
169209 _validate_runner (df , runner , project , region )
170- #os.makedirs(output_dir, exist_ok=True)
171- #TODO (mikebernico) this doesn't work with GCS locations...
210+
172211 logfile = os .path .join ('/tmp' , constants .LOGFILE )
173- logging .basicConfig (filename = logfile , level = constants .LOGLEVEL )
174- # This disables annoying Tensorflow and TFX info/warning messages.
175- logging .getLogger ('tensorflow' ).setLevel (logging .ERROR )
212+ _configure_logging (logfile )
213+
176214
177215 integer_label = pd .api .types .is_integer_dtype (df [constants .LABEL_KEY ])
178216 p = beam_pipeline .build_pipeline (
@@ -187,15 +225,55 @@ def create_tfrecords(
187225 dataflow_options = dataflow_options ,
188226 integer_label = integer_label )
189227
190- # TODO(mikbernico) Handle this async for the DataFlow case.
191228 result = p .run ()
192- result .wait_until_finish ()
193- # TODO(mikebernico) Add metrics here.
229+
230+ if runner == 'DirectRunner' :
231+ logging .info ("Using DirectRunner. TFRUtil will block until job completes." )
232+ result .wait_until_finish ()
233+
234+ row_count_filter = beam .metrics .MetricsFilter ().with_name ('row_count' )
235+ good_image_filter = beam .metrics .MetricsFilter ().with_name ('image_good' )
236+ bad_image_filter = beam .metrics .MetricsFilter ().with_name ('image_bad' )
237+
238+ row_count = _get_beam_metric (row_count_filter , result )
239+ good_image_count = _get_beam_metric (good_image_filter , result )
240+ bad_image_count = _get_beam_metric (bad_image_filter , result )
241+
242+ # TODO(mikebernico): Profile metric impact with larger dataset.
243+ metrics = {
244+ 'rows' : row_count ,
245+ 'good_images' : good_image_count ,
246+ 'bad_images' : bad_image_count ,
247+ }
248+
249+ job_result = {
250+ 'job_id' : 'DirectRunner' ,
251+ 'metrics' : metrics
252+ }
253+ logging .info ("Job Complete." )
254+
255+ else :
256+ logging .info ("Using DataFlow Runner." )
257+ # Construct DataFlow URL
258+
259+ job_id = result .job_id ()
260+
261+ url = (
262+ constants .CONSOLE_DATAFLOW_URI +
263+ region +
264+ '/' +
265+ job_id +
266+ '?project=' +
267+ project )
268+ job_result = {
269+ 'job_id' : job_id ,
270+ 'dataflow_url' : url
271+ }
272+
194273 logging .shutdown ()
195274
196- # FIXME: Issue where GCSFS is not picking up the `logfile` even if it exists.
197- if os .path .exists (logfile ):
198- pass
199- # common.copy_to_gcs(logfile,
200- # os.path.join(output_dir, constants.LOGFILE),
201- # recursive=False)
275+ if runner == 'DataFlowRunner' :
276+ # if this is a dataflow job, copy the logfile to gcs
277+ common .copy_logfile_to_gcs (logfile , output_dir )
278+
279+ return job_result
0 commit comments