diff --git a/sdks/python/apache_beam/io/avroio.py b/sdks/python/apache_beam/io/avroio.py index 3438cb5d61fe..553b6c741f3d 100644 --- a/sdks/python/apache_beam/io/avroio.py +++ b/sdks/python/apache_beam/io/avroio.py @@ -543,7 +543,7 @@ def avro_union_type_to_beam_type(union_type: List) -> schema_pb2.FieldType: """convert an avro union type to a beam type if the union type is a nullable, and it is a nullable union of an avro - primitive with a corresponding beam primitive then create a nullable beam + type with a corresponding beam type then create a nullable beam field of the corresponding beam type, otherwise return an Any type. Args: @@ -554,11 +554,10 @@ def avro_union_type_to_beam_type(union_type: List) -> schema_pb2.FieldType: """ if len(union_type) == 2 and "null" in union_type: for avro_type in union_type: - if avro_type in AVRO_PRIMITIVES_TO_BEAM_PRIMITIVES: - return schema_pb2.FieldType( - atomic_type=AVRO_PRIMITIVES_TO_BEAM_PRIMITIVES[avro_type], - nullable=True) - return schemas.typing_to_runner_api(Any) + if avro_type != "null": + beam_type = avro_type_to_beam_type(avro_type) + beam_type.nullable = True + return beam_type return schemas.typing_to_runner_api(Any) diff --git a/sdks/python/apache_beam/io/avroio_test.py b/sdks/python/apache_beam/io/avroio_test.py index 633b1307eb45..6dd9e620c665 100644 --- a/sdks/python/apache_beam/io/avroio_test.py +++ b/sdks/python/apache_beam/io/avroio_test.py @@ -206,6 +206,54 @@ def test_avro_union_type_to_beam_type_with_string_long(self): expected_beam_type = schemas.typing_to_runner_api(Any) hc.assert_that(beam_type, hc.equal_to(expected_beam_type)) + def test_avro_union_type_to_beam_type_with_record_and_null(self): + record_type = { + 'type': 'record', + 'name': 'TestRecord', + 'fields': [{ + 'name': 'field1', 'type': 'string' + }, { + 'name': 'field2', 'type': 'int' + }] + } + union_type = [record_type, 'null'] + beam_type = avro_union_type_to_beam_type(union_type) + expected_beam_type = schema_pb2.FieldType( + row_type=schema_pb2.RowType( + schema=schema_pb2.Schema( + fields=[ + schemas.schema_field( + 'field1', + schema_pb2.FieldType(atomic_type=schema_pb2.STRING)), + schemas.schema_field( + 'field2', + schema_pb2.FieldType(atomic_type=schema_pb2.INT32)) + ])), + nullable=True) + hc.assert_that(beam_type, hc.equal_to(expected_beam_type)) + + def test_avro_union_type_to_beam_type_with_nullable_annotated_string(self): + annotated_string_type = {"avro.java.string": "String", "type": "string"} + union_type = ['null', annotated_string_type] + + beam_type = avro_union_type_to_beam_type(union_type) + + expected_beam_type = schema_pb2.FieldType( + atomic_type=schema_pb2.STRING, nullable=True) + hc.assert_that(beam_type, hc.equal_to(expected_beam_type)) + + def test_avro_union_type_to_beam_type_with_only_null(self): + union_type = ['null'] + beam_type = avro_union_type_to_beam_type(union_type) + expected_beam_type = schemas.typing_to_runner_api(Any) + hc.assert_that(beam_type, hc.equal_to(expected_beam_type)) + + def test_avro_union_type_to_beam_type_with_multiple_types(self): + union_type = ['null', 'string', 'int'] + beam_type = avro_union_type_to_beam_type(union_type) + expected_beam_type = schemas.typing_to_runner_api(Any) + hc.assert_that(beam_type, hc.equal_to(expected_beam_type)) + def test_avro_schema_to_beam_and_back(self): avro_schema = fastavro.parse_schema(json.loads(self.SCHEMA_STRING)) beam_schema = avro_schema_to_beam_schema(avro_schema)