Skip to content

Commit d26dbac

Browse files
authored
Add support for PROTO format in YAML Pub/Sub transform (#36185)
* Add support for PROTO format in YAML Pub/Sub transform * Remove unused import of schema_utils in yaml_io.py and update YamlPubSubTest to use named_fields_to_schema for RowCoder. * Rename test_rw_proto to test_write_proto and add test_read_proto for PROTO format handling in YamlPubSubTest. * lints
1 parent df255a3 commit d26dbac

File tree

2 files changed

+56
-2
lines changed

2 files changed

+56
-2
lines changed

sdks/python/apache_beam/yaml/yaml_io.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
import apache_beam as beam
3636
import apache_beam.io as beam_io
3737
from apache_beam import coders
38+
from apache_beam.coders.row_coder import RowCoder
3839
from apache_beam.io import ReadFromBigQuery
3940
from apache_beam.io import ReadFromTFRecord
4041
from apache_beam.io import WriteToBigQuery
@@ -247,6 +248,10 @@ def _validate_schema():
247248
beam_schema,
248249
lambda record: covert_to_row(
249250
fastavro.schemaless_reader(io.BytesIO(record), schema))) # type: ignore[call-arg]
251+
elif format == 'PROTO':
252+
_validate_schema()
253+
beam_schema = json_utils.json_schema_to_beam_schema(schema)
254+
return beam_schema, RowCoder(beam_schema).decode
250255
else:
251256
raise ValueError(f'Unknown format: {format}')
252257

@@ -291,6 +296,8 @@ def formatter(row):
291296
return buffer.read()
292297

293298
return formatter
299+
elif format == 'PROTO':
300+
return RowCoder(beam_schema).encode
294301
else:
295302
raise ValueError(f'Unknown format: {format}')
296303

@@ -416,7 +423,7 @@ def write_to_pubsub(
416423
417424
Args:
418425
topic: Cloud Pub/Sub topic in the form "/topics/<project>/<topic>".
419-
format: How to format the message payload. Currently suported
426+
format: How to format the message payload. Currently supported
420427
formats are
421428
422429
- RAW: Expects a message with a single field (excluding
@@ -426,6 +433,8 @@ def write_to_pubsub(
426433
from the input PCollection schema.
427434
- JSON: Formats records with a given JSON schema, which may be inferred
428435
from the input PCollection schema.
436+
- PROTO: Encodes records with a given Protobuf schema, which may be
437+
inferred from the input PCollection schema.
429438
430439
schema: Schema specification for the given format.
431440
attributes: List of attribute keys whose values will be pulled out as
@@ -633,7 +642,7 @@ def read_from_tfrecord(
633642
compression_type (CompressionTypes): Used to handle compressed input files.
634643
Default value is CompressionTypes.AUTO, in which case the file_path's
635644
extension will be used to detect the compression.
636-
validate (bool): Boolean flag to verify that the files exist during the
645+
validate (bool): Boolean flag to verify that the files exist during the
637646
pipeline creation time.
638647
"""
639648
return ReadFromTFRecord(

sdks/python/apache_beam/yaml/yaml_io_test.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,12 @@
2424
import mock
2525

2626
import apache_beam as beam
27+
from apache_beam.coders.row_coder import RowCoder
2728
from apache_beam.io.gcp.pubsub import PubsubMessage
2829
from apache_beam.testing.util import AssertThat
2930
from apache_beam.testing.util import assert_that
3031
from apache_beam.testing.util import equal_to
32+
from apache_beam.typehints import schemas as schema_utils
3133
from apache_beam.yaml.yaml_transform import YamlTransform
3234

3335

@@ -491,6 +493,49 @@ def test_write_json(self):
491493
attributes_map: other
492494
'''))
493495

496+
def test_write_proto(self):
497+
with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions(
498+
pickle_library='cloudpickle')) as p:
499+
data = [beam.Row(label='37a', rank=1), beam.Row(label='389a', rank=2)]
500+
coder = RowCoder(
501+
schema_utils.named_fields_to_schema([('label', str), ('rank', int)]))
502+
expected_messages = [PubsubMessage(coder.encode(r), {}) for r in data]
503+
with mock.patch('apache_beam.io.WriteToPubSub',
504+
FakeWriteToPubSub(topic='my_topic',
505+
messages=expected_messages)):
506+
_ = (
507+
p | beam.Create(data) | YamlTransform(
508+
'''
509+
type: WriteToPubSub
510+
config:
511+
topic: my_topic
512+
format: PROTO
513+
'''))
514+
515+
def test_read_proto(self):
516+
with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions(
517+
pickle_library='cloudpickle')) as p:
518+
data = [beam.Row(label='37a', rank=1), beam.Row(label='389a', rank=2)]
519+
coder = RowCoder(
520+
schema_utils.named_fields_to_schema([('label', str), ('rank', int)]))
521+
expected_messages = [PubsubMessage(coder.encode(r), {}) for r in data]
522+
with mock.patch('apache_beam.io.ReadFromPubSub',
523+
FakeReadFromPubSub(topic='my_topic',
524+
messages=expected_messages)):
525+
result = p | YamlTransform(
526+
'''
527+
type: ReadFromPubSub
528+
config:
529+
topic: my_topic
530+
format: PROTO
531+
schema:
532+
type: object
533+
properties:
534+
label: {type: string}
535+
rank: {type: integer}
536+
''')
537+
assert_that(result, equal_to(data))
538+
494539

495540
if __name__ == '__main__':
496541
logging.getLogger().setLevel(logging.INFO)

0 commit comments

Comments
 (0)