Skip to content

Commit 379ba0e

Browse files
authored
Merge pull request #12912 from [BEAM-10938] Adds support for writing a footer to Python WriteToText
* [BEAM-10938] Adds support for writing a footer to Python WriteToText * Fixing formatting
1 parent c352d45 commit 379ba0e

File tree

2 files changed

+43
-3
lines changed

2 files changed

+43
-3
lines changed

sdks/python/apache_beam/io/textio.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -350,7 +350,8 @@ def __init__(self,
350350
shard_name_template=None,
351351
coder=coders.ToBytesCoder(), # type: coders.Coder
352352
compression_type=CompressionTypes.AUTO,
353-
header=None):
353+
header=None,
354+
footer=None):
354355
"""Initialize a _TextSink.
355356
356357
Args:
@@ -382,6 +383,8 @@ def __init__(self,
382383
compression.
383384
header: String to write at beginning of file as a header. If not None and
384385
append_trailing_newlines is set, '\n' will be added.
386+
footer: String to write at the end of file as a footer. If not None and
387+
append_trailing_newlines is set, '\n' will be added.
385388
386389
Returns:
387390
A _TextSink object usable for writing.
@@ -396,6 +399,7 @@ def __init__(self,
396399
compression_type=compression_type)
397400
self._append_trailing_newlines = append_trailing_newlines
398401
self._header = header
402+
self._footer = footer
399403

400404
def open(self, temp_path):
401405
file_handle = super(_TextSink, self).open(temp_path)
@@ -405,6 +409,13 @@ def open(self, temp_path):
405409
file_handle.write(b'\n')
406410
return file_handle
407411

412+
def close(self, file_handle):
413+
if self._footer is not None:
414+
file_handle.write(coders.ToBytesCoder().encode(self._footer))
415+
if self._append_trailing_newlines:
416+
file_handle.write(b'\n')
417+
super(_TextSink, self).close(file_handle)
418+
408419
def display_data(self):
409420
dd_parent = super(_TextSink, self).display_data()
410421
dd_parent['append_newline'] = DisplayDataItem(
@@ -588,7 +599,8 @@ def __init__(
588599
shard_name_template=None, # type: Optional[str]
589600
coder=coders.ToBytesCoder(), # type: coders.Coder
590601
compression_type=CompressionTypes.AUTO,
591-
header=None):
602+
header=None,
603+
footer=None):
592604
r"""Initialize a :class:`WriteToText` transform.
593605
594606
Args:
@@ -624,6 +636,9 @@ def __init__(
624636
header (str): String to write at beginning of file as a header.
625637
If not :data:`None` and **append_trailing_newlines** is set, ``\n`` will
626638
be added.
639+
footer (str): String to write at the end of file as a footer.
640+
If not :data:`None` and **append_trailing_newlines** is set, ``\n`` will
641+
be added.
627642
"""
628643

629644
self._sink = _TextSink(
@@ -634,7 +649,8 @@ def __init__(
634649
shard_name_template,
635650
coder,
636651
compression_type,
637-
header)
652+
header,
653+
footer)
638654

639655
def expand(self, pcoll):
640656
return pcoll | Write(self._sink)

sdks/python/apache_beam/io/textio_test.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1125,6 +1125,14 @@ def test_write_text_file_with_header(self):
11251125
with open(self.path, 'rb') as f:
11261126
self.assertEqual(f.read().splitlines(), header.splitlines() + self.lines)
11271127

1128+
def test_write_text_file_with_footer(self):
1129+
footer = b'footer1\nfooter2'
1130+
sink = TextSink(self.path, footer=footer)
1131+
self._write_lines(sink, self.lines)
1132+
1133+
with open(self.path, 'rb') as f:
1134+
self.assertEqual(f.read().splitlines(), self.lines + footer.splitlines())
1135+
11281136
def test_write_text_file_empty_with_header(self):
11291137
header = b'header1\nheader2'
11301138
sink = TextSink(self.path, header=header)
@@ -1203,6 +1211,22 @@ def test_write_pipeline_header(self):
12031211
self.assertEqual(read_result[0], header_text.encode('utf-8'))
12041212
self.assertEqual(sorted(read_result[1:]), sorted(self.lines))
12051213

1214+
def test_write_pipeline_footer(self):
1215+
with TestPipeline() as pipeline:
1216+
footer_text = 'footer'
1217+
pcoll = pipeline | beam.core.Create(self.lines)
1218+
pcoll | 'Write' >> WriteToText( # pylint: disable=expression-not-assigned
1219+
self.path,
1220+
footer=footer_text)
1221+
1222+
read_result = []
1223+
for file_name in glob.glob(self.path + '*'):
1224+
with open(file_name, 'rb') as f:
1225+
read_result.extend(f.read().splitlines())
1226+
1227+
self.assertEqual(sorted(read_result[:-1]), sorted(self.lines))
1228+
self.assertEqual(read_result[-1], footer_text.encode('utf-8'))
1229+
12061230

12071231
if __name__ == '__main__':
12081232
logging.getLogger().setLevel(logging.INFO)

0 commit comments

Comments
 (0)