Skip to content
This repository was archived by the owner on Jul 31, 2023. It is now read-only.

Commit dc8e657

Browse files
committed
Add routine to check TFRecord files.
Change-Id: I42e50880bfc04b1eea25b3f8ecb6be482f1200af
1 parent 7f370d7 commit dc8e657

22 files changed

+570
-80
lines changed

tfrutil/accessor.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,13 @@ def __init__(self, pandas_obj):
3434
self._df = pandas_obj
3535

3636
# pylint: disable=too-many-arguments
37-
def to_tfr(self,
38-
output_path: str,
39-
runner: str = "DirectRunner",
40-
job_label: str = "to-tfr",
41-
compression: Union[str, None] = "gzip",
42-
num_shards: int = 0) -> str:
37+
def to_tfr(
38+
self,
39+
output_dir: str,
40+
runner: str = "DirectRunner",
41+
job_label: str = "to-tfr",
42+
compression: Union[str, None] = "gzip",
43+
num_shards: int = 0) -> str:
4344
"""TFRUtil Pandas Accessor.
4445
4546
TFRUtil provides an easy interface to create image-based tensorflow records
@@ -48,16 +49,17 @@ def to_tfr(self,
4849
Usage:
4950
import tfrutil
5051
51-
df.tfrutil.to_tfr(runner="local",
52-
output_path="gcs://foo/bar/train",
53-
compression="gzip",
54-
num_shards=10,
55-
image_col="image",
56-
label_col="label)
52+
df.tfrutil.to_tfr(
53+
output_dir="gcs://foo/bar/train",
54+
runner="DirectRunner",
55+
compression="gzip",
56+
num_shards=10,
57+
image_col="image",
58+
label_col="label)
5759
5860
Args:
5961
runner: Beam runner. Can be "local" or "DataFlow"
60-
output_path: Local directory or GCS Location to save TFRecords to.
62+
output_dir: Local directory or GCS Location to save TFRecords to.
6163
job_label: User supplied description for the beam job name.
6264
compression: Can be "gzip" or None for no compression.
6365
num_shards: Number of shards to divide the TFRecords into. Default is
@@ -70,9 +72,9 @@ def to_tfr(self,
7072
"""
7173
print("Starting DataFlow Transform. This may take a while. Please wait.")
7274
client.create_tfrecords(self._df,
75+
output_dir=output_dir,
7376
runner=runner,
74-
output_path=output_path,
7577
job_label=job_label,
7678
compression=compression,
7779
num_shards=num_shards)
78-
print("TFRecords created. Output stored in {}".format(output_path))
80+
print("TFRecords created. Output stored in {}".format(output_dir))

tfrutil/accessor_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,9 @@ def setUp(self):
3232

3333
def test_accessor(self):
3434
"""Tests pandas accessor."""
35-
self.assertIsNone(self.test_df.tensorflow.to_tfr(runner="DirectRunner",
36-
output_path="/tmp/train"))
3735

36+
self.assertIsNone(self.test_df.tensorflow.to_tfr(runner="DirectRunner",
37+
output_dir="/tmp/train"))
3838

3939
if __name__ == "__main__":
4040
unittest.main()

tfrutil/beam_image.py

Lines changed: 37 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -24,27 +24,51 @@
2424
from PIL import Image
2525

2626

27-
def load(image_uri):
28-
"""Loads an image."""
27+
BASE64_ALTCHARS = b'-_'
2928

30-
try:
31-
with tf.io.gfile.GFile(image_uri, "rb") as f:
32-
return Image.open(f)
33-
except tf.python.framework.errors_impl.NotFoundError:
34-
raise OSError("File {} was not found.".format(image_uri))
3529

30+
def mode_to_channel(mode: str) -> int:
31+
"""Returns number of channels depending on PIL image `mode`s."""
3632

37-
def encode(image):
33+
return 1 if 'L' in mode else 3
34+
35+
36+
def channel_to_mode(channels: int) -> str:
37+
"""Returns PIL image mode depending on number of `channels`."""
38+
39+
return 'L' if channels == 1 else 'RGB'
40+
41+
42+
def encode(image: Image):
3843
"""Returns base64-encoded image data.
3944
4045
Args:
4146
image: PIL image.
4247
"""
4348

44-
return base64.b64encode(image.tobytes(), altchars=b"-_")
49+
return base64.b64encode(image.tobytes(), altchars=BASE64_ALTCHARS)
50+
51+
52+
def decode(b64_bytes, width, height, channels) -> Image:
53+
"""Decodes an image from base64-encoded data."""
54+
55+
image_bytes = base64.b64decode(b64_bytes, altchars=BASE64_ALTCHARS)
56+
mode = channel_to_mode(channels)
57+
return Image.frombytes(mode, (width, height), image_bytes)
58+
59+
60+
def load(image_uri):
61+
"""Loads an image."""
62+
63+
try:
64+
with tf.io.gfile.GFile(image_uri, "rb") as f:
65+
return Image.open(f)
66+
except tf.python.framework.errors_impl.NotFoundError:
67+
raise OSError("File {} was not found.".format(image_uri))
68+
4569

