-
Notifications
You must be signed in to change notification settings - Fork 1.7k
AVRO-1759: [java] Automatic union types for sealed classes #3436
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 3 commits
fe2ddea
be55e25
479223a
65a6c97
b3a334b
59be88e
03e41e4
91d1ac4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -38,6 +38,7 @@ | |
|
|
||
| import java.io.IOException; | ||
| import java.lang.annotation.Annotation; | ||
| import java.lang.reflect.AnnotatedElement; | ||
| import java.lang.reflect.Constructor; | ||
| import java.lang.reflect.Field; | ||
| import java.lang.reflect.GenericArrayType; | ||
|
|
@@ -69,6 +70,24 @@ public class ReflectData extends SpecificData { | |
|
|
||
| private static final String STRING_OUTER_PARENT_REFERENCE = "this$0"; | ||
|
|
||
| private static final Method IS_SEALED_METHOD; | ||
| private static final Method GET_PERMITTED_SUBCLASSES_METHOD; | ||
|
|
||
| static { | ||
| Class<? extends Class> classClass = SpecificData.class.getClass(); | ||
| Method isSealed; | ||
| Method getPermittedSubclasses; | ||
| try { | ||
| isSealed = classClass.getMethod("isSealed"); | ||
| getPermittedSubclasses = classClass.getMethod("getPermittedSubclasses"); | ||
| } catch (NoSuchMethodException e) { | ||
ashley-taylor marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| isSealed = null; | ||
| getPermittedSubclasses = null; | ||
| } | ||
| IS_SEALED_METHOD = isSealed; | ||
| GET_PERMITTED_SUBCLASSES_METHOD = getPermittedSubclasses; | ||
| } | ||
|
|
||
| // holds a wrapper so null entries will have a cached value | ||
| private final ConcurrentMap<Schema, CustomEncodingWrapper> encoderCache = new ConcurrentHashMap<>(); | ||
|
|
||
|
|
@@ -705,7 +724,7 @@ protected Schema createSchema(Type type, Map<String, Schema> names) { | |
| String space = c.getPackage() == null ? "" : c.getPackage().getName(); | ||
| if (c.getEnclosingClass() != null) // nested class | ||
| space = c.getEnclosingClass().getName().replace('$', '.'); | ||
| Union union = c.getAnnotation(Union.class); | ||
| Class[] union = getUnion(c); | ||
| if (union != null) { // union annotated | ||
| return getAnnotatedUnion(union, names); | ||
| } else if (isStringable(c)) { // Stringable | ||
|
|
@@ -811,10 +830,29 @@ private void setElement(Schema schema, Type element) { | |
| schema.addProp(ELEMENT_PROP, c.getName()); | ||
| } | ||
|
|
||
| private Class[] getUnion(AnnotatedElement element) { | ||
| Union union = element.getAnnotation(Union.class); | ||
| if (union != null) { | ||
| return union.value(); | ||
| } | ||
|
|
||
| if (element instanceof Class) { | ||
| // automatic sealed class polymorphic | ||
| try { | ||
| if (IS_SEALED_METHOD != null && Boolean.TRUE.equals(IS_SEALED_METHOD.invoke(element))) { | ||
| return (Class<?>[]) GET_PERMITTED_SUBCLASSES_METHOD.invoke(element); | ||
|
||
| } | ||
| } catch (ReflectiveOperationException e) { | ||
| throw new AvroRuntimeException(e); | ||
| } | ||
| } | ||
| return null; | ||
| } | ||
|
|
||
| // construct a schema from a union annotation | ||
| private Schema getAnnotatedUnion(Union union, Map<String, Schema> names) { | ||
| private Schema getAnnotatedUnion(Class[] union, Map<String, Schema> names) { | ||
| List<Schema> branches = new ArrayList<>(); | ||
| for (Class branch : union.value()) | ||
| for (Class branch : union) | ||
| branches.add(createSchema(branch, names)); | ||
| return Schema.createUnion(branches); | ||
| } | ||
|
|
@@ -881,7 +919,7 @@ protected Schema createFieldSchema(Field field, Map<String, Schema> names) { | |
|
|
||
| Union union = field.getAnnotation(Union.class); | ||
| if (union != null) | ||
| return getAnnotatedUnion(union, names); | ||
| return getAnnotatedUnion(union.value(), names); | ||
|
|
||
| Schema schema = createSchema(field.getGenericType(), names); | ||
| if (field.isAnnotationPresent(Stringable.class)) { // Stringable | ||
|
|
@@ -928,7 +966,7 @@ private Message getMessage(Method method, Protocol protocol, Map<String, Schema> | |
| if (annotation instanceof AvroSchema) // explicit schema | ||
| paramSchema = new Schema.Parser().parse(((AvroSchema) annotation).value()); | ||
| else if (annotation instanceof Union) // union | ||
| paramSchema = getAnnotatedUnion(((Union) annotation), names); | ||
| paramSchema = getAnnotatedUnion(((Union) annotation).value(), names); | ||
| else if (annotation instanceof Nullable) // nullable | ||
| paramSchema = makeNullable(paramSchema); | ||
| } | ||
|
|
@@ -940,7 +978,7 @@ else if (annotation instanceof Nullable) // nullable | |
| Type genericReturnType = method.getGenericReturnType(); | ||
| Type returnType = genericTypeMap.getOrDefault(genericReturnType, genericReturnType); | ||
| Union union = method.getAnnotation(Union.class); | ||
| Schema response = union == null ? getSchema(returnType, names) : getAnnotatedUnion(union, names); | ||
| Schema response = union == null ? getSchema(returnType, names) : getAnnotatedUnion(union.value(), names); | ||
| if (method.isAnnotationPresent(Nullable.class)) // nullable | ||
| response = makeNullable(response); | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,149 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one | ||
| * or more contributor license agreements. See the NOTICE file | ||
| * distributed with this work for additional information | ||
| * regarding copyright ownership. The ASF licenses this file | ||
| * to you under the Apache License, Version 2.0 (the | ||
| * "License"); you may not use this file except in compliance | ||
| * with the License. You may obtain a copy of the License at | ||
| * | ||
| * https://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| package org.apache.avro.reflect; | ||
|
|
||
| import static org.junit.Assert.assertEquals; | ||
|
|
||
| import java.io.ByteArrayInputStream; | ||
| import java.io.ByteArrayOutputStream; | ||
| import java.io.IOException; | ||
| import java.io.UncheckedIOException; | ||
| import java.util.ArrayList; | ||
| import java.util.Arrays; | ||
| import java.util.List; | ||
| import java.util.Objects; | ||
|
|
||
| import org.apache.avro.Schema; | ||
| import org.apache.avro.file.DataFileStream; | ||
| import org.apache.avro.file.DataFileWriter; | ||
| import org.apache.avro.io.DatumReader; | ||
| import org.junit.Test; | ||
|
|
||
| public class TestPolymorphicEncoding { | ||
|
|
||
| @Test | ||
| public void testPolymorphicEncoding() throws IOException { | ||
| List<Animal> expected = Arrays.asList(new Cat("Green"), new Dog(5)); | ||
| byte[] encoded = write(Animal.class, expected); | ||
| List<Animal> decoded = read(encoded); | ||
|
|
||
| assertEquals(expected, decoded); | ||
| } | ||
|
|
||
| private <T> List<T> read(byte[] toDecode) throws IOException { | ||
| DatumReader<T> datumReader = new ReflectDatumReader<>(); | ||
| try (DataFileStream<T> dataFileReader = new DataFileStream<>(new ByteArrayInputStream(toDecode, 0, toDecode.length), | ||
| datumReader);) { | ||
ashley-taylor marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| List<T> toReturn = new ArrayList<>(); | ||
| while (dataFileReader.hasNext()) { | ||
| toReturn.add(dataFileReader.next()); | ||
| } | ||
| return toReturn; | ||
| } | ||
| } | ||
|
|
||
| private <T> byte[] write(Class<?> type, List<T> custom) { | ||
| Schema schema = ReflectData.get().getSchema(type); | ||
| ReflectDatumWriter<T> datumWriter = new ReflectDatumWriter<>(); | ||
| try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); | ||
| DataFileWriter<T> writer = new DataFileWriter<>(datumWriter)) { | ||
| writer.create(schema, baos); | ||
| for (T c : custom) { | ||
| writer.append(c); | ||
| } | ||
| writer.flush(); | ||
| return baos.toByteArray(); | ||
| } catch (IOException e) { | ||
| throw new UncheckedIOException(e); | ||
| } | ||
| } | ||
|
|
||
| public static sealed interface Animal permits Cat,Dog { | ||
| } | ||
|
|
||
| public static final class Dog implements Animal { | ||
|
|
||
| private int size; | ||
|
|
||
| public Dog() { | ||
| } | ||
|
|
||
| public Dog(int size) { | ||
| this.size = size; | ||
| } | ||
|
|
||
| public int getSize() { | ||
| return size; | ||
| } | ||
|
|
||
| @Override | ||
| public int hashCode() { | ||
| return Objects.hash(size); | ||
| } | ||
|
|
||
| @Override | ||
| public boolean equals(Object obj) { | ||
| if (this == obj) | ||
| return true; | ||
| if (obj == null) | ||
| return false; | ||
| if (getClass() != obj.getClass()) | ||
| return false; | ||
| Dog other = (Dog) obj; | ||
| return size == other.size; | ||
| } | ||
|
|
||
| } | ||
|
|
||
| public static final class Cat implements Animal { | ||
|
|
||
| private String color; | ||
|
|
||
| public Cat() { | ||
| } | ||
|
|
||
| public Cat(String color) { | ||
| super(); | ||
| this.color = color; | ||
| } | ||
|
|
||
| public String getColor() { | ||
| return color; | ||
| } | ||
|
|
||
| @Override | ||
| public int hashCode() { | ||
| return Objects.hash(color); | ||
| } | ||
|
|
||
| @Override | ||
| public boolean equals(Object obj) { | ||
| if (this == obj) | ||
| return true; | ||
| if (obj == null) | ||
| return false; | ||
| if (getClass() != obj.getClass()) | ||
| return false; | ||
| Cat other = (Cat) obj; | ||
| return Objects.equals(color, other.color); | ||
| } | ||
|
|
||
| } | ||
|
|
||
| } | ||
Uh oh!
There was an error while loading. Please reload this page.