1919This file implements the full Beam pipeline for TFRecorder.
2020"""
2121
22- from typing import Any , Dict , Generator , Union
22+ from typing import Any , Callable , Dict , Generator , List , Optional , Union
2323
2424import functools
2525import logging
2626import os
2727
2828import apache_beam as beam
29+ from apache_beam import pvalue
2930import pandas as pd
3031import tensorflow_transform as tft
3132from tensorflow_transform import beam as tft_beam
3233
3334from tfrecorder import beam_image
3435from tfrecorder import common
3536from tfrecorder import constants
37+ from tfrecorder import types
3638
3739
3840def _get_job_name (job_label : str = None ) -> str :
@@ -138,7 +140,7 @@ def _get_write_to_tfrecord(output_dir: str,
138140 num_shards = num_shards ,
139141 )
140142
141- def _preprocessing_fn (inputs , integer_label : bool = False ):
143+ def _preprocessing_fn (inputs : Dict [ str , Any ] , integer_label : bool = False ):
142144 """TensorFlow Transform preprocessing function."""
143145
144146 outputs = inputs .copy ()
@@ -166,7 +168,7 @@ def __init__(self):
166168 # pylint: disable=arguments-differ
167169 def process (
168170 self ,
169- element : Dict [str , Any ]
171+ element : List [str ],
170172 ) -> Generator [Dict [str , Any ], None , None ]:
171173 """Loads image and creates image features.
172174
@@ -178,6 +180,43 @@ def process(
178180 yield element
179181
180182
183+ def get_split_counts (df : pd .DataFrame ):
184+ """Returns number of rows for each data split type given dataframe."""
185+ assert constants .SPLIT_KEY in df .columns
186+ return df [constants .SPLIT_KEY ].value_counts ().to_dict ()
187+
188+
189+ def _transform_and_write_tfr (
190+ dataset : pvalue .PCollection ,
191+ tfr_writer : Callable = None ,
192+ preprocessing_fn : Optional [Callable ] = None ,
193+ transform_fn : Optional [types .TransformFn ] = None ,
194+ label : str = 'data' ):
195+ """Applies TF Transform to dataset and outputs it as TFRecords."""
196+
197+ dataset_metadata = (dataset , constants .RAW_METADATA )
198+
199+ if transform_fn :
200+ transformed_dataset , transformed_metadata = (
201+ (dataset_metadata , transform_fn )
202+ | f'Transform{ label } ' >> tft_beam .TransformDataset ())
203+ else :
204+ if not preprocessing_fn :
205+ preprocessing_fn = lambda x : x
206+ (transformed_dataset , transformed_metadata ), transform_fn = (
207+ dataset_metadata
208+ | f'AnalyzeAndTransform{ label } ' >>
209+ tft_beam .AnalyzeAndTransformDataset (preprocessing_fn ))
210+
211+ transformed_data_coder = tft .coders .ExampleProtoCoder (
212+ transformed_metadata .schema )
213+ _ = (
214+ transformed_dataset
215+ | f'Encode{ label } ' >> beam .Map (transformed_data_coder .encode )
216+ | f'Write{ label } ' >> tfr_writer (prefix = label .lower ()))
217+
218+ return transform_fn
219+
181220
182221# pylint: disable=too-many-arguments
183222# pylint: disable=too-many-locals
@@ -246,71 +285,49 @@ def build_pipeline(
246285 | 'ReadImage' >> beam .ParDo (extract_images_fn )
247286 )
248287
249- # Split dataset into train and validation.
288+ # Note: This will not always reflect actual number of samples per dataset
289+ # written as TFRecords. The succeeding `Partition` operation may mark
290+ # additional samples from other splits as discarded. If a split has all
291+ # its samples discarded, the pipeline will still generate a TFRecord
292+ # file for that split, albeit empty.
293+ split_counts = get_split_counts (df )
294+
295+ # Require training set to be available in the input data. The transform_fn
296+ # and transformed_metadata will be generated from the training set and
297+ # applied to the other datasets, if any
298+ assert 'TRAIN' in split_counts
299+
250300 train_data , val_data , test_data , discard_data = (
251301 image_csv_data | 'SplitDataset' >> beam .Partition (
252302 _partition_fn , len (constants .SPLIT_VALUES ))
253303 )
254304
255- train_dataset = (train_data , constants .RAW_METADATA )
256- val_dataset = (val_data , constants .RAW_METADATA )
257- test_dataset = (test_data , constants .RAW_METADATA )
258-
259- # TensorFlow Transform applied to all datasets.
260305 preprocessing_fn = functools .partial (
261306 _preprocessing_fn ,
262307 integer_label = integer_label )
263- transformed_train_dataset , transform_fn = (
264- train_dataset
265- | 'AnalyzeAndTransformTrain' >> tft_beam .AnalyzeAndTransformDataset (
266- preprocessing_fn ))
267-
268- transformed_train_data , transformed_metadata = transformed_train_dataset
269- transformed_data_coder = tft .coders .ExampleProtoCoder (
270- transformed_metadata .schema )
271-
272- transformed_val_data , _ = (
273- (val_dataset , transform_fn )
274- | 'TransformVal' >> tft_beam .TransformDataset ()
275- )
276308
277- transformed_test_data , _ = (
278- (test_dataset , transform_fn )
279- | 'TransformTest' >> tft_beam .TransformDataset ()
280- )
309+ tfr_writer = functools .partial (
310+ _get_write_to_tfrecord , output_dir = job_dir , compress = compression ,
311+ num_shards = num_shards )
312+ transform_fn = _transform_and_write_tfr (
313+ train_data , tfr_writer , preprocessing_fn = preprocessing_fn ,
314+ label = 'Train' )
281315
282- # Sinks for TFRecords and metadata.
283- tfr_writer = functools .partial (_get_write_to_tfrecord ,
284- output_dir = job_dir ,
285- compress = compression ,
286- num_shards = num_shards )
316+ if 'VALIDATION' in split_counts :
317+ _transform_and_write_tfr (
318+ val_data , tfr_writer , transform_fn = transform_fn , label = 'Validation' )
287319
288- _ = (
289- transformed_train_data
290- | 'EncodeTrainData' >> beam .Map (transformed_data_coder .encode )
291- | 'WriteTrainData' >> tfr_writer (prefix = 'train' ))
292-
293- _ = (
294- transformed_val_data
295- | 'EncodeValData' >> beam .Map (transformed_data_coder .encode )
296- | 'WriteValData' >> tfr_writer (prefix = 'val' ))
297-
298- _ = (
299- transformed_test_data
300- | 'EncodeTestData' >> beam .Map (transformed_data_coder .encode )
301- | 'WriteTestData' >> tfr_writer (prefix = 'test' ))
320+ if 'TEST' in split_counts :
321+ _transform_and_write_tfr (
322+ test_data , tfr_writer , transform_fn = transform_fn , label = 'Test' )
302323
303324 _ = (
304325 discard_data
305- | 'DiscardDataWriter ' >> beam .io .WriteToText (
326+ | 'WriteDiscardedData ' >> beam .io .WriteToText (
306327 os .path .join (job_dir , 'discarded-data' )))
307328
308- # Output transform function and metadata
329+ # Note: `transform_fn` already contains the transformed metadata
309330 _ = (transform_fn | 'WriteTransformFn' >> tft_beam .WriteTransformFn (
310331 job_dir ))
311332
312- # Output metadata schema
313- _ = (transformed_metadata | 'WriteMetadata' >> tft_beam .WriteMetadata (
314- job_dir , pipeline = p ))
315-
316333 return p
0 commit comments