70+
# pylint: disable=abstract-method
4671

47-
#pylint: disable=abstract-method
4872
class ExtractImagesDoFn(beam.DoFn):
4973
"""Adds image to PCollection."""
5074

@@ -53,7 +77,7 @@ def __init__(self, image_key: str):
5377
super().__init__()
5478
self.image_key = image_key
5579

56-
#pylint: disable=unused-argument
80+
# pylint: disable=unused-argument
5781
def process(
5882
self,
5983
element: Dict,
@@ -70,8 +94,8 @@ def process(
7094
image_uri = element[self.image_key]
7195
image = load(image_uri)
7296
d["image"] = encode(image)
73-
d["image_height"], d["image_width"] = image.size
74-
d["image_channels"] = 1 if "L" in image.mode else 3
97+
d["image_width"], d["image_height"] = image.size
98+
d["image_channels"] = mode_to_channel(image.mode)
7599

76100
#pylint: disable=broad-except
77101
except Exception as e:

tfrutil/beam_image_test.py

Lines changed: 36 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,9 @@
2121

2222
import apache_beam as beam
2323
from apache_beam.testing import util
24+
import numpy as np
2425
import PIL
26+
from PIL import Image
2527
import tensorflow_transform as tft
2628

2729
from tfrutil import beam_image
@@ -35,26 +37,51 @@ class BeamImageTests(unittest.TestCase):
3537
def setUp(self):
3638
self.pipeline = test_utils.get_test_pipeline()
3739
self.df = test_utils.get_test_df()
40+
self.image_file = "tfrutil/test_data/images/cat/cat-640x853-1.jpg"
3841

3942
def test_load(self):
4043
"""Tests the image loading function."""
41-
img = beam_image.load("tfrutil/test_data/images/cat/cat-640x853-1.jpg")
44+
img = beam_image.load(self.image_file)
4245
self.assertIsInstance(img, PIL.JpegImagePlugin.JpegImageFile)
4346

4447
def test_file_not_found_load(self):
4548
"""Test loading an image that doesn"t exist."""
4649
with self.assertRaises(OSError):
4750
_ = beam_image.load("tfrutil/test_data/images/cat/food.jpg")
4851

52+
def test_mode_to_channel(self):
53+
"""Tests `mode_to_channel`."""
54+
55+
actual = [beam_image.mode_to_channel(mode)
56+
for mode in ('L', 'RGB', 'whatever')]
57+
self.assertEqual(actual, [1, 3, 3])
58+
59+
def test_channel_to_mode(self):
60+
"""Tests `channel_to_mode`."""
61+
62+
actual = [beam_image.channel_to_mode(channel) for channel in (1, 2, 3)]
63+
self.assertEqual(actual, ['L', 'RGB', 'RGB'])
64+
4965
def test_base64_encode(self):
5066
"""Tests encode function."""
51-
img = beam_image.load("tfrutil/test_data/images/cat/cat-640x853-1.jpg")
67+
img = beam_image.load(self.image_file)
5268
enc = beam_image.encode(img)
53-
decode = base64.b64decode(enc, altchars=b"-_")
69+
decode = base64.b64decode(enc, altchars=beam_image.BASE64_ALTCHARS)
5470
self.assertEqual(img.tobytes(), decode)
5571

72+
def test_base64_decode(self):
73+
"""Tests `decode` function."""
74+
75+
image = Image.open(self.image_file)
76+
b64_bytes = base64.b64encode(
77+
image.tobytes(), altchars=beam_image.BASE64_ALTCHARS)
78+
width, height = image.size
79+
actual = beam_image.decode(b64_bytes, width, height, 3)
80+
np.testing.assert_array_equal(np.asarray(actual), np.asarray(image))
81+
5682
def test_extract_image_dofn(self):
5783
"""Tests ExtractImageDoFn."""
84+
5885
with self.pipeline as p:
5986

6087
converter = tft.coders.CsvCoder(constants.IMAGE_CSV_COLUMNS,
@@ -76,9 +103,13 @@ def key_matcher(expected_keys):
76103
def _equal(actual):
77104
""" _equal raises a BeamAssertException when an element in the
78105
PCollection doesn't contain the image extraction keys."""
106+
expected_keys_ = set(expected_keys)
79107
for element in actual:
80-
if set(element.keys()) != set(expected_keys):
81-
raise util.BeamAssertException("PCollection key match failed.")
108+
actual_keys = set(element.keys())
109+
if actual_keys != expected_keys_:
110+
raise util.BeamAssertException(
111+
"PCollection key match failed. Actual ({}) vs. expected ({})"
112+
.format(actual_keys, expected_keys_))
82113
return _equal
83114

84115
expected_keys = ["image_uri", "label", "split", "image",

tfrutil/beam_pipeline.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
1919
This file implements the full beam pipeline for TFRUtil.
2020
"""
21-
import datetime
21+
2222
import functools
2323
import logging
2424
import os
@@ -30,6 +30,7 @@
3030
from tensorflow_transform import beam as tft_beam
3131

3232
from tfrutil import beam_image
33+
from tfrutil import common
3334
from tfrutil import constants
3435

3536

@@ -44,7 +45,7 @@ def _get_job_name(job_label: str = None) -> str:
4445
insure uniqueness.
4546
"""
4647

47-
job_name = "tfrutil-" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
48+
job_name = "tfrutil-" + common.get_timestamp()
4849
if job_label:
4950
job_label = job_label.replace("_", "-")
5051
job_name += "-" + job_label
@@ -89,7 +90,7 @@ def _partition_fn(
8990
index = constants.DISCARD_INDEX
9091
return index
9192

92-
def _get_write_to_tfrecord(output_path: str,
93+
def _get_write_to_tfrecord(output_dir: str,
9394
prefix: str,
9495
compress: bool = True,
9596
num_shards: int = 0) \
@@ -99,13 +100,13 @@ def _get_write_to_tfrecord(output_path: str,
99100
This configures a Beam sink to output TFRecord files.
100101
101102
Args:
102-
output_path: Directory to output TFRecord files.
103+
output_dir: Directory to output TFRecord files.
103104
prefix: TFRecord file prefix.
104105
compress: If True, GZip compress TFRecord files.
105106
num_shards: Number of file shards to split the TFRecord data.
106107
"""
107108

108-
path = os.path.join(output_path, prefix)
109+
path = os.path.join(output_dir, prefix)
109110
suffix = '.tfrecord'
110111
if compress:
111112
compression_type = 'gzip'
@@ -135,28 +136,29 @@ def _preprocessing_fn(inputs, integer_label: bool = False):
135136

136137
# pylint: disable=too-many-arguments
137138
# pylint: disable=too-many-locals
138-
def run_pipeline(df: pd.DataFrame,
139-
job_label: str,
140-
runner: str,
141-
output_path: str,
142-
compression: str,
143-
num_shards: int,
144-
integer_label: bool = False):
139+
def run_pipeline(
140+
df: pd.DataFrame,
141+
job_label: str,
142+
runner: str,
143+
output_dir: str,
144+
compression: str,
145+
num_shards: int,
146+
integer_label: bool = False):
145147
"""Runs TFRUtil Beam Pipeline.
146148
147149
Args:
148150
df: Pandas Dataframe
149151
job_label: User description for the beam job.
150152
runner: Beam Runner: (e.g. DataFlowRunner, DirectRunner).
151-
output_path: GCS or Local Path for output.
153+
output_dir: GCS or Local Path for output.
152154
compression: gzip or None.
153155
num_shards: Number of shards.
154156
155157
Note: These inputs must be validated upstream (by client.create_tfrecord())
156158
"""
157159

158160
job_name = _get_job_name(job_label)
159-
job_dir = _get_job_dir(output_path, job_name)
161+
job_dir = _get_job_dir(output_dir, job_name)
160162
popts = {} # TODO(mikebernico): consider how/if to pass pipeline options.
161163
options = _get_pipeline_options(job_name, job_dir, **popts)
162164

@@ -215,7 +217,7 @@ def run_pipeline(df: pd.DataFrame,
215217

216218
# Sinks for TFRecords and metadata.
217219
tfr_writer = functools.partial(_get_write_to_tfrecord,
218-
output_path=job_dir,
220+
output_dir=job_dir,
219221
compress=compression,
220222
num_shards=num_shards)
221223

@@ -239,7 +241,6 @@ def run_pipeline(df: pd.DataFrame,
239241
| 'DiscardDataWriter' >> beam.io.WriteToText(
240242
os.path.join(job_dir, "discarded-data")))
241243

242-
243244
# Output transform function and metadata
244245
_ = (transform_fn | 'WriteTransformFn' >> tft_beam.WriteTransformFn(
245246
job_dir))

tfrutil/beam_pipeline_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def test_processing_fn_with_string_label(self, mock_transform):
5555
def test_write_to_tfrecord(self):
5656
"""Test _write_to_tfrecord() fn."""
5757
tfr_writer = beam_pipeline._get_write_to_tfrecord(
58-
output_path='tmp',
58+
output_dir='tmp',
5959
prefix='foo',
6060
compress=True,
6161
num_shards=2)

0 commit comments

Comments
 (0)