diff --git a/CHANGES.md b/CHANGES.md index 21eaead52873..f3972906ff43 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -72,6 +72,10 @@ ## I/Os * Support for X source added (Java/Python) ([#X](https://github.com/apache/beam/issues/X)). +* Upgraded GoogleAdsAPI to v19 for GoogleAdsIO (Java) ([#34497](https://github.com/apache/beam/pull/34497)). Changed PTransform method from version-specified (`v17()`) to `current()` for better backward compatibility in the future. +* Added support for writing to Pubsub with ordering keys (Java) ([#21162](https://github.com/apache/beam/issues/21162)) +* Support for streaming writes for AvroIO, ParquetIO, TextIO, TFRecordIO +* IOBase.Sink finalize_write has a new optional parameter w for the window ## New Features / Improvements diff --git a/sdks/python/apache_beam/examples/snippets/snippets.py b/sdks/python/apache_beam/examples/snippets/snippets.py index f6bf5e5d44ec..08a092355c1c 100644 --- a/sdks/python/apache_beam/examples/snippets/snippets.py +++ b/sdks/python/apache_beam/examples/snippets/snippets.py @@ -698,7 +698,8 @@ def open_writer(self, access_token, uid): def pre_finalize(self, init_result, writer_results): pass - def finalize_write(self, access_token, table_names, pre_finalize_result): + def finalize_write( + self, access_token, table_names, pre_finalize_result, unused_window): for i, table_name in enumerate(table_names): self._simplekv.rename_table( access_token, table_name, self._final_table_name + str(i)) diff --git a/sdks/python/apache_beam/examples/unbounded_sinks/__init__.py b/sdks/python/apache_beam/examples/unbounded_sinks/__init__.py new file mode 100644 index 000000000000..cce3acad34a4 --- /dev/null +++ b/sdks/python/apache_beam/examples/unbounded_sinks/__init__.py @@ -0,0 +1,16 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/sdks/python/apache_beam/examples/unbounded_sinks/generate_event.py b/sdks/python/apache_beam/examples/unbounded_sinks/generate_event.py new file mode 100644 index 000000000000..5b23e43fc368 --- /dev/null +++ b/sdks/python/apache_beam/examples/unbounded_sinks/generate_event.py @@ -0,0 +1,133 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from datetime import datetime + +import pytz + +import apache_beam as beam +from apache_beam.testing.test_stream import TestStream + + +class GenerateEvent(beam.PTransform): + @staticmethod + def sample_data(): + return GenerateEvent() + + def expand(self, input): + elemlist = [{'age': 10}, {'age': 20}, {'age': 30}] + elem = elemlist + return ( + input + | TestStream().add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 1, 0, + tzinfo=pytz.UTC).timestamp()).add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 2, 0, + tzinfo=pytz.UTC).timestamp()).add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 3, 0, + tzinfo=pytz.UTC).timestamp()).add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 4, 0, + tzinfo=pytz.UTC).timestamp()). + advance_watermark_to( + datetime(2021, 3, 1, 0, 0, 5, 0, + tzinfo=pytz.UTC).timestamp()).add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 5, 0, + tzinfo=pytz.UTC).timestamp()). + add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 6, + 0, tzinfo=pytz.UTC).timestamp()).add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 7, 0, + tzinfo=pytz.UTC).timestamp()).add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 8, 0, + tzinfo=pytz.UTC).timestamp()).add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 9, 0, + tzinfo=pytz.UTC).timestamp()). + advance_watermark_to( + datetime(2021, 3, 1, 0, 0, 10, 0, + tzinfo=pytz.UTC).timestamp()).add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 10, 0, + tzinfo=pytz.UTC).timestamp()).add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 11, 0, + tzinfo=pytz.UTC).timestamp()). + add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 12, 0, + tzinfo=pytz.UTC).timestamp()).add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 13, 0, + tzinfo=pytz.UTC).timestamp()).add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 14, 0, + tzinfo=pytz.UTC).timestamp()). + advance_watermark_to( + datetime(2021, 3, 1, 0, 0, 15, 0, + tzinfo=pytz.UTC).timestamp()).add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 15, 0, + tzinfo=pytz.UTC).timestamp()).add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 16, 0, + tzinfo=pytz.UTC).timestamp()). + add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 17, 0, + tzinfo=pytz.UTC).timestamp()).add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 18, 0, + tzinfo=pytz.UTC).timestamp()).add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 19, 0, + tzinfo=pytz.UTC).timestamp()). + advance_watermark_to( + datetime(2021, 3, 1, 0, 0, 20, 0, + tzinfo=pytz.UTC).timestamp()).add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 20, 0, + tzinfo=pytz.UTC).timestamp()).advance_watermark_to( + datetime( + 2021, 3, 1, 0, 0, 25, 0, tzinfo=pytz.UTC). + timestamp()).advance_watermark_to_infinity()) diff --git a/sdks/python/apache_beam/examples/unbounded_sinks/test_write.py b/sdks/python/apache_beam/examples/unbounded_sinks/test_write.py new file mode 100644 index 000000000000..99eb30405bf4 --- /dev/null +++ b/sdks/python/apache_beam/examples/unbounded_sinks/test_write.py @@ -0,0 +1,177 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +# python -m apache_beam.examples.unbounded_sinks.test_write +# This file contains multiple examples of writing unbounded PCollection to files + +import argparse +import json +import logging + +import pyarrow + +import apache_beam as beam +from apache_beam.examples.unbounded_sinks.generate_event import GenerateEvent +from apache_beam.io.fileio import WriteToFiles +from apache_beam.options.pipeline_options import PipelineOptions +from apache_beam.options.pipeline_options import SetupOptions +from apache_beam.runners.runner import PipelineResult +from apache_beam.transforms.trigger import AccumulationMode +from apache_beam.transforms.trigger import AfterWatermark +from apache_beam.transforms.util import LogElements +from apache_beam.transforms.window import FixedWindows +from apache_beam.utils.timestamp import Duration + + +class CountEvents(beam.PTransform): + def expand(self, events): + return ( + events + | beam.WindowInto( + FixedWindows(5), + trigger=AfterWatermark(), + accumulation_mode=AccumulationMode.DISCARDING, + allowed_lateness=Duration(seconds=0)) + | beam.CombineGlobally( + beam.combiners.CountCombineFn()).without_defaults()) + + +def run(argv=None, save_main_session=True) -> PipelineResult: + """Main entry point; defines and runs the wordcount pipeline.""" + parser = argparse.ArgumentParser() + _, pipeline_args = parser.parse_known_args(argv) + + # We use the save_main_session option because one or more DoFn's in this + # workflow rely on global context (e.g., a module imported at module level). + pipeline_options = PipelineOptions(pipeline_args) + pipeline_options.view_as(SetupOptions).save_main_session = save_main_session + + p = beam.Pipeline(options=pipeline_options) + + output = p | GenerateEvent.sample_data() + + #TextIO + output2 = output | 'TextIO WriteToText' >> beam.io.WriteToText( + file_path_prefix="__output__/ouput_WriteToText", + file_name_suffix=".txt", + #shard_name_template='-V-SSSSS-of-NNNNN', + num_shards=2, + triggering_frequency=5, + ) + _ = output2 | 'LogElements after WriteToText' >> LogElements( + prefix='after WriteToText ', with_window=True, level=logging.INFO) + + #FileIO + _ = ( + output + | 'FileIO window' >> beam.WindowInto( + FixedWindows(5), + trigger=AfterWatermark(), + accumulation_mode=AccumulationMode.DISCARDING, + allowed_lateness=Duration(seconds=0)) + | 'Serialize' >> beam.Map(json.dumps) + | 'FileIO WriteToFiles' >> + WriteToFiles(path="__output__/output_WriteToFiles")) + + #ParquetIO + pyschema = pyarrow.schema([('age', pyarrow.int64())]) + + output4a = output | 'WriteToParquet' >> beam.io.WriteToParquet( + file_path_prefix="__output__/output_parquet", + #shard_name_template='-V-SSSSS-of-NNNNN', + file_name_suffix=".parquet", + num_shards=2, + triggering_frequency=5, + schema=pyschema) + _ = output4a | 'LogElements after WriteToParquet' >> LogElements( + prefix='after WriteToParquet 4a ', with_window=True, level=logging.INFO) + + output4aw = ( + output + | 'ParquetIO window' >> beam.WindowInto( + FixedWindows(20), + trigger=AfterWatermark(), + accumulation_mode=AccumulationMode.DISCARDING, + allowed_lateness=Duration(seconds=0)) + | 'WriteToParquet windowed' >> beam.io.WriteToParquet( + file_path_prefix="__output__/output_parquet", + shard_name_template='-W-SSSSS-of-NNNNN', + file_name_suffix=".parquet", + num_shards=2, + schema=pyschema)) + _ = output4aw | 'LogElements after WriteToParquet windowed' >> LogElements( + prefix='after WriteToParquet 4aw ', with_window=True, level=logging.INFO) + + output4b = ( + output + | 'To PyArrow Table' >> + beam.Map(lambda x: pyarrow.Table.from_pylist([x], schema=pyschema)) + | 'WriteToParquetBatched to parquet' >> beam.io.WriteToParquetBatched( + file_path_prefix="__output__/output_parquet_batched", + shard_name_template='-V-SSSSS-of-NNNNN', + file_name_suffix=".parquet", + num_shards=2, + triggering_frequency=5, + schema=pyschema)) + _ = output4b | 'LogElements after WriteToParquetBatched' >> LogElements( + prefix='after WriteToParquetBatched 4b ', + with_window=True, + level=logging.INFO) + + #AvroIO + avroschema = { + 'name': 'dummy', # your supposed to be file name with .avro extension + 'type': 'record', # type of avro serilazation, there are more (see above + # docs) but as per me this will do most of the time + 'fields': [ # this defines actual keys & their types + {'name': 'age', 'type': 'int'}, + ], + } + output5 = output | 'WriteToAvro' >> beam.io.WriteToAvro( + file_path_prefix="__output__/output_avro", + #shard_name_template='-V-SSSSS-of-NNNNN', + file_name_suffix=".avro", + num_shards=2, + #triggering_frequency=5, + schema=avroschema) + _ = output5 | 'LogElements after WriteToAvro' >> LogElements( + prefix='after WriteToAvro 5 ', with_window=True, level=logging.INFO) + + #TFrecordIO + output6 = ( + output + | "encode" >> beam.Map(lambda s: json.dumps(s).encode('utf-8')) + | 'WriteToTFRecord' >> beam.io.WriteToTFRecord( + file_path_prefix="__output__/output_tfrecord", + #shard_name_template='-V-SSSSS-of-NNNNN', + file_name_suffix=".tfrecord", + num_shards=2, + triggering_frequency=5)) + _ = output6 | 'LogElements after WriteToTFRecord' >> LogElements( + prefix='after WriteToTFRecord 6 ', with_window=True, level=logging.INFO) + + # Execute the pipeline and return the result. + result = p.run() + result.wait_until_finish() + return result + + +if __name__ == '__main__': + logging.getLogger().setLevel(logging.INFO) + run() diff --git a/sdks/python/apache_beam/examples/unbounded_sinks/test_write_bounded.py b/sdks/python/apache_beam/examples/unbounded_sinks/test_write_bounded.py new file mode 100644 index 000000000000..7e24dc433de5 --- /dev/null +++ b/sdks/python/apache_beam/examples/unbounded_sinks/test_write_bounded.py @@ -0,0 +1,112 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +# python -m apache_beam.examples.unbounded_sinks.test_write + +import argparse +import json +import logging + +import pyarrow + +import apache_beam as beam +from apache_beam.io.fileio import WriteToFiles +from apache_beam.io.textio import WriteToText +from apache_beam.options.pipeline_options import PipelineOptions +from apache_beam.options.pipeline_options import SetupOptions +from apache_beam.runners.runner import PipelineResult +from apache_beam.transforms.trigger import AccumulationMode +from apache_beam.transforms.trigger import AfterWatermark +from apache_beam.transforms.util import LogElements +from apache_beam.transforms.window import FixedWindows +from apache_beam.utils.timestamp import Duration + + +class CountEvents(beam.PTransform): + def expand(self, events): + return ( + events + | beam.WindowInto( + FixedWindows(5), + trigger=AfterWatermark(), + accumulation_mode=AccumulationMode.DISCARDING, + allowed_lateness=Duration(seconds=0)) + | beam.CombineGlobally( + beam.combiners.CountCombineFn()).without_defaults()) + + +def run(argv=None, save_main_session=True) -> PipelineResult: + """Main entry point; defines and runs the wordcount pipeline.""" + parser = argparse.ArgumentParser() + _, pipeline_args = parser.parse_known_args(argv) + + # We use the save_main_session option because one or more DoFn's in this + # workflow rely on global context (e.g., a module imported at module level). + pipeline_options = PipelineOptions(pipeline_args) + pipeline_options.view_as(SetupOptions).save_main_session = save_main_session + + p = beam.Pipeline(options=pipeline_options) + + output = ( + p | beam.Create([{ + 'age': 10 + }, { + 'age': 20 + }, { + 'age': 30 + }]) + | beam.LogElements( + prefix='before write ', with_window=False, level=logging.INFO)) + #TextIO + output2 = output | 'Write to text' >> WriteToText( + file_path_prefix="__output_batch__/ouput_WriteToText", + file_name_suffix=".txt", + shard_name_template='-U-SSSSS-of-NNNNN') + _ = output2 | 'LogElements after WriteToText' >> LogElements( + prefix='after WriteToText ', with_window=False, level=logging.INFO) + + #FileIO + output3 = ( + output | 'Serialize' >> beam.Map(json.dumps) + | 'Write to files' >> + WriteToFiles(path="__output_batch__/output_WriteToFiles")) + _ = output3 | 'LogElements after WriteToFiles' >> LogElements( + prefix='after WriteToFiles ', with_window=False, level=logging.INFO) + + #ParquetIO + output4 = output | 'Write' >> beam.io.WriteToParquet( + file_path_prefix="__output_batch__/output_parquet", + schema=pyarrow.schema([('age', pyarrow.int64())])) + _ = output4 | 'LogElements after WriteToParquet' >> LogElements( + prefix='after WriteToParquet ', with_window=False, level=logging.INFO) + _ = output | 'Write parquet' >> beam.io.WriteToParquet( + file_path_prefix="__output_batch__/output_WriteToParquet", + schema=pyarrow.schema([('age', pyarrow.int64())]), + record_batch_size=10, + num_shards=0) + + # Execute the pipeline and return the result. + result = p.run() + result.wait_until_finish() + return result + + +if __name__ == '__main__': + logging.getLogger().setLevel(logging.INFO) + run() diff --git a/sdks/python/apache_beam/io/avroio.py b/sdks/python/apache_beam/io/avroio.py index 3438cb5d61fe..fac6ba57657b 100644 --- a/sdks/python/apache_beam/io/avroio.py +++ b/sdks/python/apache_beam/io/avroio.py @@ -354,8 +354,7 @@ def split_points_unclaimed(stop_position): while range_tracker.try_claim(next_block_start): block = next(blocks) next_block_start = block.offset + block.size - for record in block: - yield record + yield from block _create_avro_source = _FastAvroSource @@ -375,7 +374,8 @@ def __init__( num_shards=0, shard_name_template=None, mime_type='application/x-avro', - use_fastavro=True): + use_fastavro=True, + triggering_frequency=None): """Initialize a WriteToAvro transform. Args: @@ -393,25 +393,44 @@ def __init__( Constraining the number of shards is likely to reduce the performance of a pipeline. Setting this value is not recommended unless you require a specific number of output files. + In streaming if not set, the service will write a file per bundle. shard_name_template: A template string containing placeholders for - the shard number and shard count. When constructing a filename for a - particular shard number, the upper-case letters 'S' and 'N' are - replaced with the 0-padded shard number and shard count respectively. - This argument can be '' in which case it behaves as if num_shards was - set to 1 and only one file will be generated. The default pattern used - is '-SSSSS-of-NNNNN' if None is passed as the shard_name_template. + the shard number and shard count. Currently only ``''``, + ``'-SSSSS-of-NNNNN'``, ``'-W-SSSSS-of-NNNNN'`` and + ``'-V-SSSSS-of-NNNNN'`` are patterns accepted by the service. + When constructing a filename for a particular shard number, the + upper-case letters ``S`` and ``N`` are replaced with the ``0``-padded + shard number and shard count respectively. This argument can be ``''`` + in which case it behaves as if num_shards was set to 1 and only one file + will be generated. The default pattern used is ``'-SSSSS-of-NNNNN'`` for + bounded PCollections and for ``'-W-SSSSS-of-NNNNN'`` unbounded + PCollections. + W is used for windowed shard naming and is replaced with + ``[window.start, window.end)`` + V is used for windowed shard naming and is replaced with + ``[window.start.to_utc_datetime().strftime("%Y-%m-%dT%H-%M-%S"), + window.end.to_utc_datetime().strftime("%Y-%m-%dT%H-%M-%S")`` mime_type: The MIME type to use for the produced files, if the filesystem supports specifying MIME types. use_fastavro (bool): This flag is left for API backwards compatibility and no longer has an effect. Do not use. + triggering_frequency: (int) Every triggering_frequency duration, a window + will be triggered and all bundles in the window will be written. + If set it overrides user windowing. Mandatory for GlobalWindow. Returns: A WriteToAvro transform usable for writing. """ self._schema = schema self._sink_provider = lambda avro_schema: _create_avro_sink( - file_path_prefix, avro_schema, codec, file_name_suffix, num_shards, - shard_name_template, mime_type) + file_path_prefix, + avro_schema, + codec, + file_name_suffix, + num_shards, + shard_name_template, + mime_type, + triggering_frequency) def expand(self, pcoll): if self._schema: @@ -428,6 +447,15 @@ def expand(self, pcoll): records = pcoll | beam.Map( beam_row_to_avro_dict(avro_schema, beam_schema)) self._sink = self._sink_provider(avro_schema) + if (not pcoll.is_bounded and self._sink.shard_name_template == + filebasedsink.DEFAULT_SHARD_NAME_TEMPLATE): + self._sink.shard_name_template = ( + filebasedsink.DEFAULT_WINDOW_SHARD_NAME_TEMPLATE) + self._sink.shard_name_format = self._sink._template_to_format( + self._sink.shard_name_template) + self._sink.shard_name_glob_format = self._sink._template_to_glob_format( + self._sink.shard_name_template) + return records | beam.io.iobase.Write(self._sink) def display_data(self): @@ -441,7 +469,8 @@ def _create_avro_sink( file_name_suffix, num_shards, shard_name_template, - mime_type): + mime_type, + triggering_frequency=60): if "class 'avro.schema" in str(type(schema)): raise ValueError( 'You are using Avro IO with fastavro (default with Beam on ' @@ -454,7 +483,8 @@ def _create_avro_sink( file_name_suffix, num_shards, shard_name_template, - mime_type) + mime_type, + triggering_frequency) class _BaseAvroSink(filebasedsink.FileBasedSink): @@ -467,7 +497,8 @@ def __init__( file_name_suffix, num_shards, shard_name_template, - mime_type): + mime_type, + triggering_frequency): super().__init__( file_path_prefix, file_name_suffix=file_name_suffix, @@ -477,7 +508,8 @@ def __init__( mime_type=mime_type, # Compression happens at the block level using the supplied codec, and # not at the file level. - compression_type=CompressionTypes.UNCOMPRESSED) + compression_type=CompressionTypes.UNCOMPRESSED, + triggering_frequency=triggering_frequency) self._schema = schema self._codec = codec @@ -498,7 +530,8 @@ def __init__( file_name_suffix, num_shards, shard_name_template, - mime_type): + mime_type, + triggering_frequency): super().__init__( file_path_prefix, schema, @@ -506,7 +539,8 @@ def __init__( file_name_suffix, num_shards, shard_name_template, - mime_type) + mime_type, + triggering_frequency) self.file_handle = None def open(self, temp_path): diff --git a/sdks/python/apache_beam/io/avroio_test.py b/sdks/python/apache_beam/io/avroio_test.py index 633b1307eb45..a5af0c48db14 100644 --- a/sdks/python/apache_beam/io/avroio_test.py +++ b/sdks/python/apache_beam/io/avroio_test.py @@ -16,11 +16,15 @@ # # pytype: skip-file +import glob import json import logging import math import os +import pytz import pytest +import re +import shutil import tempfile import unittest from typing import List, Any @@ -47,14 +51,17 @@ from apache_beam.io.filesystems import FileSystems from apache_beam.options.pipeline_options import StandardOptions from apache_beam.testing.test_pipeline import TestPipeline +from apache_beam.testing.test_stream import TestStream from apache_beam.testing.util import assert_that from apache_beam.testing.util import equal_to from apache_beam.transforms.display import DisplayData from apache_beam.transforms.display_test import DisplayDataItemMatcher from apache_beam.transforms.sql import SqlTransform from apache_beam.transforms.userstate import CombiningValueStateSpec +from apache_beam.transforms.util import LogElements from apache_beam.utils.timestamp import Timestamp from apache_beam.typehints import schemas +from datetime import datetime # Import snappy optionally; some tests will be skipped when import fails. try: @@ -625,6 +632,273 @@ def _write_data( return f.name +class GenerateEvent(beam.PTransform): + @staticmethod + def sample_data(): + return GenerateEvent() + + def expand(self, input): + elemlist = [{'age': 10}, {'age': 20}, {'age': 30}] + elem = elemlist + return ( + input + | TestStream().add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 1, 0, + tzinfo=pytz.UTC).timestamp()).add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 2, 0, + tzinfo=pytz.UTC).timestamp()).add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 3, 0, + tzinfo=pytz.UTC).timestamp()).add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 4, 0, + tzinfo=pytz.UTC).timestamp()). + advance_watermark_to( + datetime(2021, 3, 1, 0, 0, 5, 0, + tzinfo=pytz.UTC).timestamp()).add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 5, 0, + tzinfo=pytz.UTC).timestamp()). + add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 6, + 0, tzinfo=pytz.UTC).timestamp()).add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 7, 0, + tzinfo=pytz.UTC).timestamp()).add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 8, 0, + tzinfo=pytz.UTC).timestamp()).add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 9, 0, + tzinfo=pytz.UTC).timestamp()). + advance_watermark_to( + datetime(2021, 3, 1, 0, 0, 10, 0, + tzinfo=pytz.UTC).timestamp()).add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 10, 0, + tzinfo=pytz.UTC).timestamp()).add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 11, 0, + tzinfo=pytz.UTC).timestamp()). + add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 12, 0, + tzinfo=pytz.UTC).timestamp()).add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 13, 0, + tzinfo=pytz.UTC).timestamp()).add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 14, 0, + tzinfo=pytz.UTC).timestamp()). + advance_watermark_to( + datetime(2021, 3, 1, 0, 0, 15, 0, + tzinfo=pytz.UTC).timestamp()).add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 15, 0, + tzinfo=pytz.UTC).timestamp()).add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 16, 0, + tzinfo=pytz.UTC).timestamp()). + add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 17, 0, + tzinfo=pytz.UTC).timestamp()).add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 18, 0, + tzinfo=pytz.UTC).timestamp()).add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 19, 0, + tzinfo=pytz.UTC).timestamp()). + advance_watermark_to( + datetime(2021, 3, 1, 0, 0, 20, 0, + tzinfo=pytz.UTC).timestamp()).add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 20, 0, + tzinfo=pytz.UTC).timestamp()).advance_watermark_to( + datetime( + 2021, 3, 1, 0, 0, 25, 0, tzinfo=pytz.UTC). + timestamp()).advance_watermark_to_infinity()) + + +class WriteStreamingTest(unittest.TestCase): + def setUp(self): + super().setUp() + self.tempdir = tempfile.mkdtemp() + + def tearDown(self): + if os.path.exists(self.tempdir): + shutil.rmtree(self.tempdir) + + def test_write_streaming_2_shards_default_shard_name_template( + self, num_shards=2): + with TestPipeline() as p: + output = ( + p + | GenerateEvent.sample_data() + | 'User windowing' >> beam.transforms.core.WindowInto( + beam.transforms.window.FixedWindows(60), + trigger=beam.transforms.trigger.AfterWatermark(), + accumulation_mode=beam.transforms.trigger.AccumulationMode. + DISCARDING, + allowed_lateness=beam.utils.timestamp.Duration(seconds=0))) + #AvroIO + avroschema = { + 'name': 'dummy', # your supposed to be file name with .avro extension + 'type': 'record', # type of avro serilazation, there are more (see + # above docs) + 'fields': [ # this defines actual keys & their types + {'name': 'age', 'type': 'int'}, + ], + } + output2 = output | 'WriteToAvro' >> beam.io.WriteToAvro( + file_path_prefix=self.tempdir + "/ouput_WriteToAvro", + file_name_suffix=".avro", + num_shards=num_shards, + schema=avroschema) + _ = output2 | 'LogElements after WriteToAvro' >> LogElements( + prefix='after WriteToAvro ', with_window=True, level=logging.INFO) + + # Regex to match the expected windowed file pattern + # Example: + # ouput_WriteToAvro-[1614556800.0, 1614556805.0)-00000-of-00002.avro + # It captures: window_interval, shard_num, total_shards + pattern_string = ( + r'.*-\[(?P[\d\.]+), ' + r'(?P[\d\.]+|Infinity)\)-' + r'(?P\d{5})-of-(?P\d{5})\.avro$') + pattern = re.compile(pattern_string) + file_names = [] + for file_name in glob.glob(self.tempdir + '/ouput_WriteToAvro*'): + match = pattern.match(file_name) + self.assertIsNotNone( + match, f"File name {file_name} did not match expected pattern.") + if match: + file_names.append(file_name) + print("Found files matching expected pattern:", file_names) + self.assertEqual( + len(file_names), + num_shards, + "expected %d files, but got: %d" % (num_shards, len(file_names))) + + def test_write_streaming_2_shards_custom_shard_name_template( + self, num_shards=2, shard_name_template='-V-SSSSS-of-NNNNN'): + with TestPipeline() as p: + output = (p | GenerateEvent.sample_data()) + #AvroIO + avroschema = { + 'name': 'dummy', # your supposed to be file name with .avro extension + 'type': 'record', # type of avro serilazation + 'fields': [ # this defines actual keys & their types + {'name': 'age', 'type': 'int'}, + ], + } + output2 = output | 'WriteToAvro' >> beam.io.WriteToAvro( + file_path_prefix=self.tempdir + "/ouput_WriteToAvro", + file_name_suffix=".avro", + shard_name_template=shard_name_template, + num_shards=num_shards, + triggering_frequency=60, + schema=avroschema) + _ = output2 | 'LogElements after WriteToAvro' >> LogElements( + prefix='after WriteToAvro ', with_window=True, level=logging.INFO) + + # Regex to match the expected windowed file pattern + # Example: + # ouput_WriteToAvro-[2021-03-01T00-00-00, 2021-03-01T00-01-00)- + # 00000-of-00002.avro + # It captures: window_interval, shard_num, total_shards + pattern_string = ( + r'.*-\[(?P\d{4}-\d{2}-\d{2}T\d{2}-\d{2}-\d{2}), ' + r'(?P\d{4}-\d{2}-\d{2}T\d{2}-\d{2}-\d{2}|Infinity)\)-' + r'(?P\d{5})-of-(?P\d{5})\.avro$') + pattern = re.compile(pattern_string) + file_names = [] + for file_name in glob.glob(self.tempdir + '/ouput_WriteToAvro*'): + match = pattern.match(file_name) + self.assertIsNotNone( + match, f"File name {file_name} did not match expected pattern.") + if match: + file_names.append(file_name) + print("Found files matching expected pattern:", file_names) + self.assertEqual( + len(file_names), + num_shards, + "expected %d files, but got: %d" % (num_shards, len(file_names))) + + def test_write_streaming_2_shards_custom_shard_name_template_5s_window( + self, + num_shards=2, + shard_name_template='-V-SSSSS-of-NNNNN', + triggering_frequency=5): + with TestPipeline() as p: + output = (p | GenerateEvent.sample_data()) + #AvroIO + avroschema = { + 'name': 'dummy', # your supposed to be file name with .avro extension + 'type': 'record', # type of avro serilazation + 'fields': [ # this defines actual keys & their types + {'name': 'age', 'type': 'int'}, + ], + } + output2 = output | 'WriteToAvro' >> beam.io.WriteToAvro( + file_path_prefix=self.tempdir + "/ouput_WriteToAvro", + file_name_suffix=".txt", + shard_name_template=shard_name_template, + num_shards=num_shards, + triggering_frequency=triggering_frequency, + schema=avroschema) + _ = output2 | 'LogElements after WriteToAvro' >> LogElements( + prefix='after WriteToAvro ', with_window=True, level=logging.INFO) + + # Regex to match the expected windowed file pattern + # Example: + # ouput_WriteToAvro-[2021-03-01T00-00-00, 2021-03-01T00-01-00)- + # 00000-of-00002.avro + # It captures: window_interval, shard_num, total_shards + pattern_string = ( + r'.*-\[(?P\d{4}-\d{2}-\d{2}T\d{2}-\d{2}-\d{2}), ' + r'(?P\d{4}-\d{2}-\d{2}T\d{2}-\d{2}-\d{2}|Infinity)\)-' + r'(?P\d{5})-of-(?P\d{5})\.txt$') + pattern = re.compile(pattern_string) + file_names = [] + for file_name in glob.glob(self.tempdir + '/ouput_WriteToAvro*'): + match = pattern.match(file_name) + self.assertIsNotNone( + match, f"File name {file_name} did not match expected pattern.") + if match: + file_names.append(file_name) + print("Found files matching expected pattern:", file_names) + # for 5s window size, the input should be processed by 5 windows with + # 2 shards per window + self.assertEqual( + len(file_names), + 10, + "expected %d files, but got: %d" % (num_shards, len(file_names))) + + if __name__ == '__main__': logging.getLogger().setLevel(logging.INFO) unittest.main() diff --git a/sdks/python/apache_beam/io/filebasedsink.py b/sdks/python/apache_beam/io/filebasedsink.py index 8bb0f7e2171e..aff681606d99 100644 --- a/sdks/python/apache_beam/io/filebasedsink.py +++ b/sdks/python/apache_beam/io/filebasedsink.py @@ -33,9 +33,12 @@ from apache_beam.options.value_provider import StaticValueProvider from apache_beam.options.value_provider import ValueProvider from apache_beam.options.value_provider import check_accessible +from apache_beam.transforms import window from apache_beam.transforms.display import DisplayDataItem DEFAULT_SHARD_NAME_TEMPLATE = '-SSSSS-of-NNNNN' +DEFAULT_WINDOW_SHARD_NAME_TEMPLATE = '-W-SSSSS-of-NNNNN' +DEFAULT_TRIGGERING_FREQUENCY = 0 __all__ = ['FileBasedSink'] @@ -71,7 +74,9 @@ def __init__( *, max_records_per_shard=None, max_bytes_per_shard=None, - skip_if_empty=False): + skip_if_empty=False, + convert_fn=None, + triggering_frequency=None): """ Raises: TypeError: if file path parameters are not a :class:`str` or @@ -98,6 +103,8 @@ def __init__( shard_name_template = DEFAULT_SHARD_NAME_TEMPLATE elif shard_name_template == '': num_shards = 1 + if triggering_frequency is None: + triggering_frequency = DEFAULT_TRIGGERING_FREQUENCY if isinstance(file_path_prefix, str): file_path_prefix = StaticValueProvider(str, file_path_prefix) if isinstance(file_name_suffix, str): @@ -106,6 +113,7 @@ def __init__( self.file_name_suffix = file_name_suffix self.num_shards = num_shards self.coder = coder + self.shard_name_template = shard_name_template self.shard_name_format = self._template_to_format(shard_name_template) self.shard_name_glob_format = self._template_to_glob_format( shard_name_template) @@ -114,6 +122,8 @@ def __init__( self.max_records_per_shard = max_records_per_shard self.max_bytes_per_shard = max_bytes_per_shard self.skip_if_empty = skip_if_empty + self.convert_fn = convert_fn + self.triggering_frequency = triggering_frequency def display_data(self): return { @@ -202,25 +212,45 @@ def open_writer(self, init_result, uid): return FileBasedSinkWriter(self, writer_path) @check_accessible(['file_path_prefix', 'file_name_suffix']) - def _get_final_name(self, shard_num, num_shards): + def _get_final_name(self, shard_num, num_shards, w): + if w is None or isinstance(w, window.GlobalWindow): + window_utc = None + else: + window_utc = ( + '[' + w.start.to_utc_datetime().strftime("%Y-%m-%dT%H-%M-%S") + ', ' + + w.end.to_utc_datetime().strftime("%Y-%m-%dT%H-%M-%S") + ')') return ''.join([ self.file_path_prefix.get(), - self.shard_name_format % - dict(shard_num=shard_num, num_shards=num_shards), + self.shard_name_format % dict( + shard_num=shard_num, + num_shards=num_shards, + uuid=(uuid.uuid4()), + window=w, + window_utc=window_utc), self.file_name_suffix.get() ]) @check_accessible(['file_path_prefix', 'file_name_suffix']) - def _get_final_name_glob(self, num_shards): + def _get_final_name_glob(self, num_shards, w): + if w is None or isinstance(w, window.GlobalWindow): + window_utc = None + else: + window_utc = ( + '[' + w.start.to_utc_datetime().strftime("%Y-%m-%dT%H-%M-%S") + ', ' + + w.end.to_utc_datetime().strftime("%Y-%m-%dT%H-%M-%S") + ')') return ''.join([ self.file_path_prefix.get(), - self.shard_name_glob_format % dict(num_shards=num_shards), + self.shard_name_glob_format % dict( + num_shards=num_shards, + uuid=(uuid.uuid4()), + window=w, + window_utc=window_utc), self.file_name_suffix.get() ]) - def pre_finalize(self, init_result, writer_results): + def pre_finalize(self, init_result, writer_results, window=None): num_shards = len(list(writer_results)) - dst_glob = self._get_final_name_glob(num_shards) + dst_glob = self._get_final_name_glob(num_shards, window) dst_glob_files = [ file_metadata.path for mr in FileSystems.match([dst_glob]) for file_metadata in mr.metadata_list @@ -233,7 +263,8 @@ def pre_finalize(self, init_result, writer_results): self.shard_name_glob_format) FileSystems.delete(dst_glob_files) - def _check_state_for_finalize_write(self, writer_results, num_shards): + def _check_state_for_finalize_write( + self, writer_results, num_shards, window=None): """Checks writer output files' states. Returns: @@ -248,7 +279,7 @@ def _check_state_for_finalize_write(self, writer_results, num_shards): return [], [], [], 0 src_glob = FileSystems.join(FileSystems.split(writer_results[0])[0], '*') - dst_glob = self._get_final_name_glob(num_shards) + dst_glob = self._get_final_name_glob(num_shards, window) src_glob_files = set( file_metadata.path for mr in FileSystems.match([src_glob]) for file_metadata in mr.metadata_list) @@ -261,7 +292,7 @@ def _check_state_for_finalize_write(self, writer_results, num_shards): delete_files = [] num_skipped = 0 for shard_num, src in enumerate(writer_results): - final_name = self._get_final_name(shard_num, num_shards) + final_name = self._get_final_name(shard_num, num_shards, window) dst = final_name src_exists = src in src_glob_files dst_exists = dst in dst_glob_files @@ -299,12 +330,12 @@ def _report_sink_lineage(self, dst_glob, dst_files): @check_accessible(['file_path_prefix']) def finalize_write( - self, init_result, writer_results, unused_pre_finalize_results): + self, init_result, writer_results, unused_pre_finalize_results, w=None): writer_results = sorted(writer_results) num_shards = len(writer_results) src_files, dst_files, delete_files, num_skipped = ( - self._check_state_for_finalize_write(writer_results, num_shards)) + self._check_state_for_finalize_write(writer_results, num_shards, w)) num_skipped += len(delete_files) FileSystems.delete(delete_files) num_shards_to_finalize = len(src_files) @@ -322,16 +353,8 @@ def finalize_write( ] if num_shards_to_finalize: - _LOGGER.info( - 'Starting finalize_write threads with num_shards: %d (skipped: %d), ' - 'batches: %d, num_threads: %d', - num_shards_to_finalize, - num_skipped, - len(source_file_batch), - num_threads) start_time = time.time() - # Use a thread pool for renaming operations. def _rename_batch(batch): """_rename_batch executes batch rename operations.""" source_files, destination_files = batch @@ -355,19 +378,35 @@ def _rename_batch(batch): _LOGGER.debug('Rename successful: %s -> %s', src, dst) return exceptions - exception_batches = util.run_using_threadpool( - _rename_batch, - list(zip(source_file_batch, destination_file_batch)), - num_threads) - - all_exceptions = [ - e for exception_batch in exception_batches for e in exception_batch - ] - if all_exceptions: - raise Exception( - 'Encountered exceptions in finalize_write: %s' % all_exceptions) - - yield from dst_files + if w is None or isinstance(w, window.GlobalWindow): + # bounded input + # Use a thread pool for renaming operations. + exception_batches = util.run_using_threadpool( + _rename_batch, + list(zip(source_file_batch, destination_file_batch)), + num_threads) + + all_exceptions = [ + e for exception_batch in exception_batches for e in exception_batch + ] + if all_exceptions: + raise Exception( + 'Encountered exceptions in finalize_write: %s' % all_exceptions) + + yield from dst_files + else: + # unbounded input + batch = list([src_files, dst_files]) + exception_batches = _rename_batch(batch) + + all_exceptions = [ + e for exception_batch in exception_batches for e in exception_batch + ] + if all_exceptions: + raise Exception( + 'Encountered exceptions in finalize_write: %s' % all_exceptions) + + yield from dst_files _LOGGER.info( 'Renamed %d shards in %.2f seconds.', @@ -385,13 +424,34 @@ def _rename_batch(batch): # This error is not serious, we simply log it. _LOGGER.info('Unable to delete file: %s', init_result) + @staticmethod + def _template_replace_window(shard_name_template): + match = re.search('W+', shard_name_template) + if match: + shard_name_template = shard_name_template.replace( + match.group(0), '%%(window)0%ds' % len(match.group(0))) + match = re.search('V+', shard_name_template) + if match: + shard_name_template = shard_name_template.replace( + match.group(0), '%%(window_utc)0%ds' % len(match.group(0))) + return shard_name_template + + @staticmethod + def _template_replace_uuid(shard_name_template): + match = re.search('U+', shard_name_template) + if match: + shard_name_template = shard_name_template.replace( + match.group(0), '%%(uuid)0%dd' % len(match.group(0))) + return FileBasedSink._template_replace_window(shard_name_template) + @staticmethod def _template_replace_num_shards(shard_name_template): match = re.search('N+', shard_name_template) if match: shard_name_template = shard_name_template.replace( match.group(0), '%%(num_shards)0%dd' % len(match.group(0))) - return shard_name_template + #return shard_name_template + return FileBasedSink._template_replace_uuid(shard_name_template) @staticmethod def _template_to_format(shard_name_template): diff --git a/sdks/python/apache_beam/io/iobase.py b/sdks/python/apache_beam/io/iobase.py index 53215275e050..1d9838d645bd 100644 --- a/sdks/python/apache_beam/io/iobase.py +++ b/sdks/python/apache_beam/io/iobase.py @@ -42,6 +42,7 @@ from typing import Tuple from typing import Union +import apache_beam as beam from apache_beam import coders from apache_beam import pvalue from apache_beam.coders.coders import _MemoizingPickleCoder @@ -56,6 +57,7 @@ from apache_beam.transforms import core from apache_beam.transforms import ptransform from apache_beam.transforms import window +from apache_beam.transforms.core import DoFn from apache_beam.transforms.display import DisplayDataItem from apache_beam.transforms.display import HasDisplayData from apache_beam.utils import timestamp @@ -778,7 +780,7 @@ def open_writer(self, init_result, uid): """ raise NotImplementedError - def pre_finalize(self, init_result, writer_results): + def pre_finalize(self, init_result, writer_results, window=None): """Pre-finalization stage for sink. Called after all bundle writes are complete and before finalize_write. @@ -797,7 +799,8 @@ def pre_finalize(self, init_result, writer_results): """ raise NotImplementedError - def finalize_write(self, init_result, writer_results, pre_finalize_result): + def finalize_write( + self, init_result, writer_results, pre_finalize_result, w=None): """Finalizes the sink after all data is written to it. Given the result of initialization and an iterable of results from bundle @@ -830,6 +833,7 @@ def finalize_write(self, init_result, writer_results, pre_finalize_result): will only contain the result of a single successful write for a given bundle. pre_finalize_result: the result of ``pre_finalize()`` invocation. + w: DoFn window """ raise NotImplementedError @@ -1127,47 +1131,183 @@ def __init__(self, sink: Sink) -> None: self.sink = sink def expand(self, pcoll): - do_once = pcoll.pipeline | 'DoOnce' >> core.Create([None]) - init_result_coll = do_once | 'InitializeWrite' >> core.Map( - lambda _, sink: sink.initialize_write(), self.sink) + if (pcoll.is_bounded): + do_once = pcoll.pipeline | 'DoOnce' >> core.Create([None]) + init_result_coll = do_once | 'InitializeWrite' >> core.Map( + lambda _, sink: sink.initialize_write(), self.sink) if getattr(self.sink, 'num_shards', 0): min_shards = self.sink.num_shards - if min_shards == 1: - keyed_pcoll = pcoll | core.Map(lambda x: (None, x)) - else: - keyed_pcoll = pcoll | core.ParDo(_RoundRobinKeyFn(), count=min_shards) - write_result_coll = ( - keyed_pcoll - | core.WindowInto(window.GlobalWindows()) - | core.GroupByKey() - | 'WriteBundles' >> core.ParDo( - _WriteKeyedBundleDoFn(self.sink), AsSingleton(init_result_coll))) + + if (pcoll.is_bounded): + if min_shards == 1: + keyed_pcoll = pcoll | core.Map(lambda x: (None, x)) + else: + keyed_pcoll = pcoll | core.ParDo(_RoundRobinKeyFn(), count=min_shards) + write_result_coll = ( + keyed_pcoll + | core.WindowInto(window.GlobalWindows()) + | core.GroupByKey() + | 'WriteBundles' >> core.ParDo( + _WriteKeyedBundleDoFn(self.sink), AsSingleton(init_result_coll)) + ) + else: #unbounded PCollection needes to be written per window + if isinstance(pcoll.windowing.windowfn, window.GlobalWindows): + if (self.sink.triggering_frequency is None or + self.sink.triggering_frequency == 0): + raise ValueError( + 'To write a GlobalWindow PCollection, triggering_frequency must' + ' be set and be greater than 0') + widowed_pcoll = ( + pcoll #TODO GroupIntoBatches and trigger indef per freq + | core.WindowInto( + window.FixedWindows(self.sink.triggering_frequency), + trigger=beam.transforms.trigger.AfterWatermark(), + accumulation_mode=beam.transforms.trigger.AccumulationMode. + DISCARDING, + allowed_lateness=beam.utils.timestamp.Duration(seconds=0))) + else: + #keep user windowing, unless triggering_frequency has been specified + if (self.sink.triggering_frequency is not None and + self.sink.triggering_frequency > 0): + widowed_pcoll = ( + pcoll #TODO GroupIntoBatches and trigger indef per freq + | core.WindowInto( + window.FixedWindows(self.sink.triggering_frequency), + trigger=beam.transforms.trigger.AfterWatermark(), + accumulation_mode=beam.transforms.trigger.AccumulationMode. + DISCARDING, + allowed_lateness=beam.utils.timestamp.Duration(seconds=0))) + else: #keep user windowing + widowed_pcoll = pcoll + if self.sink.convert_fn is not None: + widowed_pcoll = widowed_pcoll | core.ParDo(self.sink.convert_fn) + if min_shards == 1: + keyed_pcoll = widowed_pcoll | core.Map(lambda x: (None, x)) + else: + keyed_pcoll = widowed_pcoll | core.ParDo( + _RoundRobinKeyFn(), count=min_shards) + init_result_window_coll = ( + keyed_pcoll + | 'Pair init' >> core.Map(lambda x: (None, x)) + | 'Pair init gbk' >> core.GroupByKey() + | 'InitializeWindowedWrite' >> core.Map( + lambda _, sink: sink.initialize_write(), self.sink)) + + write_result_coll = ( + keyed_pcoll + | 'Group by random key' >> core.GroupByKey() + | 'WriteWindowedBundles' >> core.ParDo( + _WriteWindowedBundleDoFn(sink=self.sink, per_key=True), + AsSingleton(init_result_window_coll)) + | 'Pair' >> core.Map(lambda x: (None, x)) + | core.GroupByKey() + | 'Extract' >> core.Map(lambda x: x[1])) + pre_finalized_write_result_coll = ( + write_result_coll + | 'PreFinalize' >> core.ParDo( + _PreFinalizeWindowedBundleDoFn(self.sink), + AsSingleton(init_result_window_coll))) + finalized_write_result_coll = ( + pre_finalized_write_result_coll + | 'FinalizeWrite' >> core.FlatMap( + _finalize_write, + self.sink, + AsSingleton(init_result_window_coll), + AsSingleton(write_result_coll), + min_shards, + AsIter(pre_finalized_write_result_coll)).with_output_types(str)) + return finalized_write_result_coll else: + _LOGGER.info( + "*** WriteImpl min_shards undef so it's 1, and we write per Bundle") min_shards = 1 - write_result_coll = ( - pcoll - | core.WindowInto(window.GlobalWindows()) - | 'WriteBundles' >> core.ParDo( - _WriteBundleDoFn(self.sink), AsSingleton(init_result_coll)) - | 'Pair' >> core.Map(lambda x: (None, x)) - | core.GroupByKey() - | 'Extract' >> core.FlatMap(lambda x: x[1])) + if (pcoll.is_bounded): + write_result_coll = ( + pcoll + | core.WindowInto(window.GlobalWindows()) + | 'WriteBundles' >> core.ParDo( + _WriteBundleDoFn(self.sink), AsSingleton(init_result_coll)) + | 'Pair' >> core.Map(lambda x: (None, x)) + | core.GroupByKey() + | 'Extract' >> core.FlatMap(lambda x: x[1])) + else: #unbounded PCollection needes to be written per window + if isinstance(pcoll.windowing.windowfn, window.GlobalWindows): + if (self.sink.triggering_frequency is None or + self.sink.triggering_frequency == 0): + raise ValueError( + 'To write a GlobalWindow PCollection, triggering_frequency must' + ' be set and be greater than 0') + widowed_pcoll = ( + pcoll #TODO GroupIntoBatches and trigger indef per freq + | core.WindowInto( + window.FixedWindows(self.sink.triggering_frequency), + trigger=beam.transforms.trigger.AfterWatermark(), + accumulation_mode=beam.transforms.trigger.AccumulationMode. + DISCARDING, + allowed_lateness=beam.utils.timestamp.Duration(seconds=0))) + else: + #keep user windowing, unless triggering_frequency has been specified + if (self.sink.triggering_frequency is not None and + self.sink.triggering_frequency > 0): + widowed_pcoll = ( + pcoll #TODO GroupIntoBatches and trigger indef per freq + | core.WindowInto( + window.FixedWindows(self.sink.triggering_frequency), + trigger=beam.transforms.trigger.AfterWatermark(), + accumulation_mode=beam.transforms.trigger.AccumulationMode. + DISCARDING, + allowed_lateness=beam.utils.timestamp.Duration(seconds=0))) + else: #keep user windowing + widowed_pcoll = pcoll + init_result_window_coll = ( + widowed_pcoll + | 'Pair init' >> core.Map(lambda x: (None, x)) + | 'Pair init gbk' >> core.GroupByKey() + | 'InitializeWindowedWrite' >> core.Map( + lambda _, sink: sink.initialize_write(), self.sink)) + if self.sink.convert_fn is not None: + widowed_pcoll = widowed_pcoll | core.ParDo(self.sink.convert_fn) + write_result_coll = ( + widowed_pcoll + | 'WriteWindowedBundles' >> core.ParDo( + _WriteWindowedBundleDoFn(self.sink), + AsSingleton(init_result_window_coll)) + | 'Pair' >> core.Map(lambda x: (None, x)) + | core.GroupByKey() + | 'Extract' >> core.Map(lambda x: x[1])) + pre_finalized_write_result_coll = ( + write_result_coll + | 'PreFinalize' >> core.ParDo( + _PreFinalizeWindowedBundleDoFn(self.sink), + AsSingleton(init_result_window_coll))) + finalized_write_result_coll = ( + pre_finalized_write_result_coll + | 'FinalizeWrite' >> core.FlatMap( + _finalize_write, + self.sink, + AsSingleton(init_result_window_coll), + AsSingleton(write_result_coll), + min_shards, + AsIter(pre_finalized_write_result_coll)).with_output_types(str)) + return finalized_write_result_coll # PreFinalize should run before FinalizeWrite, and the two should not be # fused. - pre_finalize_coll = ( - do_once - | 'PreFinalize' >> core.FlatMap( - _pre_finalize, - self.sink, - AsSingleton(init_result_coll), - AsIter(write_result_coll))) - return do_once | 'FinalizeWrite' >> core.FlatMap( - _finalize_write, - self.sink, - AsSingleton(init_result_coll), - AsIter(write_result_coll), - min_shards, - AsSingleton(pre_finalize_coll)).with_output_types(str) + if (pcoll.is_bounded): + pre_finalize_coll = ( + do_once + | 'PreFinalize' >> core.FlatMap( + _pre_finalize, + self.sink, + AsSingleton(init_result_coll), + AsIter(write_result_coll))) + return ( + do_once | 'FinalizeWrite' >> core.FlatMap( + _finalize_write, + self.sink, + AsSingleton(init_result_coll), + AsIter(write_result_coll), + min_shards, + AsSingleton(pre_finalize_coll)).with_output_types(str)) class _WriteBundleDoFn(core.DoFn): @@ -1199,6 +1339,94 @@ def finish_bundle(self): window.GlobalWindow().max_timestamp(), [window.GlobalWindow()]) +class _PreFinalizeWindowedBundleDoFn(core.DoFn): + """A DoFn for writing elements to an iobase.Writer. + Opens a writer at the first element and closes the writer at finish_bundle(). + """ + def __init__( + self, + sink, + destination_fn=None, + temp_directory=None, + ): + self.sink = sink + self._temp_directory = temp_directory + self.destination_fn = destination_fn + + def display_data(self): + return {'sink_dd': self.sink} + + def process( + self, + element, + init_result, + w=core.DoFn.WindowParam, + pane=core.DoFn.PaneInfoParam): + self.sink.pre_finalize( + init_result=init_result, writer_results=element, window=w) + yield element + + +class _WriteWindowedBundleDoFn(core.DoFn): + """A DoFn for writing elements to an iobase.Writer. + Opens a writer at the first element and closes the writer at finish_bundle(). + """ + def __init__(self, sink, per_key=False): + self.sink = sink + self.per_key = per_key + + def display_data(self): + return {'sink_dd': self.sink} + + def start_bundle(self): + self.writer = {} + self.window = {} + self.init_result = {} + + def process( + self, + element, + init_result, + w=core.DoFn.WindowParam, + pane=core.DoFn.PaneInfoParam): + if self.per_key: + w_key = "%s_%s" % (w, element[0]) # key + else: + w_key = w + + if not w in self.writer: + # We ignore UUID collisions here since they are extremely rare. + self.window[w_key] = w + self.writer[w_key] = self.sink.open_writer( + init_result, '%s_%s' % (w_key, uuid.uuid4())) + self.init_result[w_key] = init_result + + if self.per_key: + for e in element[1]: # values + self.writer[w_key].write(e) # value + else: + self.writer[w_key].write(element) + if self.writer[w_key].at_capacity(): + yield self.writer[w_key].close() + self.writer[w_key] = None + + def finish_bundle(self): + for w_key, writer in self.writer.items(): + w = self.window[w_key] + if writer is not None: + closed = writer.temp_shard_path + try: + closed = writer.close() # TODO : improve sink closing for streaming + except ValueError as exp: + _LOGGER.info( + "*** _WriteWindowedBundleDoFn finish_bundle closed ERROR %s", exp) + yield WindowedValue( + closed, + timestamp=w.start, + windows=[w] # TODO(pabloem) HOW DO WE GET THE PANE + ) + + class _WriteKeyedBundleDoFn(core.DoFn): def __init__(self, sink): self.sink = sink @@ -1224,7 +1452,8 @@ def _finalize_write( init_result, write_results, min_shards, - pre_finalize_results): + pre_finalize_results, + w=DoFn.WindowParam): write_results = list(write_results) extra_shards = [] if len(write_results) < min_shards: @@ -1235,10 +1464,10 @@ def _finalize_write( writer = sink.open_writer(init_result, str(uuid.uuid4())) extra_shards.append(writer.close()) outputs = sink.finalize_write( - init_result, write_results + extra_shards, pre_finalize_results) + init_result, write_results + extra_shards, pre_finalize_results, w) + if outputs: - return ( - window.TimestampedValue(v, timestamp.MAX_TIMESTAMP) for v in outputs) + return (window.TimestampedValue(v, w.end) for v in outputs) class _RoundRobinKeyFn(core.DoFn): diff --git a/sdks/python/apache_beam/io/iobase_it_test.py b/sdks/python/apache_beam/io/iobase_it_test.py new file mode 100644 index 000000000000..168d94b41feb --- /dev/null +++ b/sdks/python/apache_beam/io/iobase_it_test.py @@ -0,0 +1,72 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# pytype: skip-file + +import logging +import unittest +import uuid + +import apache_beam as beam +from apache_beam.io.textio import WriteToText +from apache_beam.testing.test_pipeline import TestPipeline +from apache_beam.transforms.window import FixedWindows + +# End-to-End tests for iobase +# Usage: +# cd sdks/python +# pip install build && python -m build --sdist +# DataflowRunner: +# python -m pytest -o log_cli=True -o log_level=Info \ +# apache_beam/io/iobase_it_test.py::IOBaseITTest \ +# --test-pipeline-options="--runner=TestDataflowRunner \ +# --project=apache-beam-testing --region=us-central1 \ +# --temp_location=gs://apache-beam-testing-temp/temp \ +# --sdk_location=dist/apache_beam-2.65.0.dev0.tar.gz" + + +class IOBaseITTest(unittest.TestCase): + def setUp(self): + self.test_pipeline = TestPipeline(is_integration_test=True) + self.runner_name = type(self.test_pipeline.runner).__name__ + + def test_unbounded_pcoll_without_gloabl_window(self): + # https://github.com/apache/beam/issues/25598 + + args = self.test_pipeline.get_full_options_as_args(streaming=True, ) + + topic = 'projects/pubsub-public-data/topics/taxirides-realtime' + unique_id = str(uuid.uuid4()) + output_file = f'gs://apache-beam-testing-integration-testing/iobase/test-{unique_id}' # pylint: disable=line-too-long + + p = beam.Pipeline(argv=args) + # Read from Pub/Sub with fixed windowing + lines = ( + p + | "ReadFromPubSub" >> beam.io.ReadFromPubSub(topic=topic) + | "WindowInto" >> beam.WindowInto(FixedWindows(10))) + + # Write to text file + _ = lines | 'WriteToText' >> WriteToText(output_file) + + result = p.run() + result.wait_until_finish(duration=60 * 1000) + + +if __name__ == "__main__": + logging.getLogger().setLevel(logging.INFO) + unittest.main() diff --git a/sdks/python/apache_beam/io/parquetio.py b/sdks/python/apache_beam/io/parquetio.py index 48c51428c17d..89266888792d 100644 --- a/sdks/python/apache_beam/io/parquetio.py +++ b/sdks/python/apache_beam/io/parquetio.py @@ -48,6 +48,7 @@ from apache_beam.transforms import PTransform from apache_beam.transforms import window from apache_beam.typehints import schemas +from apache_beam.utils.windowed_value import WindowedValue try: import pyarrow as pa @@ -105,8 +106,10 @@ def __init__( self._buffer_size = record_batch_size self._record_batches = [] self._record_batches_byte_size = 0 + self._window = None - def process(self, row): + def process(self, row, w=DoFn.WindowParam, pane=DoFn.PaneInfoParam): + self._window = w if len(self._buffer[0]) >= self._buffer_size: self._flush_buffer() @@ -123,7 +126,17 @@ def finish_bundle(self): self._flush_buffer() if self._record_batches_byte_size > 0: table = self._create_table() - yield window.GlobalWindows.windowed_value_at_end_of_window(table) + if self._window is None or isinstance(self._window, window.GlobalWindow): + # bounded input + yield window.GlobalWindows.windowed_value_at_end_of_window(table) + else: + # unbounded input + yield WindowedValue( + table, + timestamp=self._window. + end, #or it could be max of timestamp of the rows processed + windows=[self._window] # TODO(pabloem) HOW DO WE GET THE PANE + ) def display_data(self): res = super().display_data() @@ -476,7 +489,9 @@ def __init__( file_name_suffix='', num_shards=0, shard_name_template=None, - mime_type='application/x-parquet'): + mime_type='application/x-parquet', + triggering_frequency=None, + ): """Initialize a WriteToParquet transform. Writes parquet files from a :class:`~apache_beam.pvalue.PCollection` of @@ -540,14 +555,26 @@ def __init__( the performance of a pipeline. Setting this value is not recommended unless you require a specific number of output files. shard_name_template: A template string containing placeholders for - the shard number and shard count. When constructing a filename for a - particular shard number, the upper-case letters 'S' and 'N' are - replaced with the 0-padded shard number and shard count respectively. - This argument can be '' in which case it behaves as if num_shards was - set to 1 and only one file will be generated. The default pattern used - is '-SSSSS-of-NNNNN' if None is passed as the shard_name_template. + the shard number and shard count. Currently only ``''``, + ``'-SSSSS-of-NNNNN'``, ``'-W-SSSSS-of-NNNNN'`` and + ``'-V-SSSSS-of-NNNNN'`` are patterns accepted by the service. + When constructing a filename for a particular shard number, the + upper-case letters ``S`` and ``N`` are replaced with the ``0``-padded + shard number and shard count respectively. This argument can be ``''`` + in which case it behaves as if num_shards was set to 1 and only one file + will be generated. The default pattern used is ``'-SSSSS-of-NNNNN'`` for + bounded PCollections and for ``'-W-SSSSS-of-NNNNN'`` unbounded + PCollections. + W is used for windowed shard naming and is replaced with + ``[window.start, window.end)`` + V is used for windowed shard naming and is replaced with + ``[window.start.to_utc_datetime().strftime("%Y-%m-%dT%H-%M-%S"), + window.end.to_utc_datetime().strftime("%Y-%m-%dT%H-%M-%S")`` mime_type: The MIME type to use for the produced files, if the filesystem supports specifying MIME types. + triggering_frequency: (int) Every triggering_frequency duration, a window + will be triggered and all bundles in the window will be written. + If set it overrides user windowing. Mandatory for GlobalWindow. Returns: A WriteToParquet transform usable for writing. @@ -567,10 +594,20 @@ def __init__( file_name_suffix, num_shards, shard_name_template, - mime_type + mime_type, + triggering_frequency ) def expand(self, pcoll): + if (not pcoll.is_bounded and self._sink.shard_name_template == + filebasedsink.DEFAULT_SHARD_NAME_TEMPLATE): + self._sink.shard_name_template = ( + filebasedsink.DEFAULT_WINDOW_SHARD_NAME_TEMPLATE) + self._sink.shard_name_format = self._sink._template_to_format( + self._sink.shard_name_template) + self._sink.shard_name_glob_format = self._sink._template_to_glob_format( + self._sink.shard_name_template) + if self._schema is None: try: beam_schema = schemas.schema_from_element_type(pcoll.element_type) @@ -583,7 +620,11 @@ def expand(self, pcoll): else: convert_fn = _RowDictionariesToArrowTable( self._schema, self._row_group_buffer_size, self._record_batch_size) - return pcoll | ParDo(convert_fn) | Write(self._sink) + if pcoll.is_bounded: + return pcoll | ParDo(convert_fn) | Write(self._sink) + else: + self._sink.convert_fn = convert_fn + return pcoll | Write(self._sink) def display_data(self): return { @@ -610,7 +651,7 @@ def __init__( num_shards=0, shard_name_template=None, mime_type='application/x-parquet', - ): + triggering_frequency=None): """Initialize a WriteToParquetBatched transform. Writes parquet files from a :class:`~apache_beam.pvalue.PCollection` of @@ -668,11 +709,21 @@ def __init__( the shard number and shard count. When constructing a filename for a particular shard number, the upper-case letters 'S' and 'N' are replaced with the 0-padded shard number and shard count respectively. + W is used for windowed shard naming and is replaced with + ``[window.start, window.end)`` + V is used for windowed shard naming and is replaced with + ``[window.start.to_utc_datetime().isoformat(), + window.end.to_utc_datetime().isoformat()`` This argument can be '' in which case it behaves as if num_shards was - set to 1 and only one file will be generated. The default pattern used - is '-SSSSS-of-NNNNN' if None is passed as the shard_name_template. + set to 1 and only one file will be generated. + The default pattern used is '-SSSSS-of-NNNNN' if None is passed as the + shard_name_template and the PCollection is bounded. + The default pattern used is '-W-SSSSS-of-NNNNN' if None is passed as the + shard_name_template and the PCollection is unbounded. mime_type: The MIME type to use for the produced files, if the filesystem supports specifying MIME types. + triggering_frequency: (int) Every triggering_frequency duration, a window + will be triggered and all bundles in the window will be written. Returns: A WriteToParquetBatched transform usable for writing. @@ -688,10 +739,19 @@ def __init__( file_name_suffix, num_shards, shard_name_template, - mime_type + mime_type, + triggering_frequency ) def expand(self, pcoll): + if (not pcoll.is_bounded and self._sink.shard_name_template == + filebasedsink.DEFAULT_SHARD_NAME_TEMPLATE): + self._sink.shard_name_template = ( + filebasedsink.DEFAULT_WINDOW_SHARD_NAME_TEMPLATE) + self._sink.shard_name_format = self._sink._template_to_format( + self._sink.shard_name_template) + self._sink.shard_name_glob_format = self._sink._template_to_glob_format( + self._sink.shard_name_template) return pcoll | Write(self._sink) def display_data(self): @@ -707,7 +767,8 @@ def _create_parquet_sink( file_name_suffix, num_shards, shard_name_template, - mime_type): + mime_type, + triggering_frequency=60): return \ _ParquetSink( file_path_prefix, @@ -718,7 +779,8 @@ def _create_parquet_sink( file_name_suffix, num_shards, shard_name_template, - mime_type + mime_type, + triggering_frequency ) @@ -734,7 +796,8 @@ def __init__( file_name_suffix, num_shards, shard_name_template, - mime_type): + mime_type, + triggering_frequency): super().__init__( file_path_prefix, file_name_suffix=file_name_suffix, @@ -744,7 +807,8 @@ def __init__( mime_type=mime_type, # Compression happens at the block level using the supplied codec, and # not at the file level. - compression_type=CompressionTypes.UNCOMPRESSED) + compression_type=CompressionTypes.UNCOMPRESSED, + triggering_frequency=triggering_frequency) self._schema = schema self._codec = codec if ARROW_MAJOR_VERSION == 1 and self._codec.lower() == "lz4": diff --git a/sdks/python/apache_beam/io/parquetio_it_test.py b/sdks/python/apache_beam/io/parquetio_it_test.py index 052b54f3ebfb..5dd3eac63746 100644 --- a/sdks/python/apache_beam/io/parquetio_it_test.py +++ b/sdks/python/apache_beam/io/parquetio_it_test.py @@ -19,10 +19,14 @@ import logging import string import unittest +import uuid from collections import Counter +from datetime import datetime import pytest +import pytz +import apache_beam as beam from apache_beam import Create from apache_beam import DoFn from apache_beam import FlatMap @@ -37,6 +41,7 @@ from apache_beam.testing.util import BeamAssertException from apache_beam.transforms import CombineGlobally from apache_beam.transforms.combiners import Count +from apache_beam.transforms.periodicsequence import PeriodicImpulse try: import pyarrow as pa @@ -142,6 +147,42 @@ def get_int(self): return i +@unittest.skipIf(pa is None, "PyArrow is not installed.") +class WriteStreamingIT(unittest.TestCase): + def setUp(self): + self.test_pipeline = TestPipeline(is_integration_test=True) + self.runner_name = type(self.test_pipeline.runner).__name__ + super().setUp() + + def test_write_streaming_2_shards_default_shard_name_template( + self, num_shards=2): + + args = self.test_pipeline.get_full_options_as_args(streaming=True, ) + + unique_id = str(uuid.uuid4()) + output_file = f'gs://apache-beam-testing-integration-testing/iobase/test-{unique_id}' # pylint: disable=line-too-long + p = beam.Pipeline(argv=args) + pyschema = pa.schema([('age', pa.int64())]) + + _ = ( + p + | "generate impulse" >> PeriodicImpulse( + start_timestamp=datetime(2021, 3, 1, 0, 0, 1, 0, + tzinfo=pytz.UTC).timestamp(), + stop_timestamp=datetime(2021, 3, 1, 0, 0, 20, 0, + tzinfo=pytz.UTC).timestamp(), + fire_interval=1) + | "generate data" >> beam.Map(lambda t: {'age': t * 10}) + | 'WriteToParquet' >> beam.io.WriteToParquet( + file_path_prefix=output_file, + file_name_suffix=".parquet", + num_shards=num_shards, + triggering_frequency=60, + schema=pyschema)) + result = p.run() + result.wait_until_finish(duration=600 * 1000) + + if __name__ == '__main__': logging.getLogger().setLevel(logging.INFO) unittest.main() diff --git a/sdks/python/apache_beam/io/parquetio_test.py b/sdks/python/apache_beam/io/parquetio_test.py index fd19ec9520a9..3a2f7034f1a7 100644 --- a/sdks/python/apache_beam/io/parquetio_test.py +++ b/sdks/python/apache_beam/io/parquetio_test.py @@ -16,17 +16,21 @@ # # pytype: skip-file +import glob import json import logging import os +import re import shutil import tempfile import unittest +from datetime import datetime from tempfile import TemporaryDirectory import hamcrest as hc import pandas import pytest +import pytz from parameterized import param from parameterized import parameterized @@ -45,10 +49,12 @@ from apache_beam.io.parquetio import _create_parquet_sink from apache_beam.io.parquetio import _create_parquet_source from apache_beam.testing.test_pipeline import TestPipeline +from apache_beam.testing.test_stream import TestStream from apache_beam.testing.util import assert_that from apache_beam.testing.util import equal_to from apache_beam.transforms.display import DisplayData from apache_beam.transforms.display_test import DisplayDataItemMatcher +from apache_beam.transforms.util import LogElements try: import pyarrow as pa @@ -656,6 +662,290 @@ def test_read_all_from_parquet_with_filename(self): equal_to(result)) +class GenerateEvent(beam.PTransform): + @staticmethod + def sample_data(): + return GenerateEvent() + + def expand(self, input): + elemlist = [{'age': 10}, {'age': 20}, {'age': 30}] + elem = elemlist + return ( + input + | TestStream().add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 1, 0, + tzinfo=pytz.UTC).timestamp()).add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 2, 0, + tzinfo=pytz.UTC).timestamp()).add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 3, 0, + tzinfo=pytz.UTC).timestamp()).add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 4, 0, + tzinfo=pytz.UTC).timestamp()). + advance_watermark_to( + datetime(2021, 3, 1, 0, 0, 5, 0, + tzinfo=pytz.UTC).timestamp()).add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 5, 0, + tzinfo=pytz.UTC).timestamp()). + add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 6, + 0, tzinfo=pytz.UTC).timestamp()).add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 7, 0, + tzinfo=pytz.UTC).timestamp()).add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 8, 0, + tzinfo=pytz.UTC).timestamp()).add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 9, 0, + tzinfo=pytz.UTC).timestamp()). + advance_watermark_to( + datetime(2021, 3, 1, 0, 0, 10, 0, + tzinfo=pytz.UTC).timestamp()).add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 10, 0, + tzinfo=pytz.UTC).timestamp()).add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 11, 0, + tzinfo=pytz.UTC).timestamp()). + add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 12, 0, + tzinfo=pytz.UTC).timestamp()).add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 13, 0, + tzinfo=pytz.UTC).timestamp()).add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 14, 0, + tzinfo=pytz.UTC).timestamp()). + advance_watermark_to( + datetime(2021, 3, 1, 0, 0, 15, 0, + tzinfo=pytz.UTC).timestamp()).add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 15, 0, + tzinfo=pytz.UTC).timestamp()).add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 16, 0, + tzinfo=pytz.UTC).timestamp()). + add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 17, 0, + tzinfo=pytz.UTC).timestamp()).add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 18, 0, + tzinfo=pytz.UTC).timestamp()).add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 19, 0, + tzinfo=pytz.UTC).timestamp()). + advance_watermark_to( + datetime(2021, 3, 1, 0, 0, 20, 0, + tzinfo=pytz.UTC).timestamp()).add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 20, 0, + tzinfo=pytz.UTC).timestamp()).advance_watermark_to( + datetime( + 2021, 3, 1, 0, 0, 25, 0, tzinfo=pytz.UTC). + timestamp()).advance_watermark_to_infinity()) + + +class WriteStreamingTest(unittest.TestCase): + def setUp(self): + super().setUp() + self.tempdir = tempfile.mkdtemp() + + def tearDown(self): + if os.path.exists(self.tempdir): + shutil.rmtree(self.tempdir) + + def test_write_streaming_2_shards_default_shard_name_template( + self, num_shards=2): + with TestPipeline() as p: + output = (p | GenerateEvent.sample_data()) + #ParquetIO + pyschema = pa.schema([('age', pa.int64())]) + output2 = output | 'WriteToParquet' >> beam.io.WriteToParquet( + file_path_prefix=self.tempdir + "/ouput_WriteToParquet", + file_name_suffix=".parquet", + num_shards=num_shards, + triggering_frequency=60, + schema=pyschema) + _ = output2 | 'LogElements after WriteToParquet' >> LogElements( + prefix='after WriteToParquet ', with_window=True, level=logging.INFO) + + # Regex to match the expected windowed file pattern + # Example: + # ouput_WriteToParquet-[1614556800.0, 1614556805.0)-00000-of-00002.parquet + # It captures: window_interval, shard_num, total_shards + pattern_string = ( + r'.*-\[(?P[\d\.]+), ' + r'(?P[\d\.]+|Infinity)\)-' + r'(?P\d{5})-of-(?P\d{5})\.parquet$') + pattern = re.compile(pattern_string) + file_names = [] + for file_name in glob.glob(self.tempdir + '/ouput_WriteToParquet*'): + match = pattern.match(file_name) + self.assertIsNotNone( + match, f"File name {file_name} did not match expected pattern.") + if match: + file_names.append(file_name) + print("Found files matching expected pattern:", file_names) + self.assertEqual( + len(file_names), + num_shards, + "expected %d files, but got: %d" % (num_shards, len(file_names))) + + def test_write_streaming_2_shards_custom_shard_name_template( + self, num_shards=2, shard_name_template='-V-SSSSS-of-NNNNN'): + with TestPipeline() as p: + output = (p | GenerateEvent.sample_data()) + #ParquetIO + pyschema = pa.schema([('age', pa.int64())]) + output2 = output | 'WriteToParquet' >> beam.io.WriteToParquet( + file_path_prefix=self.tempdir + "/ouput_WriteToParquet", + file_name_suffix=".parquet", + shard_name_template=shard_name_template, + num_shards=num_shards, + triggering_frequency=60, + schema=pyschema) + _ = output2 | 'LogElements after WriteToParquet' >> LogElements( + prefix='after WriteToParquet ', with_window=True, level=logging.INFO) + + # Regex to match the expected windowed file pattern + # Example: + # ouput_WriteToParquet-[2021-03-01T00-00-00, 2021-03-01T00-01-00)- + # 00000-of-00002.parquet + # It captures: window_interval, shard_num, total_shards + pattern_string = ( + r'.*-\[(?P\d{4}-\d{2}-\d{2}T\d{2}-\d{2}-\d{2}), ' + r'(?P\d{4}-\d{2}-\d{2}T\d{2}-\d{2}-\d{2}|Infinity)\)-' + r'(?P\d{5})-of-(?P\d{5})\.parquet$') + pattern = re.compile(pattern_string) + file_names = [] + for file_name in glob.glob(self.tempdir + '/ouput_WriteToParquet*'): + match = pattern.match(file_name) + self.assertIsNotNone( + match, f"File name {file_name} did not match expected pattern.") + if match: + file_names.append(file_name) + print("Found files matching expected pattern:", file_names) + self.assertEqual( + len(file_names), + num_shards, + "expected %d files, but got: %d" % (num_shards, len(file_names))) + + def test_write_streaming_2_shards_custom_shard_name_template_5s_window( + self, + num_shards=2, + shard_name_template='-V-SSSSS-of-NNNNN', + triggering_frequency=5): + with TestPipeline() as p: + output = (p | GenerateEvent.sample_data()) + #ParquetIO + pyschema = pa.schema([('age', pa.int64())]) + output2 = output | 'WriteToParquet' >> beam.io.WriteToParquet( + file_path_prefix=self.tempdir + "/ouput_WriteToParquet", + file_name_suffix=".parquet", + shard_name_template=shard_name_template, + num_shards=num_shards, + triggering_frequency=triggering_frequency, + schema=pyschema) + _ = output2 | 'LogElements after WriteToParquet' >> LogElements( + prefix='after WriteToParquet ', with_window=True, level=logging.INFO) + + # Regex to match the expected windowed file pattern + # Example: + # ouput_WriteToParquet-[2021-03-01T00-00-00, 2021-03-01T00-01-00)- + # 00000-of-00002.parquet + # It captures: window_interval, shard_num, total_shards + pattern_string = ( + r'.*-\[(?P\d{4}-\d{2}-\d{2}T\d{2}-\d{2}-\d{2}), ' + r'(?P\d{4}-\d{2}-\d{2}T\d{2}-\d{2}-\d{2}|Infinity)\)-' + r'(?P\d{5})-of-(?P\d{5})\.parquet$') + pattern = re.compile(pattern_string) + file_names = [] + for file_name in glob.glob(self.tempdir + '/ouput_WriteToParquet*'): + match = pattern.match(file_name) + self.assertIsNotNone( + match, f"File name {file_name} did not match expected pattern.") + if match: + file_names.append(file_name) + print("Found files matching expected pattern:", file_names) + # for 5s window size, the input should be processed by 5 windows with + # 2 shards per window + self.assertEqual( + len(file_names), + 10, + "expected %d files, but got: %d" % (num_shards, len(file_names))) + + def test_write_streaming_undef_shards_default_shard_name_template_windowed_pcoll( # pylint: disable=line-too-long + self): + with TestPipeline() as p: + output = ( + p | GenerateEvent.sample_data() + | 'User windowing' >> beam.transforms.core.WindowInto( + beam.transforms.window.FixedWindows(10), + trigger=beam.transforms.trigger.AfterWatermark(), + accumulation_mode=beam.transforms.trigger.AccumulationMode. + DISCARDING, + allowed_lateness=beam.utils.timestamp.Duration(seconds=0))) + #ParquetIO + pyschema = pa.schema([('age', pa.int64())]) + output2 = output | 'WriteToParquet' >> beam.io.WriteToParquet( + file_path_prefix=self.tempdir + "/ouput_WriteToParquet", + file_name_suffix=".parquet", + num_shards=0, + schema=pyschema) + _ = output2 | 'LogElements after WriteToParquet' >> LogElements( + prefix='after WriteToParquet ', with_window=True, level=logging.INFO) + + # Regex to match the expected windowed file pattern + # Example: + # ouput_WriteToParquet-[1614556800.0, 1614556805.0)-00000-of-00002.parquet + # It captures: window_interval, shard_num, total_shards + pattern_string = ( + r'.*-\[(?P[\d\.]+), ' + r'(?P[\d\.]+|Infinity)\)-' + r'(?P\d{5})-of-(?P\d{5})\.parquet$') + pattern = re.compile(pattern_string) + file_names = [] + for file_name in glob.glob(self.tempdir + '/ouput_WriteToParquet*'): + match = pattern.match(file_name) + self.assertIsNotNone( + match, f"File name {file_name} did not match expected pattern.") + if match: + file_names.append(file_name) + print("Found files matching expected pattern:", file_names) + self.assertGreaterEqual( + len(file_names), + 1*3, #25s of data covered by 3 10s windows + "expected %d files, but got: %d" % (1*3, len(file_names))) + + if __name__ == '__main__': logging.getLogger().setLevel(logging.INFO) unittest.main() diff --git a/sdks/python/apache_beam/io/textio.py b/sdks/python/apache_beam/io/textio.py index d817463cfef6..7fa31c3d1a00 100644 --- a/sdks/python/apache_beam/io/textio.py +++ b/sdks/python/apache_beam/io/textio.py @@ -451,7 +451,8 @@ def __init__( *, max_records_per_shard=None, max_bytes_per_shard=None, - skip_if_empty=False): + skip_if_empty=False, + triggering_frequency=None): """Initialize a _TextSink. Args: @@ -468,13 +469,23 @@ def __init__( Constraining the number of shards is likely to reduce the performance of a pipeline. Setting this value is not recommended unless you require a specific number of output files. + In streaming if not set, the service will write a file per bundle. shard_name_template: A template string containing placeholders for - the shard number and shard count. When constructing a filename for a - particular shard number, the upper-case letters 'S' and 'N' are - replaced with the 0-padded shard number and shard count respectively. - This argument can be '' in which case it behaves as if num_shards was - set to 1 and only one file will be generated. The default pattern used - is '-SSSSS-of-NNNNN' if None is passed as the shard_name_template. + the shard number and shard count. Currently only ``''``, + ``'-SSSSS-of-NNNNN'``, ``'-W-SSSSS-of-NNNNN'`` and + ``'-V-SSSSS-of-NNNNN'`` are patterns accepted by the service. + When constructing a filename for a particular shard number, the + upper-case letters ``S`` and ``N`` are replaced with the ``0``-padded + shard number and shard count respectively. This argument can be ``''`` + in which case it behaves as if num_shards was set to 1 and only one file + will be generated. The default pattern used is ``'-SSSSS-of-NNNNN'`` for + bounded PCollections and for ``'-W-SSSSS-of-NNNNN'`` unbounded + PCollections. + W is used for windowed shard naming and is replaced with + ``[window.start, window.end)`` + V is used for windowed shard naming and is replaced with + ``[window.start.to_utc_datetime().strftime("%Y-%m-%dT%H-%M-%S"), + window.end.to_utc_datetime().strftime("%Y-%m-%dT%H-%M-%S")`` coder: Coder used to encode each line. compression_type: Used to handle compressed output files. Typical value is CompressionTypes.AUTO, in which case the final file path's @@ -494,6 +505,10 @@ def __init__( to exceed this value. This also tracks the uncompressed, not compressed, size of the shard. skip_if_empty: Don't write any shards if the PCollection is empty. + triggering_frequency: (int) Every triggering_frequency duration, a window + will be triggered and all bundles in the window will be written. + If set it overrides user windowing. Mandatory for GlobalWindow. + Returns: A _TextSink object usable for writing. @@ -508,7 +523,8 @@ def __init__( compression_type=compression_type, max_records_per_shard=max_records_per_shard, max_bytes_per_shard=max_bytes_per_shard, - skip_if_empty=skip_if_empty) + skip_if_empty=skip_if_empty, + triggering_frequency=triggering_frequency) self._append_trailing_newlines = append_trailing_newlines self._header = header self._footer = footer @@ -833,7 +849,8 @@ def __init__( *, max_records_per_shard=None, max_bytes_per_shard=None, - skip_if_empty=False): + skip_if_empty=False, + triggering_frequency=None): r"""Initialize a :class:`WriteToText` transform. Args: @@ -852,13 +869,21 @@ def __init__( the performance of a pipeline. Setting this value is not recommended unless you require a specific number of output files. shard_name_template (str): A template string containing placeholders for - the shard number and shard count. Currently only ``''`` and - ``'-SSSSS-of-NNNNN'`` are patterns accepted by the service. + the shard number and shard count. Currently only ``''``, + ``'-SSSSS-of-NNNNN'``, ``'-W-SSSSS-of-NNNNN'`` and + ``'-V-SSSSS-of-NNNNN'`` are patterns accepted by the service. When constructing a filename for a particular shard number, the upper-case letters ``S`` and ``N`` are replaced with the ``0``-padded shard number and shard count respectively. This argument can be ``''`` in which case it behaves as if num_shards was set to 1 and only one file - will be generated. The default pattern used is ``'-SSSSS-of-NNNNN'``. + will be generated. The default pattern used is ``'-SSSSS-of-NNNNN'`` for + bounded PCollections and for ``'-W-SSSSS-of-NNNNN'`` unbounded + PCollections. + W is used for windowed shard naming and is replaced with + ``[window.start, window.end)`` + V is used for windowed shard naming and is replaced with + ``[window.start.to_utc_datetime().strftime("%Y-%m-%dT%H-%M-%S"), + window.end.to_utc_datetime().strftime("%Y-%m-%dT%H-%M-%S")`` coder (~apache_beam.coders.coders.Coder): Coder used to encode each line. compression_type (str): Used to handle compressed output files. Typical value is :class:`CompressionTypes.AUTO @@ -883,6 +908,8 @@ def __init__( skip_if_empty: Don't write any shards if the PCollection is empty. In case of an empty PCollection, this will still delete existing files having same file path and not create new ones. + triggering_frequency: (int) Every triggering_frequency duration, a window + will be triggered and all bundles in the window will be written. """ self._sink = _TextSink( @@ -897,9 +924,18 @@ def __init__( footer, max_records_per_shard=max_records_per_shard, max_bytes_per_shard=max_bytes_per_shard, - skip_if_empty=skip_if_empty) + skip_if_empty=skip_if_empty, + triggering_frequency=triggering_frequency) def expand(self, pcoll): + if (not pcoll.is_bounded and self._sink.shard_name_template == + filebasedsink.DEFAULT_SHARD_NAME_TEMPLATE): + self._sink.shard_name_template = ( + filebasedsink.DEFAULT_WINDOW_SHARD_NAME_TEMPLATE) + self._sink.shard_name_format = self._sink._template_to_format( + self._sink.shard_name_template) + self._sink.shard_name_glob_format = self._sink._template_to_glob_format( + self._sink.shard_name_template) return pcoll | Write(self._sink) diff --git a/sdks/python/apache_beam/io/textio_test.py b/sdks/python/apache_beam/io/textio_test.py index 30ddc5d62e07..a1183adcc218 100644 --- a/sdks/python/apache_beam/io/textio_test.py +++ b/sdks/python/apache_beam/io/textio_test.py @@ -24,10 +24,14 @@ import logging import os import platform +import re import shutil import tempfile import unittest import zlib +from datetime import datetime + +import pytz import apache_beam as beam from apache_beam import coders @@ -45,11 +49,13 @@ from apache_beam.io.textio import WriteToText from apache_beam.options.pipeline_options import PipelineOptions from apache_beam.testing.test_pipeline import TestPipeline +from apache_beam.testing.test_stream import TestStream from apache_beam.testing.test_utils import TempDir from apache_beam.testing.util import assert_that from apache_beam.testing.util import equal_to from apache_beam.transforms.core import Create from apache_beam.transforms.userstate import CombiningValueStateSpec +from apache_beam.transforms.util import LogElements from apache_beam.utils.timestamp import Timestamp @@ -1849,6 +1855,406 @@ def check_types(element): _ = pcoll | beam.Map(check_types) +class GenerateEvent(beam.PTransform): + @staticmethod + def sample_data(): + return GenerateEvent() + + def expand(self, input): + elemlist = [{'age': 10}, {'age': 20}, {'age': 30}] + elem = elemlist + return ( + input + | TestStream().add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 1, 0, + tzinfo=pytz.UTC).timestamp()).add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 2, 0, + tzinfo=pytz.UTC).timestamp()).add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 3, 0, + tzinfo=pytz.UTC).timestamp()).add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 4, 0, + tzinfo=pytz.UTC).timestamp()). + advance_watermark_to( + datetime(2021, 3, 1, 0, 0, 5, 0, + tzinfo=pytz.UTC).timestamp()).add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 5, 0, + tzinfo=pytz.UTC).timestamp()). + add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 6, + 0, tzinfo=pytz.UTC).timestamp()).add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 7, 0, + tzinfo=pytz.UTC).timestamp()).add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 8, 0, + tzinfo=pytz.UTC).timestamp()).add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 9, 0, + tzinfo=pytz.UTC).timestamp()). + advance_watermark_to( + datetime(2021, 3, 1, 0, 0, 10, 0, + tzinfo=pytz.UTC).timestamp()).add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 10, 0, + tzinfo=pytz.UTC).timestamp()).add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 11, 0, + tzinfo=pytz.UTC).timestamp()). + add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 12, 0, + tzinfo=pytz.UTC).timestamp()).add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 13, 0, + tzinfo=pytz.UTC).timestamp()).add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 14, 0, + tzinfo=pytz.UTC).timestamp()). + advance_watermark_to( + datetime(2021, 3, 1, 0, 0, 15, 0, + tzinfo=pytz.UTC).timestamp()).add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 15, 0, + tzinfo=pytz.UTC).timestamp()).add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 16, 0, + tzinfo=pytz.UTC).timestamp()). + add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 17, 0, + tzinfo=pytz.UTC).timestamp()).add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 18, 0, + tzinfo=pytz.UTC).timestamp()).add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 19, 0, + tzinfo=pytz.UTC).timestamp()). + advance_watermark_to( + datetime(2021, 3, 1, 0, 0, 20, 0, + tzinfo=pytz.UTC).timestamp()).add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 20, 0, + tzinfo=pytz.UTC).timestamp()).advance_watermark_to( + datetime( + 2021, 3, 1, 0, 0, 25, 0, tzinfo=pytz.UTC). + timestamp()).advance_watermark_to_infinity()) + + +class WriteStreamingTest(unittest.TestCase): + def setUp(self): + super().setUp() + self.tempdir = tempfile.mkdtemp() + + def tearDown(self): + if os.path.exists(self.tempdir): + shutil.rmtree(self.tempdir) + + def test_write_streaming_2_shards_default_shard_name_template( + self, num_shards=2): + with TestPipeline() as p: + output = (p | GenerateEvent.sample_data()) + #TextIO + output2 = output | 'TextIO WriteToText' >> beam.io.WriteToText( + file_path_prefix=self.tempdir + "/ouput_WriteToText", + file_name_suffix=".txt", + num_shards=num_shards, + triggering_frequency=60) + _ = output2 | 'LogElements after WriteToText' >> LogElements( + prefix='after WriteToText ', with_window=True, level=logging.INFO) + + # Regex to match the expected windowed file pattern + # Example: + # ouput_WriteToText-[1614556800.0, 1614556805.0)-00000-of-00002.txt + # It captures: window_interval, shard_num, total_shards + pattern_string = ( + r'.*-\[(?P[\d\.]+), ' + r'(?P[\d\.]+|Infinity)\)-' + r'(?P\d{5})-of-(?P\d{5})\.txt$') + pattern = re.compile(pattern_string) + file_names = [] + for file_name in glob.glob(self.tempdir + '/ouput_WriteToText*'): + match = pattern.match(file_name) + self.assertIsNotNone( + match, f"File name {file_name} did not match expected pattern.") + if match: + file_names.append(file_name) + print("Found files matching expected pattern:", file_names) + self.assertEqual( + len(file_names), + num_shards, + "expected %d files, but got: %d" % (num_shards, len(file_names))) + + def test_write_streaming_2_shards_default_shard_name_template_windowed_pcoll( + self, num_shards=2): + with TestPipeline() as p: + output = ( + p | GenerateEvent.sample_data() + | 'User windowing' >> beam.transforms.core.WindowInto( + beam.transforms.window.FixedWindows(10), + trigger=beam.transforms.trigger.AfterWatermark(), + accumulation_mode=beam.transforms.trigger.AccumulationMode. + DISCARDING, + allowed_lateness=beam.utils.timestamp.Duration(seconds=0))) + #TextIO + output2 = output | 'TextIO WriteToText' >> beam.io.WriteToText( + file_path_prefix=self.tempdir + "/ouput_WriteToText", + file_name_suffix=".txt", + num_shards=num_shards, + ) + _ = output2 | 'LogElements after WriteToText' >> LogElements( + prefix='after WriteToText ', with_window=True, level=logging.INFO) + + # Regex to match the expected windowed file pattern + # Example: + # ouput_WriteToText-[1614556800.0, 1614556805.0)-00000-of-00002.txt + # It captures: window_interval, shard_num, total_shards + pattern_string = ( + r'.*-\[(?P[\d\.]+), ' + r'(?P[\d\.]+|Infinity)\)-' + r'(?P\d{5})-of-(?P\d{5})\.txt$') + pattern = re.compile(pattern_string) + file_names = [] + for file_name in glob.glob(self.tempdir + '/ouput_WriteToText*'): + match = pattern.match(file_name) + self.assertIsNotNone( + match, f"File name {file_name} did not match expected pattern.") + if match: + file_names.append(file_name) + print("Found files matching expected pattern:", file_names) + self.assertEqual( + len(file_names), + num_shards*3, #25s of data covered by 3 10s windows + "expected %d files, but got: %d" % (num_shards*3, len(file_names))) + + def test_write_streaming_undef_shards_default_shard_name_template_windowed_pcoll( # pylint: disable=line-too-long + self): + with TestPipeline() as p: + output = ( + p | GenerateEvent.sample_data() + | 'User windowing' >> beam.transforms.core.WindowInto( + beam.transforms.window.FixedWindows(10), + trigger=beam.transforms.trigger.AfterWatermark(), + accumulation_mode=beam.transforms.trigger.AccumulationMode. + DISCARDING, + allowed_lateness=beam.utils.timestamp.Duration(seconds=0))) + #TextIO + output2 = output | 'TextIO WriteToText' >> beam.io.WriteToText( + file_path_prefix=self.tempdir + "/ouput_WriteToText", + file_name_suffix=".txt", + num_shards=0, + ) + _ = output2 | 'LogElements after WriteToText' >> LogElements( + prefix='after WriteToText ', with_window=True, level=logging.INFO) + + # Regex to match the expected windowed file pattern + # Example: + # ouput_WriteToText-[1614556800.0, 1614556805.0)-00000-of-00002.txt + # It captures: window_interval, shard_num, total_shards + pattern_string = ( + r'.*-\[(?P[\d\.]+), ' + r'(?P[\d\.]+|Infinity)\)-' + r'(?P\d{5})-of-(?P\d{5})\.txt$') + pattern = re.compile(pattern_string) + file_names = [] + for file_name in glob.glob(self.tempdir + '/ouput_WriteToText*'): + match = pattern.match(file_name) + self.assertIsNotNone( + match, f"File name {file_name} did not match expected pattern.") + if match: + file_names.append(file_name) + print("Found files matching expected pattern:", file_names) + self.assertGreaterEqual( + len(file_names), + 1*3, #25s of data covered by 3 10s windows + "expected %d files, but got: %d" % (1*3, len(file_names))) + + def test_write_streaming_undef_shards_default_shard_name_template_windowed_pcoll_and_trig_freq( # pylint: disable=line-too-long + self): + with TestPipeline() as p: + output = ( + p | GenerateEvent.sample_data() + | 'User windowing' >> beam.transforms.core.WindowInto( + beam.transforms.window.FixedWindows(60), + trigger=beam.transforms.trigger.AfterWatermark(), + accumulation_mode=beam.transforms.trigger.AccumulationMode. + DISCARDING, + allowed_lateness=beam.utils.timestamp.Duration(seconds=0))) + #TextIO + output2 = output | 'TextIO WriteToText' >> beam.io.WriteToText( + file_path_prefix=self.tempdir + "/ouput_WriteToText", + file_name_suffix=".txt", + num_shards=0, + triggering_frequency=10, + ) + _ = output2 | 'LogElements after WriteToText' >> LogElements( + prefix='after WriteToText ', with_window=True, level=logging.INFO) + + # Regex to match the expected windowed file pattern + # Example: + # ouput_WriteToText-[1614556800.0, 1614556805.0)-00000-of-00002.txt + # It captures: window_interval, shard_num, total_shards + pattern_string = ( + r'.*-\[(?P[\d\.]+), ' + r'(?P[\d\.]+|Infinity)\)-' + r'(?P\d{5})-of-(?P\d{5})\.txt$') + pattern = re.compile(pattern_string) + file_names = [] + for file_name in glob.glob(self.tempdir + '/ouput_WriteToText*'): + match = pattern.match(file_name) + self.assertIsNotNone( + match, f"File name {file_name} did not match expected pattern.") + if match: + file_names.append(file_name) + print("Found files matching expected pattern:", file_names) + self.assertGreaterEqual( + len(file_names), + 1*3, #25s of data covered by 3 10s windows + "expected %d files, but got: %d" % (1*3, len(file_names))) + + def test_write_streaming_undef_shards_default_shard_name_template_global_window_pcoll( # pylint: disable=line-too-long + self): + with TestPipeline() as p: + output = (p | GenerateEvent.sample_data()) + #TextIO + output2 = output | 'TextIO WriteToText' >> beam.io.WriteToText( + file_path_prefix=self.tempdir + "/ouput_WriteToText", + file_name_suffix=".txt", + num_shards=0, #0 means undef nb of shards, same as omitted/default + triggering_frequency=60, + ) + _ = output2 | 'LogElements after WriteToText' >> LogElements( + prefix='after WriteToText ', with_window=True, level=logging.INFO) + + # Regex to match the expected windowed file pattern + # Example: + # ouput_WriteToText-[1614556800.0, 1614556805.0)-00000-of-00002.txt + # It captures: window_interval, shard_num, total_shards + pattern_string = ( + r'.*-\[(?P[\d\.]+), ' + r'(?P[\d\.]+|Infinity)\)-' + r'(?P\d{5})-of-(?P\d{5})\.txt$') + pattern = re.compile(pattern_string) + file_names = [] + for file_name in glob.glob(self.tempdir + '/ouput_WriteToText*'): + match = pattern.match(file_name) + self.assertIsNotNone( + match, f"File name {file_name} did not match expected pattern.") + if match: + file_names.append(file_name) + print("Found files matching expected pattern:", file_names) + self.assertGreaterEqual( + len(file_names), + 1*3, #25s of data covered by 3 10s windows + "expected %d files, but got: %d" % (1*3, len(file_names))) + + def test_write_streaming_2_shards_custom_shard_name_template( + self, num_shards=2, shard_name_template='-V-SSSSS-of-NNNNN'): + with TestPipeline() as p: + output = (p | GenerateEvent.sample_data()) + #TextIO + output2 = output | 'TextIO WriteToText' >> beam.io.WriteToText( + file_path_prefix=self.tempdir + "/ouput_WriteToText", + file_name_suffix=".txt", + shard_name_template=shard_name_template, + num_shards=num_shards, + triggering_frequency=60, + ) + _ = output2 | 'LogElements after WriteToText' >> LogElements( + prefix='after WriteToText ', with_window=True, level=logging.INFO) + + # Regex to match the expected windowed file pattern + # Example: + # ouput_WriteToText-[2021-03-01T00-00-00, 2021-03-01T00-01-00)- + # 00000-of-00002.txt + # It captures: window_interval, shard_num, total_shards + pattern_string = ( + r'.*-\[(?P\d{4}-\d{2}-\d{2}T\d{2}-\d{2}-\d{2}), ' + r'(?P\d{4}-\d{2}-\d{2}T\d{2}-\d{2}-\d{2}|Infinity)\)-' + r'(?P\d{5})-of-(?P\d{5})\.txt$') + pattern = re.compile(pattern_string) + file_names = [] + for file_name in glob.glob(self.tempdir + '/ouput_WriteToText*'): + match = pattern.match(file_name) + self.assertIsNotNone( + match, f"File name {file_name} did not match expected pattern.") + if match: + file_names.append(file_name) + print("Found files matching expected pattern:", file_names) + self.assertEqual( + len(file_names), + num_shards, + "expected %d files, but got: %d" % (num_shards, len(file_names))) + + def test_write_streaming_2_shards_custom_shard_name_template_5s_window( + self, + num_shards=2, + shard_name_template='-V-SSSSS-of-NNNNN', + triggering_frequency=5): + with TestPipeline() as p: + output = (p | GenerateEvent.sample_data()) + #TextIO + output2 = output | 'TextIO WriteToText' >> beam.io.WriteToText( + file_path_prefix=self.tempdir + "/ouput_WriteToText", + file_name_suffix=".txt", + shard_name_template=shard_name_template, + num_shards=num_shards, + triggering_frequency=triggering_frequency, + ) + _ = output2 | 'LogElements after WriteToText' >> LogElements( + prefix='after WriteToText ', with_window=True, level=logging.INFO) + + # Regex to match the expected windowed file pattern + # Example: + # ouput_WriteToText-[2021-03-01T00-00-00, 2021-03-01T00-01-00)- + # 00000-of-00002.txt + # It captures: window_interval, shard_num, total_shards + pattern_string = ( + r'.*-\[(?P\d{4}-\d{2}-\d{2}T\d{2}-\d{2}-\d{2}), ' + r'(?P\d{4}-\d{2}-\d{2}T\d{2}-\d{2}-\d{2}|Infinity)\)-' + r'(?P\d{5})-of-(?P\d{5})\.txt$') + pattern = re.compile(pattern_string) + file_names = [] + for file_name in glob.glob(self.tempdir + '/ouput_WriteToText*'): + match = pattern.match(file_name) + self.assertIsNotNone( + match, f"File name {file_name} did not match expected pattern.") + if match: + file_names.append(file_name) + print("Found files matching expected pattern:", file_names) + # for 5s window size, the input should be processed by 5 windows with + # 2 shards per window + self.assertEqual( + len(file_names), + 10, + "expected %d files, but got: %d" % (num_shards, len(file_names))) + + if __name__ == '__main__': logging.getLogger().setLevel(logging.INFO) unittest.main() diff --git a/sdks/python/apache_beam/io/tfrecordio.py b/sdks/python/apache_beam/io/tfrecordio.py index b911c64a1348..20617663e95c 100644 --- a/sdks/python/apache_beam/io/tfrecordio.py +++ b/sdks/python/apache_beam/io/tfrecordio.py @@ -290,7 +290,8 @@ def __init__( file_name_suffix, num_shards, shard_name_template, - compression_type): + compression_type, + triggering_frequency=60): """Initialize a TFRecordSink. See WriteToTFRecord for details.""" super().__init__( @@ -300,7 +301,8 @@ def __init__( num_shards=num_shards, shard_name_template=shard_name_template, mime_type='application/octet-stream', - compression_type=compression_type) + compression_type=compression_type, + triggering_frequency=triggering_frequency) def write_encoded_record(self, file_handle, value): _TFRecordUtil.write_record(file_handle, value) @@ -315,7 +317,8 @@ def __init__( file_name_suffix='', num_shards=0, shard_name_template=None, - compression_type=CompressionTypes.AUTO): + compression_type=CompressionTypes.AUTO, + triggering_frequency=None): """Initialize WriteToTFRecord transform. Args: @@ -326,16 +329,29 @@ def __init__( file_name_suffix: Suffix for the files written. num_shards: The number of files (shards) used for output. If not set, the default value will be used. + In streaming if not set, the service will write a file per bundle. shard_name_template: A template string containing placeholders for - the shard number and shard count. When constructing a filename for a - particular shard number, the upper-case letters 'S' and 'N' are - replaced with the 0-padded shard number and shard count respectively. - This argument can be '' in which case it behaves as if num_shards was - set to 1 and only one file will be generated. The default pattern used - is '-SSSSS-of-NNNNN' if None is passed as the shard_name_template. + the shard number and shard count. Currently only ``''``, + ``'-SSSSS-of-NNNNN'``, ``'-W-SSSSS-of-NNNNN'`` and + ``'-V-SSSSS-of-NNNNN'`` are patterns accepted by the service. + When constructing a filename for a particular shard number, the + upper-case letters ``S`` and ``N`` are replaced with the ``0``-padded + shard number and shard count respectively. This argument can be ``''`` + in which case it behaves as if num_shards was set to 1 and only one file + will be generated. The default pattern used is ``'-SSSSS-of-NNNNN'`` for + bounded PCollections and for ``'-W-SSSSS-of-NNNNN'`` unbounded + PCollections. + W is used for windowed shard naming and is replaced with + ``[window.start, window.end)`` + V is used for windowed shard naming and is replaced with + ``[window.start.to_utc_datetime().strftime("%Y-%m-%dT%H-%M-%S"), + window.end.to_utc_datetime().strftime("%Y-%m-%dT%H-%M-%S")`` compression_type: Used to handle compressed output files. Typical value is CompressionTypes.AUTO, in which case the file_path's extension will be used to detect the compression. + triggering_frequency: (int) Every triggering_frequency duration, a window + will be triggered and all bundles in the window will be written. + If set it overrides user windowing. Mandatory for GlobalWindow. Returns: A WriteToTFRecord transform object. @@ -347,7 +363,17 @@ def __init__( file_name_suffix, num_shards, shard_name_template, - compression_type) + compression_type, + triggering_frequency) def expand(self, pcoll): + if (not pcoll.is_bounded and self._sink.shard_name_template == + filebasedsink.DEFAULT_SHARD_NAME_TEMPLATE): + self._sink.shard_name_template = ( + filebasedsink.DEFAULT_WINDOW_SHARD_NAME_TEMPLATE) + self._sink.shard_name_format = self._sink._template_to_format( + self._sink.shard_name_template) + self._sink.shard_name_glob_format = self._sink._template_to_glob_format( + self._sink.shard_name_template) + return pcoll | Write(self._sink) diff --git a/sdks/python/apache_beam/io/tfrecordio_test.py b/sdks/python/apache_beam/io/tfrecordio_test.py index a867c0212ad3..6522ade36d80 100644 --- a/sdks/python/apache_beam/io/tfrecordio_test.py +++ b/sdks/python/apache_beam/io/tfrecordio_test.py @@ -21,15 +21,20 @@ import glob import gzip import io +import json import logging import os import pickle import random import re +import shutil +import tempfile import unittest import zlib +from datetime import datetime import crcmod +import pytz import apache_beam as beam from apache_beam import Create @@ -41,9 +46,11 @@ from apache_beam.io.tfrecordio import _TFRecordSink from apache_beam.io.tfrecordio import _TFRecordUtil from apache_beam.testing.test_pipeline import TestPipeline +from apache_beam.testing.test_stream import TestStream from apache_beam.testing.test_utils import TempDir from apache_beam.testing.util import assert_that from apache_beam.testing.util import equal_to +from apache_beam.transforms.util import LogElements try: import tensorflow.compat.v1 as tf # pylint: disable=import-error @@ -558,6 +565,258 @@ def test_end2end_read_write_read(self): assert_that(actual_data, equal_to(expected_data)) +class GenerateEvent(beam.PTransform): + @staticmethod + def sample_data(): + return GenerateEvent() + + def expand(self, input): + elemlist = [{'age': 10}, {'age': 20}, {'age': 30}] + elem = elemlist + return ( + input + | TestStream().add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 1, 0, + tzinfo=pytz.UTC).timestamp()).add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 2, 0, + tzinfo=pytz.UTC).timestamp()).add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 3, 0, + tzinfo=pytz.UTC).timestamp()).add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 4, 0, + tzinfo=pytz.UTC).timestamp()). + advance_watermark_to( + datetime(2021, 3, 1, 0, 0, 5, 0, + tzinfo=pytz.UTC).timestamp()).add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 5, 0, + tzinfo=pytz.UTC).timestamp()). + add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 6, + 0, tzinfo=pytz.UTC).timestamp()).add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 7, 0, + tzinfo=pytz.UTC).timestamp()).add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 8, 0, + tzinfo=pytz.UTC).timestamp()).add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 9, 0, + tzinfo=pytz.UTC).timestamp()). + advance_watermark_to( + datetime(2021, 3, 1, 0, 0, 10, 0, + tzinfo=pytz.UTC).timestamp()).add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 10, 0, + tzinfo=pytz.UTC).timestamp()).add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 11, 0, + tzinfo=pytz.UTC).timestamp()). + add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 12, 0, + tzinfo=pytz.UTC).timestamp()).add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 13, 0, + tzinfo=pytz.UTC).timestamp()).add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 14, 0, + tzinfo=pytz.UTC).timestamp()). + advance_watermark_to( + datetime(2021, 3, 1, 0, 0, 15, 0, + tzinfo=pytz.UTC).timestamp()).add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 15, 0, + tzinfo=pytz.UTC).timestamp()).add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 16, 0, + tzinfo=pytz.UTC).timestamp()). + add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 17, 0, + tzinfo=pytz.UTC).timestamp()).add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 18, 0, + tzinfo=pytz.UTC).timestamp()).add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 19, 0, + tzinfo=pytz.UTC).timestamp()). + advance_watermark_to( + datetime(2021, 3, 1, 0, 0, 20, 0, + tzinfo=pytz.UTC).timestamp()).add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 20, 0, + tzinfo=pytz.UTC).timestamp()).advance_watermark_to( + datetime( + 2021, 3, 1, 0, 0, 25, 0, tzinfo=pytz.UTC). + timestamp()).advance_watermark_to_infinity()) + + +class WriteStreamingTest(unittest.TestCase): + def setUp(self): + super().setUp() + self.tempdir = tempfile.mkdtemp() + + def tearDown(self): + if os.path.exists(self.tempdir): + shutil.rmtree(self.tempdir) + + def test_write_streaming_2_shards_default_shard_name_template( + self, num_shards=2): + with TestPipeline() as p: + output = ( + p + | GenerateEvent.sample_data() + | 'User windowing' >> beam.transforms.core.WindowInto( + beam.transforms.window.FixedWindows(60), + trigger=beam.transforms.trigger.AfterWatermark(), + accumulation_mode=beam.transforms.trigger.AccumulationMode. + DISCARDING, + allowed_lateness=beam.utils.timestamp.Duration(seconds=0)) + | "encode" >> beam.Map(lambda s: json.dumps(s).encode('utf-8'))) + #TFrecordIO + output2 = output | 'WriteToTFRecord' >> beam.io.WriteToTFRecord( + file_path_prefix=self.tempdir + "/ouput_WriteToTFRecord", + file_name_suffix=".tfrecord", + num_shards=num_shards, + ) + _ = output2 | 'LogElements after WriteToTFRecord' >> LogElements( + prefix='after WriteToTFRecord ', with_window=True, level=logging.INFO) + + # Regex to match the expected windowed file pattern + # Example: + # ouput_WriteToTFRecord-[1614556800.0, 1614556805.0)-00000-of-00002.tfrecord + # It captures: window_interval, shard_num, total_shards + pattern_string = ( + r'.*-\[(?P[\d\.]+), ' + r'(?P[\d\.]+|Infinity)\)-' + r'(?P\d{5})-of-(?P\d{5})\.tfrecord$') + pattern = re.compile(pattern_string) + file_names = [] + for file_name in glob.glob(self.tempdir + '/ouput_WriteToTFRecord*'): + match = pattern.match(file_name) + self.assertIsNotNone( + match, f"File name {file_name} did not match expected pattern.") + if match: + file_names.append(file_name) + print("Found files matching expected pattern:", file_names) + self.assertEqual( + len(file_names), + num_shards, + "expected %d files, but got: %d" % (num_shards, len(file_names))) + + def test_write_streaming_2_shards_custom_shard_name_template( + self, num_shards=2, shard_name_template='-V-SSSSS-of-NNNNN'): + with TestPipeline() as p: + output = ( + p + | GenerateEvent.sample_data() + | "encode" >> beam.Map(lambda s: json.dumps(s).encode('utf-8'))) + #TFrecordIO + output2 = output | 'WriteToTFRecord' >> beam.io.WriteToTFRecord( + file_path_prefix=self.tempdir + "/ouput_WriteToTFRecord", + file_name_suffix=".tfrecord", + shard_name_template=shard_name_template, + num_shards=num_shards, + triggering_frequency=60, + ) + _ = output2 | 'LogElements after WriteToTFRecord' >> LogElements( + prefix='after WriteToTFRecord ', with_window=True, level=logging.INFO) + + # Regex to match the expected windowed file pattern + # Example: + # ouput_WriteToTFRecord-[2021-03-01T00-00-00, 2021-03-01T00-01-00)- + # 00000-of-00002.tfrecord + # It captures: window_interval, shard_num, total_shards + pattern_string = ( + r'.*-\[(?P\d{4}-\d{2}-\d{2}T\d{2}-\d{2}-\d{2}), ' + r'(?P\d{4}-\d{2}-\d{2}T\d{2}-\d{2}-\d{2}|Infinity)\)-' + r'(?P\d{5})-of-(?P\d{5})\.tfrecord$') + pattern = re.compile(pattern_string) + file_names = [] + for file_name in glob.glob(self.tempdir + '/ouput_WriteToTFRecord*'): + match = pattern.match(file_name) + self.assertIsNotNone( + match, f"File name {file_name} did not match expected pattern.") + if match: + file_names.append(file_name) + print("Found files matching expected pattern:", file_names) + self.assertEqual( + len(file_names), + num_shards, + "expected %d files, but got: %d" % (num_shards, len(file_names))) + + def test_write_streaming_2_shards_custom_shard_name_template_5s_window( + self, + num_shards=2, + shard_name_template='-V-SSSSS-of-NNNNN', + triggering_frequency=5): + with TestPipeline() as p: + output = ( + p + | GenerateEvent.sample_data() + | "encode" >> beam.Map(lambda s: json.dumps(s).encode('utf-8'))) + #TFrecordIO + output2 = output | 'WriteToTFRecord' >> beam.io.WriteToTFRecord( + file_path_prefix=self.tempdir + "/ouput_WriteToTFRecord", + file_name_suffix=".tfrecord", + shard_name_template=shard_name_template, + num_shards=num_shards, + triggering_frequency=triggering_frequency, + ) + _ = output2 | 'LogElements after WriteToTFRecord' >> LogElements( + prefix='after WriteToTFRecord ', with_window=True, level=logging.INFO) + + # Regex to match the expected windowed file pattern + # Example: + # ouput_WriteToTFRecord-[2021-03-01T00-00-00, 2021-03-01T00-01-00)- + # 00000-of-00002.tfrecord + # It captures: window_interval, shard_num, total_shards + pattern_string = ( + r'.*-\[(?P\d{4}-\d{2}-\d{2}T\d{2}-\d{2}-\d{2}), ' + r'(?P\d{4}-\d{2}-\d{2}T\d{2}-\d{2}-\d{2}|Infinity)\)-' + r'(?P\d{5})-of-(?P\d{5})\.tfrecord$') + pattern = re.compile(pattern_string) + file_names = [] + for file_name in glob.glob(self.tempdir + '/ouput_WriteToTFRecord*'): + match = pattern.match(file_name) + self.assertIsNotNone( + match, f"File name {file_name} did not match expected pattern.") + if match: + file_names.append(file_name) + print("Found files matching expected pattern:", file_names) + # for 5s window size, the input should be processed by 5 windows with + # 2 shards per window + self.assertEqual( + len(file_names), + 10, + "expected %d files, but got: %d" % (num_shards, len(file_names))) + + if __name__ == '__main__': logging.getLogger().setLevel(logging.INFO) unittest.main() diff --git a/sdks/python/apache_beam/transforms/write_ptransform_test.py b/sdks/python/apache_beam/transforms/write_ptransform_test.py index ce402d8d3062..8525adb4b74a 100644 --- a/sdks/python/apache_beam/transforms/write_ptransform_test.py +++ b/sdks/python/apache_beam/transforms/write_ptransform_test.py @@ -45,7 +45,7 @@ def pre_finalize(self, init_result, writer_results): pass def finalize_write( - self, init_result, writer_results, unused_pre_finalize_result): + self, init_result, writer_results, unused_pre_finalize_result, unused_w): self.init_result_at_finalize = init_result self.write_results_at_finalize = writer_results