diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/AutoValueSchema.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/AutoValueSchema.java index 5ccfe39b92af..f35782c2b9a2 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/AutoValueSchema.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/AutoValueSchema.java @@ -19,7 +19,6 @@ import java.lang.reflect.Method; import java.lang.reflect.Modifier; -import java.util.Comparator; import java.util.List; import java.util.stream.Collectors; import org.apache.beam.sdk.schemas.annotations.SchemaIgnore; @@ -32,13 +31,10 @@ import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; +import org.checkerframework.checker.nullness.qual.NonNull; import org.checkerframework.checker.nullness.qual.Nullable; /** A {@link SchemaProvider} for AutoValue classes. */ -@SuppressWarnings({ - "nullness", // TODO(https://github.com/apache/beam/issues/20497) - "rawtypes" -}) public class AutoValueSchema extends GetterBasedSchemaProviderV2 { /** {@link FieldValueTypeSupplier} that's based on AutoValue getters. */ @VisibleForTesting @@ -49,7 +45,11 @@ public static class AbstractGetterTypeSupplier implements FieldValueTypeSupplier public List get(TypeDescriptor typeDescriptor) { // If the generated class is passed in, we want to look at the base class to find the getters. - TypeDescriptor targetTypeDescriptor = AutoValueUtils.getBaseAutoValueClass(typeDescriptor); + TypeDescriptor targetTypeDescriptor = + Preconditions.checkNotNull( + AutoValueUtils.getBaseAutoValueClass(typeDescriptor), + "unable to determine base AutoValue class for type {}", + typeDescriptor); List methods = ReflectUtils.getMethods(targetTypeDescriptor.getRawType()).stream() @@ -62,9 +62,9 @@ public List get(TypeDescriptor typeDescriptor) { .collect(Collectors.toList()); List types = Lists.newArrayListWithCapacity(methods.size()); for (int i = 0; i < methods.size(); ++i) { - types.add(FieldValueTypeInformation.forGetter(methods.get(i), i)); + types.add(FieldValueTypeInformation.forGetter(typeDescriptor, methods.get(i), i)); } - types.sort(Comparator.comparing(FieldValueTypeInformation::getNumber)); + types.sort(JavaBeanUtils.comparingNullFirst(FieldValueTypeInformation::getNumber)); validateFieldNumbers(types); return types; } @@ -89,8 +89,8 @@ private static void validateFieldNumbers(List types) } @Override - public List fieldValueGetters( - TypeDescriptor targetTypeDescriptor, Schema schema) { + public List> fieldValueGetters( + TypeDescriptor targetTypeDescriptor, Schema schema) { return JavaBeanUtils.getGetters( targetTypeDescriptor, schema, diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/CachingFactory.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/CachingFactory.java index 8725833bc1da..6e244fefb263 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/CachingFactory.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/CachingFactory.java @@ -20,6 +20,9 @@ import java.util.Objects; import java.util.concurrent.ConcurrentHashMap; import org.apache.beam.sdk.values.TypeDescriptor; +import org.checkerframework.checker.initialization.qual.NotOnlyInitialized; +import org.checkerframework.checker.initialization.qual.UnknownInitialization; +import org.checkerframework.checker.nullness.qual.NonNull; import org.checkerframework.checker.nullness.qual.Nullable; /** @@ -32,24 +35,25 @@ * significant for larger schemas) on each lookup. This wrapper caches the value returned by the * inner factory, so the schema comparison only need happen on the first lookup. */ -@SuppressWarnings({ - "nullness", // TODO(https://github.com/apache/beam/issues/20497) - "rawtypes" -}) -public class CachingFactory implements Factory { +public class CachingFactory implements Factory { private transient @Nullable ConcurrentHashMap, CreatedT> cache = null; - private final Factory innerFactory; + private final @NotOnlyInitialized Factory innerFactory; - public CachingFactory(Factory innerFactory) { + public CachingFactory(@UnknownInitialization Factory innerFactory) { this.innerFactory = innerFactory; } - @Override - public CreatedT create(TypeDescriptor typeDescriptor, Schema schema) { + private ConcurrentHashMap, CreatedT> getCache() { if (cache == null) { cache = new ConcurrentHashMap<>(); } + return cache; + } + + @Override + public CreatedT create(TypeDescriptor typeDescriptor, Schema schema) { + ConcurrentHashMap, CreatedT> cache = getCache(); CreatedT cached = cache.get(typeDescriptor); if (cached != null) { return cached; diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/FieldValueGetter.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/FieldValueGetter.java index fb98db8e8343..63ab56dc7609 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/FieldValueGetter.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/FieldValueGetter.java @@ -19,6 +19,7 @@ import java.io.Serializable; import org.apache.beam.sdk.annotations.Internal; +import org.checkerframework.checker.nullness.qual.NonNull; import org.checkerframework.checker.nullness.qual.Nullable; /** @@ -29,7 +30,7 @@ *

Implementations of this interface are generated at runtime to map object fields to Row fields. */ @Internal -public interface FieldValueGetter extends Serializable { +public interface FieldValueGetter extends Serializable { @Nullable ValueT get(ObjectT object); diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/FieldValueTypeInformation.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/FieldValueTypeInformation.java index 750709192c08..43aac6a5e20c 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/FieldValueTypeInformation.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/FieldValueTypeInformation.java @@ -27,7 +27,9 @@ import java.util.Arrays; import java.util.Collections; import java.util.Map; +import java.util.Optional; import java.util.stream.Stream; +import org.apache.beam.sdk.annotations.Internal; import org.apache.beam.sdk.schemas.annotations.SchemaCaseFormat; import org.apache.beam.sdk.schemas.annotations.SchemaFieldDescription; import org.apache.beam.sdk.schemas.annotations.SchemaFieldName; @@ -40,10 +42,7 @@ /** Represents type information for a Java type that will be used to infer a Schema type. */ @AutoValue -@SuppressWarnings({ - "nullness", // TODO(https://github.com/apache/beam/issues/20497) - "rawtypes" -}) +@Internal public abstract class FieldValueTypeInformation implements Serializable { /** Optionally returns the field index. */ public abstract @Nullable Integer getNumber(); @@ -125,8 +124,13 @@ public static FieldValueTypeInformation forOneOf( .build(); } - public static FieldValueTypeInformation forField(Field field, int index) { - TypeDescriptor type = TypeDescriptor.of(field.getGenericType()); + public static FieldValueTypeInformation forField( + @Nullable TypeDescriptor typeDescriptor, Field field, int index) { + TypeDescriptor type = + Optional.ofNullable(typeDescriptor) + .map(td -> (TypeDescriptor) td.resolveType(field.getGenericType())) + // fall back to previous behavior + .orElseGet(() -> TypeDescriptor.of(field.getGenericType())); return new AutoValue_FieldValueTypeInformation.Builder() .setName(getNameOverride(field.getName(), field)) .setNumber(getNumberOverride(index, field)) @@ -134,9 +138,9 @@ public static FieldValueTypeInformation forField(Field field, int index) { .setType(type) .setRawType(type.getRawType()) .setField(field) - .setElementType(getIterableComponentType(field)) - .setMapKeyType(getMapKeyType(field)) - .setMapValueType(getMapValueType(field)) + .setElementType(getIterableComponentType(type)) + .setMapKeyType(getMapKeyType(type)) + .setMapValueType(getMapValueType(type)) .setOneOfTypes(Collections.emptyMap()) .setDescription(getFieldDescription(field)) .build(); @@ -185,6 +189,11 @@ public static String getNameOverride( } public static FieldValueTypeInformation forGetter(Method method, int index) { + return forGetter(null, method, index); + } + + public static FieldValueTypeInformation forGetter( + @Nullable TypeDescriptor typeDescriptor, Method method, int index) { String name; if (method.getName().startsWith("get")) { name = ReflectUtils.stripPrefix(method.getName(), "get"); @@ -194,7 +203,12 @@ public static FieldValueTypeInformation forGetter(Method method, int index) { throw new RuntimeException("Getter has wrong prefix " + method.getName()); } - TypeDescriptor type = TypeDescriptor.of(method.getGenericReturnType()); + TypeDescriptor type = + Optional.ofNullable(typeDescriptor) + .map(td -> (TypeDescriptor) td.resolveType(method.getGenericReturnType())) + // fall back to previous behavior + .orElseGet(() -> TypeDescriptor.of(method.getGenericReturnType())); + boolean nullable = hasNullableReturnType(method); return new AutoValue_FieldValueTypeInformation.Builder() .setName(getNameOverride(name, method)) @@ -253,10 +267,20 @@ private static boolean isNullableAnnotation(Annotation annotation) { } public static FieldValueTypeInformation forSetter(Method method) { - return forSetter(method, "set"); + return forSetter(null, method); } public static FieldValueTypeInformation forSetter(Method method, String setterPrefix) { + return forSetter(null, method, setterPrefix); + } + + public static FieldValueTypeInformation forSetter( + @Nullable TypeDescriptor typeDescriptor, Method method) { + return forSetter(typeDescriptor, method, "set"); + } + + public static FieldValueTypeInformation forSetter( + @Nullable TypeDescriptor typeDescriptor, Method method, String setterPrefix) { String name; if (method.getName().startsWith(setterPrefix)) { name = ReflectUtils.stripPrefix(method.getName(), setterPrefix); @@ -264,7 +288,11 @@ public static FieldValueTypeInformation forSetter(Method method, String setterPr throw new RuntimeException("Setter has wrong prefix " + method.getName()); } - TypeDescriptor type = TypeDescriptor.of(method.getGenericParameterTypes()[0]); + TypeDescriptor type = + Optional.ofNullable(typeDescriptor) + .map(td -> (TypeDescriptor) td.resolveType(method.getGenericParameterTypes()[0])) + // fall back to previous behavior + .orElseGet(() -> TypeDescriptor.of(method.getGenericParameterTypes()[0])); boolean nullable = hasSingleNullableParameter(method); return new AutoValue_FieldValueTypeInformation.Builder() .setName(name) @@ -283,10 +311,6 @@ public FieldValueTypeInformation withName(String name) { return toBuilder().setName(name).build(); } - private static FieldValueTypeInformation getIterableComponentType(Field field) { - return getIterableComponentType(TypeDescriptor.of(field.getGenericType())); - } - static @Nullable FieldValueTypeInformation getIterableComponentType(TypeDescriptor valueType) { // TODO: Figure out nullable elements. TypeDescriptor componentType = ReflectUtils.getIterableComponentType(valueType); @@ -306,23 +330,13 @@ private static FieldValueTypeInformation getIterableComponentType(Field field) { .build(); } - // If the Field is a map type, returns the key type, otherwise returns a null reference. - - private static @Nullable FieldValueTypeInformation getMapKeyType(Field field) { - return getMapKeyType(TypeDescriptor.of(field.getGenericType())); - } - + // If the type is a map type, returns the key type, otherwise returns a null reference. private static @Nullable FieldValueTypeInformation getMapKeyType( TypeDescriptor typeDescriptor) { return getMapType(typeDescriptor, 0); } - // If the Field is a map type, returns the value type, otherwise returns a null reference. - - private static @Nullable FieldValueTypeInformation getMapValueType(Field field) { - return getMapType(TypeDescriptor.of(field.getGenericType()), 1); - } - + // If the type is a map type, returns the value type, otherwise returns a null reference. private static @Nullable FieldValueTypeInformation getMapValueType( TypeDescriptor typeDescriptor) { return getMapType(typeDescriptor, 1); @@ -330,10 +344,9 @@ private static FieldValueTypeInformation getIterableComponentType(Field field) { // If the Field is a map type, returns the key or value type (0 is key type, 1 is value). // Otherwise returns a null reference. - @SuppressWarnings("unchecked") private static @Nullable FieldValueTypeInformation getMapType( TypeDescriptor valueType, int index) { - TypeDescriptor mapType = ReflectUtils.getMapType(valueType, index); + TypeDescriptor mapType = ReflectUtils.getMapType(valueType, index); if (mapType == null) { return null; } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/GetterBasedSchemaProvider.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/GetterBasedSchemaProvider.java index ce5be71933b8..4e431bb45207 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/GetterBasedSchemaProvider.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/GetterBasedSchemaProvider.java @@ -17,13 +17,12 @@ */ package org.apache.beam.sdk.schemas; -import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkState; - import java.util.ArrayList; import java.util.Collection; import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.Optional; import org.apache.beam.sdk.schemas.Schema.FieldType; import org.apache.beam.sdk.schemas.Schema.LogicalType; import org.apache.beam.sdk.schemas.Schema.TypeName; @@ -32,10 +31,13 @@ import org.apache.beam.sdk.transforms.SerializableFunction; import org.apache.beam.sdk.values.Row; import org.apache.beam.sdk.values.TypeDescriptor; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Verify; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Collections2; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Maps; +import org.checkerframework.checker.initialization.qual.NotOnlyInitialized; +import org.checkerframework.checker.nullness.qual.NonNull; import org.checkerframework.checker.nullness.qual.Nullable; /** @@ -46,10 +48,7 @@ * methods which receive {@link TypeDescriptor}s instead of ordinary {@link Class}es as * arguments, which permits to support generic type signatures during schema inference */ -@SuppressWarnings({ - "nullness", // TODO(https://github.com/apache/beam/issues/20497) - "rawtypes" -}) +@SuppressWarnings({"rawtypes"}) @Deprecated public abstract class GetterBasedSchemaProvider implements SchemaProvider { @@ -67,9 +66,9 @@ public abstract class GetterBasedSchemaProvider implements SchemaProvider { * override it if you want to use the richer type signature contained in the {@link * TypeDescriptor} not subject to the type erasure. */ - public List fieldValueGetters( - TypeDescriptor targetTypeDescriptor, Schema schema) { - return fieldValueGetters(targetTypeDescriptor.getRawType(), schema); + public List> fieldValueGetters( + TypeDescriptor targetTypeDescriptor, Schema schema) { + return (List) fieldValueGetters(targetTypeDescriptor.getRawType(), schema); } /** @@ -112,9 +111,10 @@ public SchemaUserTypeCreator schemaTypeCreator( return schemaTypeCreator(targetTypeDescriptor.getRawType(), schema); } - private class ToRowWithValueGetters implements SerializableFunction { + private class ToRowWithValueGetters + implements SerializableFunction { private final Schema schema; - private final Factory> getterFactory; + private final Factory>> getterFactory; public ToRowWithValueGetters(Schema schema) { this.schema = schema; @@ -122,7 +122,12 @@ public ToRowWithValueGetters(Schema schema) { // schema, return a caching factory that caches the first value seen for each class. This // prevents having to lookup the getter list each time createGetters is called. this.getterFactory = - RowValueGettersFactory.of(GetterBasedSchemaProvider.this::fieldValueGetters); + RowValueGettersFactory.of( + (Factory>>) + (typeDescriptor, schema1) -> + (List) + GetterBasedSchemaProvider.this.fieldValueGetters( + typeDescriptor, schema1)); } @Override @@ -160,13 +165,15 @@ public SerializableFunction toRowFunction(TypeDescriptor typeDesc // important to capture the schema once here, so all invocations of the toRowFunction see the // same version of the schema. If schemaFor were to be called inside the lambda below, different // workers would see different versions of the schema. - Schema schema = schemaFor(typeDescriptor); + @NonNull + Schema schema = + Verify.verifyNotNull( + schemaFor(typeDescriptor), "can't create a ToRowFunction with null schema"); return new ToRowWithValueGetters<>(schema); } @Override - @SuppressWarnings("unchecked") public SerializableFunction fromRowFunction(TypeDescriptor typeDescriptor) { return new FromRowUsingCreator<>(typeDescriptor, this); } @@ -181,23 +188,27 @@ public boolean equals(@Nullable Object obj) { return obj != null && this.getClass() == obj.getClass(); } - private static class RowValueGettersFactory implements Factory> { - private final Factory> gettersFactory; - private final Factory> cachingGettersFactory; + private static class RowValueGettersFactory + implements Factory>> { + private final Factory>> gettersFactory; + private final @NotOnlyInitialized Factory>> + cachingGettersFactory; - static Factory> of(Factory> gettersFactory) { - return new RowValueGettersFactory(gettersFactory).cachingGettersFactory; + static Factory>> of( + Factory>> gettersFactory) { + return new RowValueGettersFactory<>(gettersFactory).cachingGettersFactory; } - RowValueGettersFactory(Factory> gettersFactory) { + RowValueGettersFactory(Factory>> gettersFactory) { this.gettersFactory = gettersFactory; this.cachingGettersFactory = new CachingFactory<>(this); } @Override - public List create(TypeDescriptor typeDescriptor, Schema schema) { - List getters = gettersFactory.create(typeDescriptor, schema); - List rowGetters = new ArrayList<>(getters.size()); + public List> create( + TypeDescriptor typeDescriptor, Schema schema) { + List> getters = gettersFactory.create(typeDescriptor, schema); + List> rowGetters = new ArrayList<>(getters.size()); for (int i = 0; i < getters.size(); i++) { rowGetters.add(rowValueGetter(getters.get(i), schema.getField(i).getType())); } @@ -209,71 +220,80 @@ static boolean needsConversion(FieldType type) { return typeName.equals(TypeName.ROW) || typeName.isLogicalType() || ((typeName.equals(TypeName.ARRAY) || typeName.equals(TypeName.ITERABLE)) - && needsConversion(type.getCollectionElementType())) + && needsConversion(Verify.verifyNotNull(type.getCollectionElementType()))) || (typeName.equals(TypeName.MAP) - && (needsConversion(type.getMapKeyType()) - || needsConversion(type.getMapValueType()))); + && (needsConversion(Verify.verifyNotNull(type.getMapKeyType())) + || needsConversion(Verify.verifyNotNull(type.getMapValueType())))); } - FieldValueGetter rowValueGetter(FieldValueGetter base, FieldType type) { + FieldValueGetter rowValueGetter(FieldValueGetter base, FieldType type) { TypeName typeName = type.getTypeName(); if (!needsConversion(type)) { return base; } if (typeName.equals(TypeName.ROW)) { - return new GetRow(base, type.getRowSchema(), cachingGettersFactory); + return new GetRow(base, Verify.verifyNotNull(type.getRowSchema()), cachingGettersFactory); } else if (typeName.equals(TypeName.ARRAY)) { - FieldType elementType = type.getCollectionElementType(); + FieldType elementType = Verify.verifyNotNull(type.getCollectionElementType()); return elementType.getTypeName().equals(TypeName.ROW) ? new GetEagerCollection(base, converter(elementType)) : new GetCollection(base, converter(elementType)); } else if (typeName.equals(TypeName.ITERABLE)) { - return new GetIterable(base, converter(type.getCollectionElementType())); + return new GetIterable( + base, converter(Verify.verifyNotNull(type.getCollectionElementType()))); } else if (typeName.equals(TypeName.MAP)) { - return new GetMap(base, converter(type.getMapKeyType()), converter(type.getMapValueType())); + return new GetMap( + base, + converter(Verify.verifyNotNull(type.getMapKeyType())), + converter(Verify.verifyNotNull(type.getMapValueType()))); } else if (type.isLogicalType(OneOfType.IDENTIFIER)) { OneOfType oneOfType = type.getLogicalType(OneOfType.class); Schema oneOfSchema = oneOfType.getOneOfSchema(); Map values = oneOfType.getCaseEnumType().getValuesMap(); - Map converters = Maps.newHashMapWithExpectedSize(values.size()); + Map> converters = + Maps.newHashMapWithExpectedSize(values.size()); for (Map.Entry kv : values.entrySet()) { FieldType fieldType = oneOfSchema.getField(kv.getKey()).getType(); - FieldValueGetter converter = converter(fieldType); + FieldValueGetter converter = converter(fieldType); converters.put(kv.getValue(), converter); } return new GetOneOf(base, converters, oneOfType); } else if (typeName.isLogicalType()) { - return new GetLogicalInputType(base, type.getLogicalType()); + return new GetLogicalInputType(base, Verify.verifyNotNull(type.getLogicalType())); } return base; } - FieldValueGetter converter(FieldType type) { + FieldValueGetter converter(FieldType type) { return rowValueGetter(IDENTITY, type); } - static class GetRow extends Converter { + static class GetRow + extends Converter { final Schema schema; - final Factory> factory; + final Factory>> factory; - GetRow(FieldValueGetter getter, Schema schema, Factory> factory) { + GetRow( + FieldValueGetter getter, + Schema schema, + Factory>> factory) { super(getter); this.schema = schema; this.factory = factory; } @Override - Object convert(Object value) { + Object convert(V value) { return Row.withSchema(schema).withFieldValueGetters(factory, value); } } - static class GetEagerCollection extends Converter { + static class GetEagerCollection extends Converter { final FieldValueGetter converter; - GetEagerCollection(FieldValueGetter getter, FieldValueGetter converter) { + GetEagerCollection(FieldValueGetter getter, FieldValueGetter converter) { super(getter); this.converter = converter; } @@ -288,15 +308,16 @@ Object convert(Collection collection) { } } - static class GetCollection extends Converter { + static class GetCollection extends Converter { final FieldValueGetter converter; - GetCollection(FieldValueGetter getter, FieldValueGetter converter) { + GetCollection(FieldValueGetter getter, FieldValueGetter converter) { super(getter); this.converter = converter; } @Override + @SuppressWarnings({"nullness"}) Object convert(Collection collection) { if (collection instanceof List) { // For performance reasons if the input is a list, make sure that we produce a list. @@ -309,45 +330,51 @@ Object convert(Collection collection) { } } - static class GetIterable extends Converter { + static class GetIterable extends Converter { final FieldValueGetter converter; - GetIterable(FieldValueGetter getter, FieldValueGetter converter) { + GetIterable(FieldValueGetter getter, FieldValueGetter converter) { super(getter); this.converter = converter; } @Override + @SuppressWarnings({"nullness"}) Object convert(Iterable value) { return Iterables.transform(value, converter::get); } } - static class GetMap extends Converter> { - final FieldValueGetter keyConverter; - final FieldValueGetter valueConverter; + static class GetMap + extends Converter> { + final FieldValueGetter<@NonNull K1, K2> keyConverter; + final FieldValueGetter<@NonNull V1, V2> valueConverter; GetMap( - FieldValueGetter getter, FieldValueGetter keyConverter, FieldValueGetter valueConverter) { + FieldValueGetter> getter, + FieldValueGetter<@NonNull K1, K2> keyConverter, + FieldValueGetter<@NonNull V1, V2> valueConverter) { super(getter); this.keyConverter = keyConverter; this.valueConverter = valueConverter; } @Override - Object convert(Map value) { - Map returnMap = Maps.newHashMapWithExpectedSize(value.size()); - for (Map.Entry entry : value.entrySet()) { - returnMap.put(keyConverter.get(entry.getKey()), valueConverter.get(entry.getValue())); + Map<@Nullable K2, @Nullable V2> convert(Map<@Nullable K1, @Nullable V1> value) { + Map<@Nullable K2, @Nullable V2> returnMap = Maps.newHashMapWithExpectedSize(value.size()); + for (Map.Entry<@Nullable K1, @Nullable V1> entry : value.entrySet()) { + returnMap.put( + Optional.ofNullable(entry.getKey()).map(keyConverter::get).orElse(null), + Optional.ofNullable(entry.getValue()).map(valueConverter::get).orElse(null)); } return returnMap; } } - static class GetLogicalInputType extends Converter { + static class GetLogicalInputType extends Converter { final LogicalType logicalType; - GetLogicalInputType(FieldValueGetter getter, LogicalType logicalType) { + GetLogicalInputType(FieldValueGetter getter, LogicalType logicalType) { super(getter); this.logicalType = logicalType; } @@ -359,12 +386,14 @@ Object convert(Object value) { } } - static class GetOneOf extends Converter { + static class GetOneOf extends Converter { final OneOfType oneOfType; - final Map converters; + final Map> converters; GetOneOf( - FieldValueGetter getter, Map converters, OneOfType oneOfType) { + FieldValueGetter getter, + Map> converters, + OneOfType oneOfType) { super(getter); this.converters = converters; this.oneOfType = oneOfType; @@ -373,24 +402,31 @@ static class GetOneOf extends Converter { @Override Object convert(OneOfType.Value value) { EnumerationType.Value caseType = value.getCaseType(); - FieldValueGetter converter = converters.get(caseType.getValue()); - checkState(converter != null, "Missing OneOf converter for case %s.", caseType); + + @NonNull + FieldValueGetter<@NonNull Object, Object> converter = + Verify.verifyNotNull( + converters.get(caseType.getValue()), + "Missing OneOf converter for case %s.", + caseType); + return oneOfType.createValue(caseType, converter.get(value.getValue())); } } - abstract static class Converter implements FieldValueGetter { - final FieldValueGetter getter; + abstract static class Converter + implements FieldValueGetter { + final FieldValueGetter getter; - public Converter(FieldValueGetter getter) { + public Converter(FieldValueGetter getter) { this.getter = getter; } - abstract Object convert(T value); + abstract Object convert(ValueT value); @Override - public @Nullable Object get(Object object) { - T value = (T) getter.get(object); + public @Nullable Object get(ObjectT object) { + ValueT value = getter.get(object); if (value == null) { return null; } @@ -398,7 +434,7 @@ public Converter(FieldValueGetter getter) { } @Override - public @Nullable Object getRaw(Object object) { + public @Nullable Object getRaw(ObjectT object) { return getter.getRaw(object); } @@ -408,16 +444,16 @@ public String name() { } } - private static final FieldValueGetter IDENTITY = - new FieldValueGetter() { + private static final FieldValueGetter<@NonNull Object, Object> IDENTITY = + new FieldValueGetter<@NonNull Object, Object>() { @Override - public @Nullable Object get(Object object) { + public Object get(@NonNull Object object) { return object; } @Override public String name() { - return null; + return "IDENTITY"; } }; } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/GetterBasedSchemaProviderV2.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/GetterBasedSchemaProviderV2.java index de31f9947c36..e7214d8f663a 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/GetterBasedSchemaProviderV2.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/GetterBasedSchemaProviderV2.java @@ -19,6 +19,7 @@ import java.util.List; import org.apache.beam.sdk.values.TypeDescriptor; +import org.checkerframework.checker.nullness.qual.NonNull; /** * A newer version of {@link GetterBasedSchemaProvider}, which works with {@link TypeDescriptor}s, @@ -28,12 +29,12 @@ public abstract class GetterBasedSchemaProviderV2 extends GetterBasedSchemaProvider { @Override public List fieldValueGetters(Class targetClass, Schema schema) { - return fieldValueGetters(TypeDescriptor.of(targetClass), schema); + return (List) fieldValueGetters(TypeDescriptor.of(targetClass), schema); } @Override - public abstract List fieldValueGetters( - TypeDescriptor targetTypeDescriptor, Schema schema); + public abstract List> fieldValueGetters( + TypeDescriptor targetTypeDescriptor, Schema schema); @Override public List fieldValueTypeInformations( diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/JavaBeanSchema.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/JavaBeanSchema.java index a9cf01c52057..14adf2f6603e 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/JavaBeanSchema.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/JavaBeanSchema.java @@ -19,7 +19,6 @@ import java.lang.reflect.Constructor; import java.lang.reflect.Method; -import java.util.Comparator; import java.util.List; import java.util.stream.Collectors; import org.apache.beam.sdk.schemas.annotations.SchemaCaseFormat; @@ -34,6 +33,7 @@ import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; +import org.checkerframework.checker.nullness.qual.NonNull; import org.checkerframework.checker.nullness.qual.Nullable; /** @@ -49,10 +49,7 @@ *

TODO: Validate equals() method is provided, and if not generate a "slow" equals method based * on the schema. */ -@SuppressWarnings({ - "nullness", // TODO(https://github.com/apache/beam/issues/20497) - "rawtypes" -}) +@SuppressWarnings({"rawtypes"}) public class JavaBeanSchema extends GetterBasedSchemaProviderV2 { /** {@link FieldValueTypeSupplier} that's based on getter methods. */ @VisibleForTesting @@ -68,9 +65,9 @@ public List get(TypeDescriptor typeDescriptor) { .collect(Collectors.toList()); List types = Lists.newArrayListWithCapacity(methods.size()); for (int i = 0; i < methods.size(); ++i) { - types.add(FieldValueTypeInformation.forGetter(methods.get(i), i)); + types.add(FieldValueTypeInformation.forGetter(typeDescriptor, methods.get(i), i)); } - types.sort(Comparator.comparing(FieldValueTypeInformation::getNumber)); + types.sort(JavaBeanUtils.comparingNullFirst(FieldValueTypeInformation::getNumber)); validateFieldNumbers(types); return types; } @@ -114,29 +111,32 @@ public List get(TypeDescriptor typeDescriptor) { return ReflectUtils.getMethods(typeDescriptor.getRawType()).stream() .filter(ReflectUtils::isSetter) .filter(m -> !m.isAnnotationPresent(SchemaIgnore.class)) - .map(FieldValueTypeInformation::forSetter) + .map(m -> FieldValueTypeInformation.forSetter(typeDescriptor, m)) .map( t -> { - if (t.getMethod().getAnnotation(SchemaFieldNumber.class) != null) { + Method m = + Preconditions.checkNotNull( + t.getMethod(), JavaBeanUtils.SETTER_WITH_NULL_METHOD_ERROR); + if (m.getAnnotation(SchemaFieldNumber.class) != null) { throw new RuntimeException( String.format( "@SchemaFieldNumber can only be used on getters in Java Beans. Found on" + " setter '%s'", - t.getMethod().getName())); + m.getName())); } - if (t.getMethod().getAnnotation(SchemaFieldName.class) != null) { + if (m.getAnnotation(SchemaFieldName.class) != null) { throw new RuntimeException( String.format( "@SchemaFieldName can only be used on getters in Java Beans. Found on" + " setter '%s'", - t.getMethod().getName())); + m.getName())); } - if (t.getMethod().getAnnotation(SchemaCaseFormat.class) != null) { + if (m.getAnnotation(SchemaCaseFormat.class) != null) { throw new RuntimeException( String.format( "@SchemaCaseFormat can only be used on getters in Java Beans. Found on" + " setter '%s'", - t.getMethod().getName())); + m.getName())); } return t; }) @@ -172,8 +172,8 @@ public Schema schemaFor(TypeDescriptor typeDescriptor) { } @Override - public List fieldValueGetters( - TypeDescriptor targetTypeDescriptor, Schema schema) { + public List> fieldValueGetters( + TypeDescriptor targetTypeDescriptor, Schema schema) { return JavaBeanUtils.getGetters( targetTypeDescriptor, schema, diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/JavaFieldSchema.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/JavaFieldSchema.java index 21f07c47b47f..9a8eef2bf2c8 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/JavaFieldSchema.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/JavaFieldSchema.java @@ -21,20 +21,22 @@ import java.lang.reflect.Field; import java.lang.reflect.Method; import java.lang.reflect.Modifier; -import java.util.Comparator; import java.util.List; import java.util.Optional; import java.util.stream.Collectors; +import java.util.stream.Stream; import javax.annotation.Nullable; import org.apache.beam.sdk.schemas.annotations.SchemaIgnore; import org.apache.beam.sdk.schemas.utils.ByteBuddyUtils.DefaultTypeConversionsFactory; import org.apache.beam.sdk.schemas.utils.FieldValueTypeSupplier; +import org.apache.beam.sdk.schemas.utils.JavaBeanUtils; import org.apache.beam.sdk.schemas.utils.POJOUtils; import org.apache.beam.sdk.schemas.utils.ReflectUtils; import org.apache.beam.sdk.values.TypeDescriptor; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; +import org.checkerframework.checker.nullness.qual.NonNull; /** * A {@link SchemaProvider} for Java POJO objects. @@ -49,7 +51,6 @@ *

TODO: Validate equals() method is provided, and if not generate a "slow" equals method based * on the schema. */ -@SuppressWarnings({"nullness", "rawtypes"}) public class JavaFieldSchema extends GetterBasedSchemaProviderV2 { /** {@link FieldValueTypeSupplier} that's based on public fields. */ @VisibleForTesting @@ -64,9 +65,9 @@ public List get(TypeDescriptor typeDescriptor) { .collect(Collectors.toList()); List types = Lists.newArrayListWithCapacity(fields.size()); for (int i = 0; i < fields.size(); ++i) { - types.add(FieldValueTypeInformation.forField(fields.get(i), i)); + types.add(FieldValueTypeInformation.forField(typeDescriptor, fields.get(i), i)); } - types.sort(Comparator.comparing(FieldValueTypeInformation::getNumber)); + types.sort(JavaBeanUtils.comparingNullFirst(FieldValueTypeInformation::getNumber)); validateFieldNumbers(types); // If there are no creators registered, then make sure none of the schema fields are final, @@ -75,7 +76,9 @@ public List get(TypeDescriptor typeDescriptor) { && ReflectUtils.getAnnotatedConstructor(typeDescriptor.getRawType()) == null) { Optional finalField = types.stream() - .map(FieldValueTypeInformation::getField) + .flatMap( + fvti -> + Optional.ofNullable(fvti.getField()).map(Stream::of).orElse(Stream.empty())) .filter(f -> Modifier.isFinal(f.getModifiers())) .findAny(); if (finalField.isPresent()) { @@ -115,8 +118,8 @@ public Schema schemaFor(TypeDescriptor typeDescriptor) { } @Override - public List fieldValueGetters( - TypeDescriptor targetTypeDescriptor, Schema schema) { + public List> fieldValueGetters( + TypeDescriptor targetTypeDescriptor, Schema schema) { return POJOUtils.getGetters( targetTypeDescriptor, schema, @@ -149,7 +152,7 @@ public SchemaUserTypeCreator schemaTypeCreator( ReflectUtils.getAnnotatedConstructor(targetTypeDescriptor.getRawType()); if (constructor != null) { return POJOUtils.getConstructorCreator( - targetTypeDescriptor, + (TypeDescriptor) targetTypeDescriptor, constructor, schema, JavaFieldTypeSupplier.INSTANCE, diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/Schema.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/Schema.java index 255d411028f9..3196db01b778 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/Schema.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/Schema.java @@ -89,6 +89,7 @@ public String toString() { return Arrays.toString(array); } } + // A mapping between field names an indices. private final BiMap fieldIndices = HashBiMap.create(); private Map encodingPositions = Maps.newHashMap(); @@ -823,10 +824,11 @@ public static FieldType iterable(FieldType elementType) { public static FieldType map(FieldType keyType, FieldType valueType) { if (FieldType.BYTES.equals(keyType)) { LOG.warn( - "Using byte arrays as keys in a Map may lead to unexpected behavior and may not work as intended. " - + "Since arrays do not override equals() or hashCode, comparisons will be done on reference equality only. " - + "ByteBuffers, when used as keys, present similar challenges because Row stores ByteBuffer as a byte array. " - + "Consider using a different type of key for more consistent and predictable behavior."); + "Using byte arrays as keys in a Map may lead to unexpected behavior and may not work as" + + " intended. Since arrays do not override equals() or hashCode, comparisons will" + + " be done on reference equality only. ByteBuffers, when used as keys, present" + + " similar challenges because Row stores ByteBuffer as a byte array. Consider" + + " using a different type of key for more consistent and predictable behavior."); } return FieldType.forTypeName(TypeName.MAP) .setMapKeyType(keyType) @@ -1436,7 +1438,7 @@ private static Schema fromFields(List fields) { } /** Return the list of all field names. */ - public List getFieldNames() { + public List<@NonNull String> getFieldNames() { return getFields().stream().map(Schema.Field::getName).collect(Collectors.toList()); } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/AutoValueUtils.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/AutoValueUtils.java index d7fddd8abfed..300dce61e2ea 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/AutoValueUtils.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/AutoValueUtils.java @@ -27,6 +27,7 @@ import java.lang.reflect.Parameter; import java.lang.reflect.Type; import java.util.Arrays; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Optional; @@ -62,21 +63,25 @@ import org.apache.beam.sdk.schemas.utils.ByteBuddyUtils.TypeConversionsFactory; import org.apache.beam.sdk.util.common.ReflectHelpers; import org.apache.beam.sdk.values.TypeDescriptor; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; import org.checkerframework.checker.nullness.qual.Nullable; /** Utilities for managing AutoValue schemas. */ -@SuppressWarnings({ - "nullness", // TODO(https://github.com/apache/beam/issues/20497) - "rawtypes" -}) +@SuppressWarnings({"rawtypes"}) public class AutoValueUtils { - public static TypeDescriptor getBaseAutoValueClass(TypeDescriptor typeDescriptor) { + public static @Nullable TypeDescriptor getBaseAutoValueClass( + TypeDescriptor typeDescriptor) { // AutoValue extensions may be nested - while (typeDescriptor != null && typeDescriptor.getRawType().getName().contains("AutoValue_")) { - typeDescriptor = TypeDescriptor.of(typeDescriptor.getRawType().getSuperclass()); + @Nullable TypeDescriptor baseTypeDescriptor = typeDescriptor; + while (baseTypeDescriptor != null + && baseTypeDescriptor.getRawType().getName().contains("AutoValue_")) { + baseTypeDescriptor = + Optional.ofNullable(baseTypeDescriptor.getRawType().getSuperclass()) + .map(TypeDescriptor::of) + .orElse(null); } - return typeDescriptor; + return baseTypeDescriptor; } private static TypeDescriptor getAutoValueGenerated(TypeDescriptor typeDescriptor) { @@ -154,7 +159,11 @@ private static boolean matchConstructor( getterTypes.stream() .collect( Collectors.toMap( - f -> ReflectUtils.stripGetterPrefix(f.getMethod().getName()), + f -> + ReflectUtils.stripGetterPrefix( + Preconditions.checkNotNull( + f.getMethod(), JavaBeanUtils.GETTER_WITH_NULL_METHOD_ERROR) + .getName()), Function.identity())); boolean valid = true; @@ -196,18 +205,23 @@ private static boolean matchConstructor( return null; } - Map setterTypes = - ReflectUtils.getMethods(builderClass).stream() - .filter(ReflectUtils::isSetter) - .map(FieldValueTypeInformation::forSetter) - .collect(Collectors.toMap(FieldValueTypeInformation::getName, Function.identity())); + Map setterTypes = new HashMap<>(); + + ReflectUtils.getMethods(builderClass).stream() + .filter(ReflectUtils::isSetter) + .map(m -> FieldValueTypeInformation.forSetter(TypeDescriptor.of(builderClass), m)) + .forEach(fv -> setterTypes.putIfAbsent(fv.getName(), fv)); List setterMethods = Lists.newArrayList(); // The builder methods to call in order. List schemaTypes = fieldValueTypeSupplier.get(TypeDescriptor.of(clazz), schema); for (FieldValueTypeInformation type : schemaTypes) { - String autoValueFieldName = ReflectUtils.stripGetterPrefix(type.getMethod().getName()); + String autoValueFieldName = + ReflectUtils.stripGetterPrefix( + Preconditions.checkNotNull( + type.getMethod(), JavaBeanUtils.GETTER_WITH_NULL_METHOD_ERROR) + .getName()); FieldValueTypeInformation setterType = setterTypes.get(autoValueFieldName); if (setterType == null) { diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/ByteBuddyUtils.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/ByteBuddyUtils.java index c2b33c2d2315..ccfad18d7536 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/ByteBuddyUtils.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/ByteBuddyUtils.java @@ -21,6 +21,7 @@ import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull; import java.lang.reflect.Constructor; +import java.lang.reflect.Field; import java.lang.reflect.Method; import java.lang.reflect.Modifier; import java.lang.reflect.Parameter; @@ -33,6 +34,7 @@ import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.Optional; import java.util.Set; import java.util.SortedMap; import net.bytebuddy.ByteBuddy; @@ -41,6 +43,7 @@ import net.bytebuddy.asm.AsmVisitorWrapper; import net.bytebuddy.description.method.MethodDescription.ForLoadedConstructor; import net.bytebuddy.description.method.MethodDescription.ForLoadedMethod; +import net.bytebuddy.description.type.PackageDescription; import net.bytebuddy.description.type.TypeDescription; import net.bytebuddy.description.type.TypeDescription.ForLoadedType; import net.bytebuddy.dynamic.DynamicType; @@ -77,6 +80,8 @@ import org.apache.beam.sdk.values.TypeDescriptor; import org.apache.beam.sdk.values.TypeParameter; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Function; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Verify; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Collections2; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; @@ -84,6 +89,7 @@ import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.primitives.Primitives; import org.apache.commons.lang3.ArrayUtils; import org.apache.commons.lang3.ClassUtils; +import org.checkerframework.checker.nullness.qual.NonNull; import org.checkerframework.checker.nullness.qual.Nullable; import org.joda.time.DateTimeZone; import org.joda.time.Instant; @@ -94,8 +100,6 @@ @Internal @SuppressWarnings({ "keyfor", - "nullness", // TODO(https://github.com/apache/beam/issues/20497) - "rawtypes" }) public class ByteBuddyUtils { private static final ForLoadedType ARRAYS_TYPE = new ForLoadedType(Arrays.class); @@ -146,7 +150,11 @@ protected String name(TypeDescription superClass) { // If the target class is in a prohibited package (java.*) then leave the original package // alone. String realPackage = - overridePackage(targetPackage) ? targetPackage : superClass.getPackage().getName(); + overridePackage(targetPackage) + ? targetPackage + : Optional.ofNullable(superClass.getPackage()) + .map(PackageDescription::getName) + .orElse(""); return realPackage + className + "$" + SUFFIX + "$" + randomString.nextString(); } @@ -201,25 +209,27 @@ static class ShortCircuitReturnNull extends IfNullElse { // Create a new FieldValueGetter subclass. @SuppressWarnings("unchecked") - public static DynamicType.Builder subclassGetterInterface( - ByteBuddy byteBuddy, Type objectType, Type fieldType) { + public static + DynamicType.Builder> subclassGetterInterface( + ByteBuddy byteBuddy, Type objectType, Type fieldType) { TypeDescription.Generic getterGenericType = TypeDescription.Generic.Builder.parameterizedType( FieldValueGetter.class, objectType, fieldType) .build(); - return (DynamicType.Builder) + return (DynamicType.Builder>) byteBuddy.with(new InjectPackageStrategy((Class) objectType)).subclass(getterGenericType); } // Create a new FieldValueSetter subclass. @SuppressWarnings("unchecked") - public static DynamicType.Builder subclassSetterInterface( - ByteBuddy byteBuddy, Type objectType, Type fieldType) { + public static + DynamicType.Builder> subclassSetterInterface( + ByteBuddy byteBuddy, Type objectType, Type fieldType) { TypeDescription.Generic setterGenericType = TypeDescription.Generic.Builder.parameterizedType( FieldValueSetter.class, objectType, fieldType) .build(); - return (DynamicType.Builder) + return (DynamicType.Builder) byteBuddy.with(new InjectPackageStrategy((Class) objectType)).subclass(setterGenericType); } @@ -251,9 +261,11 @@ public TypeConversion createSetterConversions(StackManipulati // Base class used below to convert types. @SuppressWarnings("unchecked") public abstract static class TypeConversion { - public T convert(TypeDescriptor typeDescriptor) { + public T convert(TypeDescriptor typeDescriptor) { if (typeDescriptor.isArray() - && !typeDescriptor.getComponentType().getRawType().equals(byte.class)) { + && !Preconditions.checkNotNull(typeDescriptor.getComponentType()) + .getRawType() + .equals(byte.class)) { // Byte arrays are special, so leave those alone. return convertArray(typeDescriptor); } else if (typeDescriptor.isSubtypeOf(TypeDescriptor.of(Map.class))) { @@ -338,25 +350,32 @@ protected ConvertType(boolean returnRawTypes) { @Override protected Type convertArray(TypeDescriptor type) { - TypeDescriptor ret = createCollectionType(type.getComponentType()); + TypeDescriptor ret = + createCollectionType(Preconditions.checkNotNull(type.getComponentType())); return returnRawTypes ? ret.getRawType() : ret.getType(); } @Override protected Type convertCollection(TypeDescriptor type) { - TypeDescriptor ret = createCollectionType(ReflectUtils.getIterableComponentType(type)); + TypeDescriptor ret = + createCollectionType( + Preconditions.checkNotNull(ReflectUtils.getIterableComponentType(type))); return returnRawTypes ? ret.getRawType() : ret.getType(); } @Override protected Type convertList(TypeDescriptor type) { - TypeDescriptor ret = createCollectionType(ReflectUtils.getIterableComponentType(type)); + TypeDescriptor ret = + createCollectionType( + Preconditions.checkNotNull(ReflectUtils.getIterableComponentType(type))); return returnRawTypes ? ret.getRawType() : ret.getType(); } @Override protected Type convertIterable(TypeDescriptor type) { - TypeDescriptor ret = createIterableType(ReflectUtils.getIterableComponentType(type)); + TypeDescriptor ret = + createIterableType( + Preconditions.checkNotNull(ReflectUtils.getIterableComponentType(type))); return returnRawTypes ? ret.getRawType() : ret.getType(); } @@ -398,8 +417,9 @@ protected Type convertDefault(TypeDescriptor type) { @SuppressWarnings("unchecked") private TypeDescriptor> createCollectionType( TypeDescriptor componentType) { - TypeDescriptor wrappedComponentType = - TypeDescriptor.of(ClassUtils.primitiveToWrapper(componentType.getRawType())); + TypeDescriptor wrappedComponentType = + (TypeDescriptor) + TypeDescriptor.of(ClassUtils.primitiveToWrapper(componentType.getRawType())); return new TypeDescriptor>() {}.where( new TypeParameter() {}, wrappedComponentType); } @@ -407,8 +427,9 @@ private TypeDescriptor> createCollectionType( @SuppressWarnings("unchecked") private TypeDescriptor> createIterableType( TypeDescriptor componentType) { - TypeDescriptor wrappedComponentType = - TypeDescriptor.of(ClassUtils.primitiveToWrapper(componentType.getRawType())); + TypeDescriptor wrappedComponentType = + (TypeDescriptor) + TypeDescriptor.of(ClassUtils.primitiveToWrapper(componentType.getRawType())); return new TypeDescriptor>() {}.where( new TypeParameter() {}, wrappedComponentType); } @@ -420,7 +441,7 @@ private TypeDescriptor> createIterableType( // This function // generates a subclass of Function that can be used to recursively transform each element of the // container. - static Class createCollectionTransformFunction( + static Class createCollectionTransformFunction( Type fromType, Type toType, Function convertElement) { // Generate a TypeDescription for the class we want to generate. TypeDescription.Generic functionGenericType = @@ -428,8 +449,8 @@ static Class createCollectionTransformFunction( Function.class, Primitives.wrap((Class) fromType), Primitives.wrap((Class) toType)) .build(); - DynamicType.Builder builder = - (DynamicType.Builder) + DynamicType.Builder> builder = + (DynamicType.Builder) BYTE_BUDDY .with(new InjectPackageStrategy((Class) fromType)) .subclass(functionGenericType) @@ -463,9 +484,11 @@ public InstrumentedType prepare(InstrumentedType instrumentedType) { .visit(new AsmVisitorWrapper.ForDeclaredMethods().writerFlags(ClassWriter.COMPUTE_FRAMES)) .make() .load( - ReflectHelpers.findClassLoader(((Class) fromType).getClassLoader()), + ReflectHelpers.findClassLoader(((Class) fromType).getClassLoader()), getClassLoadingStrategy( - ((Class) fromType).getClassLoader() == null ? Function.class : (Class) fromType)) + ((Class) fromType).getClassLoader() == null + ? Function.class + : (Class) fromType)) .getLoaded(); } @@ -547,17 +570,17 @@ public boolean containsValue(Object value) { } @Override - public V2 get(Object key) { + public @Nullable V2 get(Object key) { return delegateMap.get(key); } @Override - public V2 put(K2 key, V2 value) { + public @Nullable V2 put(K2 key, V2 value) { return delegateMap.put(key, value); } @Override - public V2 remove(Object key) { + public @Nullable V2 remove(Object key) { return delegateMap.remove(key); } @@ -635,12 +658,12 @@ protected StackManipulation convertArray(TypeDescriptor type) { // return isComponentTypePrimitive ? Arrays.asList(ArrayUtils.toObject(value)) // : Arrays.asList(value); - TypeDescriptor componentType = type.getComponentType(); + TypeDescriptor componentType = Preconditions.checkNotNull(type.getComponentType()); ForLoadedType loadedArrayType = new ForLoadedType(type.getRawType()); StackManipulation readArrayValue = readValue; // Row always expects to get an Iterable back for array types. Wrap this array into a // List using Arrays.asList before returning. - if (loadedArrayType.getComponentType().isPrimitive()) { + if (Preconditions.checkNotNull(loadedArrayType.getComponentType()).isPrimitive()) { // Arrays.asList doesn't take primitive arrays, so convert first using ArrayUtils.toObject. readArrayValue = new Compound( @@ -668,7 +691,7 @@ protected StackManipulation convertArray(TypeDescriptor type) { // Generate a SerializableFunction to convert the element-type objects. StackManipulation stackManipulation; - final TypeDescriptor finalComponentType = ReflectUtils.boxIfPrimitive(componentType); + final TypeDescriptor finalComponentType = ReflectUtils.boxIfPrimitive(componentType); if (!finalComponentType.hasUnresolvedParameters()) { Type convertedComponentType = getFactory().createTypeConversion(true).convert(componentType); @@ -687,10 +710,11 @@ protected StackManipulation convertArray(TypeDescriptor type) { @Override protected StackManipulation convertIterable(TypeDescriptor type) { - TypeDescriptor componentType = ReflectUtils.getIterableComponentType(type); + TypeDescriptor componentType = + Preconditions.checkNotNull(ReflectUtils.getIterableComponentType(type)); Type convertedComponentType = getFactory().createTypeConversion(true).convert(componentType); - final TypeDescriptor finalComponentType = ReflectUtils.boxIfPrimitive(componentType); + final TypeDescriptor finalComponentType = ReflectUtils.boxIfPrimitive(componentType); if (!finalComponentType.hasUnresolvedParameters()) { ForLoadedType functionType = new ForLoadedType( @@ -707,9 +731,10 @@ protected StackManipulation convertIterable(TypeDescriptor type) { @Override protected StackManipulation convertCollection(TypeDescriptor type) { - TypeDescriptor componentType = ReflectUtils.getIterableComponentType(type); + TypeDescriptor componentType = + Preconditions.checkNotNull(ReflectUtils.getIterableComponentType(type)); Type convertedComponentType = getFactory().createTypeConversion(true).convert(componentType); - final TypeDescriptor finalComponentType = ReflectUtils.boxIfPrimitive(componentType); + final TypeDescriptor finalComponentType = ReflectUtils.boxIfPrimitive(componentType); if (!finalComponentType.hasUnresolvedParameters()) { ForLoadedType functionType = new ForLoadedType( @@ -726,9 +751,10 @@ protected StackManipulation convertCollection(TypeDescriptor type) { @Override protected StackManipulation convertList(TypeDescriptor type) { - TypeDescriptor componentType = ReflectUtils.getIterableComponentType(type); + TypeDescriptor componentType = + Preconditions.checkNotNull(ReflectUtils.getIterableComponentType(type)); Type convertedComponentType = getFactory().createTypeConversion(true).convert(componentType); - final TypeDescriptor finalComponentType = ReflectUtils.boxIfPrimitive(componentType); + final TypeDescriptor finalComponentType = ReflectUtils.boxIfPrimitive(componentType); if (!finalComponentType.hasUnresolvedParameters()) { ForLoadedType functionType = new ForLoadedType( @@ -745,8 +771,8 @@ protected StackManipulation convertList(TypeDescriptor type) { @Override protected StackManipulation convertMap(TypeDescriptor type) { - final TypeDescriptor keyType = ReflectUtils.getMapType(type, 0); - final TypeDescriptor valueType = ReflectUtils.getMapType(type, 1); + final TypeDescriptor keyType = ReflectUtils.getMapType(type, 0); + final TypeDescriptor valueType = ReflectUtils.getMapType(type, 1); Type convertedKeyType = getFactory().createTypeConversion(true).convert(keyType); Type convertedValueType = getFactory().createTypeConversion(true).convert(valueType); @@ -970,16 +996,18 @@ protected StackManipulation convertArray(TypeDescriptor type) { // return isPrimitive ? toArray : ArrayUtils.toPrimitive(toArray); ForLoadedType loadedType = new ForLoadedType(type.getRawType()); + TypeDescription loadedTypeComponentType = Verify.verifyNotNull(loadedType.getComponentType()); + // The type of the array containing the (possibly) boxed values. TypeDescription arrayType = - TypeDescription.Generic.Builder.rawType(loadedType.getComponentType().asBoxed()) + TypeDescription.Generic.Builder.rawType(loadedTypeComponentType.asBoxed()) .asArray() .build() .asErasure(); - Type rowElementType = - getFactory().createTypeConversion(false).convert(type.getComponentType()); - final TypeDescriptor arrayElementType = ReflectUtils.boxIfPrimitive(type.getComponentType()); + TypeDescriptor componentType = Preconditions.checkNotNull(type.getComponentType()); + Type rowElementType = getFactory().createTypeConversion(false).convert(componentType); + final TypeDescriptor arrayElementType = ReflectUtils.boxIfPrimitive(componentType); StackManipulation readTransformedValue = readValue; if (!arrayElementType.hasUnresolvedParameters()) { ForLoadedType conversionFunction = @@ -999,7 +1027,7 @@ protected StackManipulation convertArray(TypeDescriptor type) { // Call Collection.toArray(T[[]) to extract the array. Push new T[0] on the stack // before // calling toArray. - ArrayFactory.forType(loadedType.getComponentType().asBoxed().asGenericType()) + ArrayFactory.forType(loadedTypeComponentType.asBoxed().asGenericType()) .withValues(Collections.emptyList()), MethodInvocation.invoke( COLLECTION_TYPE @@ -1016,7 +1044,7 @@ protected StackManipulation convertArray(TypeDescriptor type) { // Cast the result to T[]. TypeCasting.to(arrayType)); - if (loadedType.getComponentType().isPrimitive()) { + if (loadedTypeComponentType.isPrimitive()) { // The array we extract will be an array of objects. If the pojo field is an array of // primitive types, we need to then convert to an array of unboxed objects. stackManipulation = @@ -1035,11 +1063,9 @@ protected StackManipulation convertArray(TypeDescriptor type) { @Override protected StackManipulation convertIterable(TypeDescriptor type) { - Type rowElementType = - getFactory() - .createTypeConversion(false) - .convert(ReflectUtils.getIterableComponentType(type)); - final TypeDescriptor iterableElementType = ReflectUtils.getIterableComponentType(type); + final TypeDescriptor iterableElementType = + Preconditions.checkNotNull(ReflectUtils.getIterableComponentType(type)); + Type rowElementType = getFactory().createTypeConversion(false).convert(iterableElementType); if (!iterableElementType.hasUnresolvedParameters()) { ForLoadedType conversionFunction = new ForLoadedType( @@ -1057,11 +1083,9 @@ protected StackManipulation convertIterable(TypeDescriptor type) { @Override protected StackManipulation convertCollection(TypeDescriptor type) { - Type rowElementType = - getFactory() - .createTypeConversion(false) - .convert(ReflectUtils.getIterableComponentType(type)); - final TypeDescriptor collectionElementType = ReflectUtils.getIterableComponentType(type); + final TypeDescriptor collectionElementType = + Preconditions.checkNotNull(ReflectUtils.getIterableComponentType(type)); + Type rowElementType = getFactory().createTypeConversion(false).convert(collectionElementType); if (!collectionElementType.hasUnresolvedParameters()) { ForLoadedType conversionFunction = @@ -1080,11 +1104,9 @@ protected StackManipulation convertCollection(TypeDescriptor type) { @Override protected StackManipulation convertList(TypeDescriptor type) { - Type rowElementType = - getFactory() - .createTypeConversion(false) - .convert(ReflectUtils.getIterableComponentType(type)); - final TypeDescriptor collectionElementType = ReflectUtils.getIterableComponentType(type); + final TypeDescriptor collectionElementType = + Preconditions.checkNotNull(ReflectUtils.getIterableComponentType(type)); + Type rowElementType = getFactory().createTypeConversion(false).convert(collectionElementType); StackManipulation readTrasformedValue = readValue; if (!collectionElementType.hasUnresolvedParameters()) { @@ -1112,12 +1134,12 @@ protected StackManipulation convertList(TypeDescriptor type) { @Override protected StackManipulation convertMap(TypeDescriptor type) { - Type rowKeyType = - getFactory().createTypeConversion(false).convert(ReflectUtils.getMapType(type, 0)); - final TypeDescriptor keyElementType = ReflectUtils.getMapType(type, 0); - Type rowValueType = - getFactory().createTypeConversion(false).convert(ReflectUtils.getMapType(type, 1)); - final TypeDescriptor valueElementType = ReflectUtils.getMapType(type, 1); + final TypeDescriptor keyElementType = + Preconditions.checkNotNull(ReflectUtils.getMapType(type, 0)); + final TypeDescriptor valueElementType = + Preconditions.checkNotNull(ReflectUtils.getMapType(type, 1)); + Type rowKeyType = getFactory().createTypeConversion(false).convert(keyElementType); + Type rowValueType = getFactory().createTypeConversion(false).convert(valueElementType); StackManipulation readTrasformedValue = readValue; if (!keyElementType.hasUnresolvedParameters() @@ -1332,12 +1354,12 @@ protected StackManipulation convertDefault(TypeDescriptor type) { * constructor. */ static class ConstructorCreateInstruction extends InvokeUserCreateInstruction { - private final Constructor constructor; + private final Constructor constructor; ConstructorCreateInstruction( List fields, - Class targetClass, - Constructor constructor, + Class targetClass, + Constructor constructor, TypeConversionsFactory typeConversionsFactory) { super( fields, @@ -1375,7 +1397,7 @@ static class StaticFactoryMethodInstruction extends InvokeUserCreateInstruction StaticFactoryMethodInstruction( List fields, - Class targetClass, + Class targetClass, Method creator, TypeConversionsFactory typeConversionsFactory) { super( @@ -1399,14 +1421,14 @@ protected StackManipulation afterPushingParameters() { static class InvokeUserCreateInstruction implements Implementation { protected final List fields; - protected final Class targetClass; + protected final Class targetClass; protected final List parameters; protected final Map fieldMapping; private final TypeConversionsFactory typeConversionsFactory; protected InvokeUserCreateInstruction( List fields, - Class targetClass, + Class targetClass, List parameters, TypeConversionsFactory typeConversionsFactory) { this.fields = fields; @@ -1424,11 +1446,15 @@ protected InvokeUserCreateInstruction( // actual Java field or method names. FieldValueTypeInformation fieldValue = checkNotNull(fields.get(i)); fieldsByLogicalName.put(fieldValue.getName(), i); - if (fieldValue.getField() != null) { - fieldsByJavaClassMember.put(fieldValue.getField().getName(), i); - } else if (fieldValue.getMethod() != null) { - String name = ReflectUtils.stripGetterPrefix(fieldValue.getMethod().getName()); - fieldsByJavaClassMember.put(name, i); + Field field = fieldValue.getField(); + if (field != null) { + fieldsByJavaClassMember.put(field.getName(), i); + } else { + Method method = fieldValue.getMethod(); + if (method != null) { + String name = ReflectUtils.stripGetterPrefix(method.getName()); + fieldsByJavaClassMember.put(name, i); + } } } @@ -1482,7 +1508,7 @@ public ByteCodeAppender appender(final Target implementationTarget) { StackManipulation readParameter = new StackManipulation.Compound( MethodVariableAccess.REFERENCE.loadFrom(1), - IntegerConstant.forValue(fieldMapping.get(i)), + IntegerConstant.forValue(Preconditions.checkNotNull(fieldMapping.get(i))), ArrayAccess.REFERENCE.load(), TypeCasting.to(convertedType)); stackManipulation = diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/JavaBeanUtils.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/JavaBeanUtils.java index 911f79f6eeed..ee4868ddb2b6 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/JavaBeanUtils.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/JavaBeanUtils.java @@ -22,9 +22,11 @@ import java.lang.reflect.Constructor; import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; +import java.util.Comparator; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.function.Function; import java.util.stream.Collectors; import net.bytebuddy.ByteBuddy; import net.bytebuddy.asm.AsmVisitorWrapper; @@ -54,14 +56,22 @@ import org.apache.beam.sdk.schemas.utils.ReflectUtils.TypeDescriptorWithSchema; import org.apache.beam.sdk.util.common.ReflectHelpers; import org.apache.beam.sdk.values.TypeDescriptor; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Maps; +import org.checkerframework.checker.nullness.qual.NonNull; +import org.checkerframework.checker.nullness.qual.Nullable; /** A set of utilities to generate getter and setter classes for JavaBean objects. */ -@SuppressWarnings({ - "nullness", // TODO(https://github.com/apache/beam/issues/20497) - "rawtypes" -}) +@SuppressWarnings({"rawtypes"}) public class JavaBeanUtils { + + private static final String X_WITH_NULL_METHOD_ERROR_FMT = + "a %s FieldValueTypeInformation object has a null method field"; + public static final String GETTER_WITH_NULL_METHOD_ERROR = + String.format(X_WITH_NULL_METHOD_ERROR_FMT, "getter"); + public static final String SETTER_WITH_NULL_METHOD_ERROR = + String.format(X_WITH_NULL_METHOD_ERROR_FMT, "setter"); + /** Create a {@link Schema} for a Java Bean class. */ public static Schema schemaFromJavaBeanClass( TypeDescriptor typeDescriptor, FieldValueTypeSupplier fieldValueTypeSupplier) { @@ -69,7 +79,9 @@ public static Schema schemaFromJavaBeanClass( } private static final String CONSTRUCTOR_HELP_STRING = - "In order to infer a Schema from a Java Bean, it must have a constructor annotated with @SchemaCreate, or it must have a compatible setter for every getter used as a Schema field."; + "In order to infer a Schema from a Java Bean, it must have a constructor annotated with" + + " @SchemaCreate, or it must have a compatible setter for every getter used as a Schema" + + " field."; // Make sure that there are matching setters and getters. public static void validateJavaBean( @@ -88,23 +100,26 @@ public static void validateJavaBean( for (FieldValueTypeInformation type : getters) { FieldValueTypeInformation setterType = setterMap.get(type.getName()); + Method m = Preconditions.checkNotNull(type.getMethod(), GETTER_WITH_NULL_METHOD_ERROR); if (setterType == null) { throw new RuntimeException( String.format( - "Java Bean '%s' contains a getter for field '%s', but does not contain a matching setter. %s", - type.getMethod().getDeclaringClass(), type.getName(), CONSTRUCTOR_HELP_STRING)); + "Java Bean '%s' contains a getter for field '%s', but does not contain a matching" + + " setter. %s", + m.getDeclaringClass(), type.getName(), CONSTRUCTOR_HELP_STRING)); } if (!type.getType().equals(setterType.getType())) { throw new RuntimeException( String.format( "Java Bean '%s' contains a setter for field '%s' that has a mismatching type. %s", - type.getMethod().getDeclaringClass(), type.getName(), CONSTRUCTOR_HELP_STRING)); + m.getDeclaringClass(), type.getName(), CONSTRUCTOR_HELP_STRING)); } if (!type.isNullable() == setterType.isNullable()) { throw new RuntimeException( String.format( - "Java Bean '%s' contains a setter for field '%s' that has a mismatching nullable attribute. %s", - type.getMethod().getDeclaringClass(), type.getName(), CONSTRUCTOR_HELP_STRING)); + "Java Bean '%s' contains a setter for field '%s' that has a mismatching nullable" + + " attribute. %s", + m.getDeclaringClass(), type.getName(), CONSTRUCTOR_HELP_STRING)); } } } @@ -126,36 +141,41 @@ public static List getFieldTypes( // The list of getters for a class is cached, so we only create the classes the first time // getSetters is called. - private static final Map, List> CACHED_GETTERS = - Maps.newConcurrentMap(); + private static final Map, List>> + CACHED_GETTERS = Maps.newConcurrentMap(); /** * Return the list of {@link FieldValueGetter}s for a Java Bean class * *

The returned list is ordered by the order of fields in the schema. */ - public static List getGetters( - TypeDescriptor typeDescriptor, + public static List> getGetters( + TypeDescriptor typeDescriptor, Schema schema, FieldValueTypeSupplier fieldValueTypeSupplier, TypeConversionsFactory typeConversionsFactory) { - return CACHED_GETTERS.computeIfAbsent( - TypeDescriptorWithSchema.create(typeDescriptor, schema), - c -> { - List types = - fieldValueTypeSupplier.get(typeDescriptor, schema); - return types.stream() - .map(t -> createGetter(t, typeConversionsFactory)) - .collect(Collectors.toList()); - }); + return (List) + CACHED_GETTERS.computeIfAbsent( + TypeDescriptorWithSchema.create(typeDescriptor, schema), + c -> { + List types = + fieldValueTypeSupplier.get(typeDescriptor, schema); + return types.stream() + .map(t -> JavaBeanUtils.createGetter(t, typeConversionsFactory)) + .collect(Collectors.toList()); + }); } - public static FieldValueGetter createGetter( - FieldValueTypeInformation typeInformation, TypeConversionsFactory typeConversionsFactory) { - DynamicType.Builder builder = + public static + FieldValueGetter createGetter( + FieldValueTypeInformation typeInformation, + TypeConversionsFactory typeConversionsFactory) { + final Method m = + Preconditions.checkNotNull(typeInformation.getMethod(), GETTER_WITH_NULL_METHOD_ERROR); + DynamicType.Builder> builder = ByteBuddyUtils.subclassGetterInterface( BYTE_BUDDY, - typeInformation.getMethod().getDeclaringClass(), + m.getDeclaringClass(), typeConversionsFactory.createTypeConversion(false).convert(typeInformation.getType())); builder = implementGetterMethods(builder, typeInformation, typeConversionsFactory); try { @@ -163,9 +183,8 @@ public static FieldValueGetter createGetter( .visit(new AsmVisitorWrapper.ForDeclaredMethods().writerFlags(ClassWriter.COMPUTE_FRAMES)) .make() .load( - ReflectHelpers.findClassLoader( - typeInformation.getMethod().getDeclaringClass().getClassLoader()), - getClassLoadingStrategy(typeInformation.getMethod().getDeclaringClass())) + ReflectHelpers.findClassLoader(m.getDeclaringClass().getClassLoader()), + getClassLoadingStrategy(m.getDeclaringClass())) .getLoaded() .getDeclaredConstructor() .newInstance(); @@ -178,10 +197,11 @@ public static FieldValueGetter createGetter( } } - private static DynamicType.Builder implementGetterMethods( - DynamicType.Builder builder, - FieldValueTypeInformation typeInformation, - TypeConversionsFactory typeConversionsFactory) { + private static + DynamicType.Builder> implementGetterMethods( + DynamicType.Builder> builder, + FieldValueTypeInformation typeInformation, + TypeConversionsFactory typeConversionsFactory) { return builder .method(ElementMatchers.named("name")) .intercept(FixedValue.reference(typeInformation.getName())) @@ -215,12 +235,14 @@ public static List getSetters( }); } - public static FieldValueSetter createSetter( + public static FieldValueSetter createSetter( FieldValueTypeInformation typeInformation, TypeConversionsFactory typeConversionsFactory) { - DynamicType.Builder builder = + final Method m = + Preconditions.checkNotNull(typeInformation.getMethod(), SETTER_WITH_NULL_METHOD_ERROR); + DynamicType.Builder> builder = ByteBuddyUtils.subclassSetterInterface( BYTE_BUDDY, - typeInformation.getMethod().getDeclaringClass(), + m.getDeclaringClass(), typeConversionsFactory.createTypeConversion(false).convert(typeInformation.getType())); builder = implementSetterMethods(builder, typeInformation, typeConversionsFactory); try { @@ -228,9 +250,8 @@ public static FieldValueSetter createSetter( .visit(new AsmVisitorWrapper.ForDeclaredMethods().writerFlags(ClassWriter.COMPUTE_FRAMES)) .make() .load( - ReflectHelpers.findClassLoader( - typeInformation.getMethod().getDeclaringClass().getClassLoader()), - getClassLoadingStrategy(typeInformation.getMethod().getDeclaringClass())) + ReflectHelpers.findClassLoader(m.getDeclaringClass().getClassLoader()), + getClassLoadingStrategy(m.getDeclaringClass())) .getLoaded() .getDeclaredConstructor() .newInstance(); @@ -243,10 +264,11 @@ public static FieldValueSetter createSetter( } } - private static DynamicType.Builder implementSetterMethods( - DynamicType.Builder builder, - FieldValueTypeInformation fieldValueTypeInformation, - TypeConversionsFactory typeConversionsFactory) { + private static + DynamicType.Builder> implementSetterMethods( + DynamicType.Builder> builder, + FieldValueTypeInformation fieldValueTypeInformation, + TypeConversionsFactory typeConversionsFactory) { return builder .method(ElementMatchers.named("name")) .intercept(FixedValue.reference(fieldValueTypeInformation.getName())) @@ -358,6 +380,11 @@ public static SchemaUserTypeCreator createStaticCreator( } } + public static > Comparator comparingNullFirst( + Function keyExtractor) { + return Comparator.comparing(keyExtractor, Comparator.nullsFirst(Comparator.naturalOrder())); + } + // Implements a method to read a public getter out of an object. private static class InvokeGetterInstruction implements Implementation { private final FieldValueTypeInformation typeInformation; @@ -386,7 +413,10 @@ public ByteCodeAppender appender(final Target implementationTarget) { // Method param is offset 1 (offset 0 is the this parameter). MethodVariableAccess.REFERENCE.loadFrom(1), // Invoke the getter - MethodInvocation.invoke(new ForLoadedMethod(typeInformation.getMethod()))); + MethodInvocation.invoke( + new ForLoadedMethod( + Preconditions.checkNotNull( + typeInformation.getMethod(), GETTER_WITH_NULL_METHOD_ERROR)))); StackManipulation stackManipulation = new StackManipulation.Compound( @@ -428,7 +458,9 @@ public ByteCodeAppender appender(final Target implementationTarget) { // The instruction to read the field. StackManipulation readField = MethodVariableAccess.REFERENCE.loadFrom(2); - Method method = fieldValueTypeInformation.getMethod(); + Method method = + Preconditions.checkNotNull( + fieldValueTypeInformation.getMethod(), SETTER_WITH_NULL_METHOD_ERROR); boolean setterMethodReturnsVoid = method.getReturnType().equals(Void.TYPE); // Read the object onto the stack. StackManipulation stackManipulation = diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/POJOUtils.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/POJOUtils.java index 571b9c690900..8e33d321a1c6 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/POJOUtils.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/POJOUtils.java @@ -62,8 +62,9 @@ import org.apache.beam.sdk.schemas.utils.ReflectUtils.TypeDescriptorWithSchema; import org.apache.beam.sdk.util.common.ReflectHelpers; import org.apache.beam.sdk.values.TypeDescriptor; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Maps; -import org.checkerframework.checker.nullness.qual.Nullable; +import org.checkerframework.checker.nullness.qual.NonNull; /** A set of utilities to generate getter and setter classes for POJOs. */ @SuppressWarnings({ @@ -94,38 +95,40 @@ public static List getFieldTypes( // The list of getters for a class is cached, so we only create the classes the first time // getSetters is called. - private static final Map> CACHED_GETTERS = - Maps.newConcurrentMap(); + private static final Map, List>> + CACHED_GETTERS = Maps.newConcurrentMap(); - public static List getGetters( - TypeDescriptor typeDescriptor, + public static List> getGetters( + TypeDescriptor typeDescriptor, Schema schema, FieldValueTypeSupplier fieldValueTypeSupplier, TypeConversionsFactory typeConversionsFactory) { // Return the getters ordered by their position in the schema. - return CACHED_GETTERS.computeIfAbsent( - TypeDescriptorWithSchema.create(typeDescriptor, schema), - c -> { - List types = - fieldValueTypeSupplier.get(typeDescriptor, schema); - List getters = - types.stream() - .map(t -> createGetter(t, typeConversionsFactory)) - .collect(Collectors.toList()); - if (getters.size() != schema.getFieldCount()) { - throw new RuntimeException( - "Was not able to generate getters for schema: " - + schema - + " class: " - + typeDescriptor); - } - return getters; - }); + return (List) + CACHED_GETTERS.computeIfAbsent( + TypeDescriptorWithSchema.create(typeDescriptor, schema), + c -> { + List types = + fieldValueTypeSupplier.get(typeDescriptor, schema); + List> getters = + types.stream() + .>map( + t -> POJOUtils.createGetter(t, typeConversionsFactory)) + .collect(Collectors.toList()); + if (getters.size() != schema.getFieldCount()) { + throw new RuntimeException( + "Was not able to generate getters for schema: " + + schema + + " class: " + + typeDescriptor); + } + return (List) getters; + }); } // The list of constructors for a class is cached, so we only create the classes the first time // getConstructor is called. - public static final Map CACHED_CREATORS = + public static final Map, SchemaUserTypeCreator> CACHED_CREATORS = Maps.newConcurrentMap(); public static SchemaUserTypeCreator getSetFieldCreator( @@ -150,7 +153,9 @@ private static SchemaUserTypeCreator createSetFieldCreator( TypeConversionsFactory typeConversionsFactory) { // Get the list of class fields ordered by schema. List fields = - types.stream().map(FieldValueTypeInformation::getField).collect(Collectors.toList()); + types.stream() + .map(type -> Preconditions.checkNotNull(type.getField())) + .collect(Collectors.toList()); try { DynamicType.Builder builder = BYTE_BUDDY @@ -175,14 +180,16 @@ private static SchemaUserTypeCreator createSetFieldCreator( | InvocationTargetException e) { throw new RuntimeException( String.format( - "Unable to generate a creator for POJO '%s' with inferred schema: %s%nNote POJOs must have a zero-argument constructor, or a constructor annotated with @SchemaCreate.", + "Unable to generate a creator for POJO '%s' with inferred schema: %s%nNote POJOs must" + + " have a zero-argument constructor, or a constructor annotated with" + + " @SchemaCreate.", clazz, schema)); } } - public static SchemaUserTypeCreator getConstructorCreator( - TypeDescriptor typeDescriptor, - Constructor constructor, + public static SchemaUserTypeCreator getConstructorCreator( + TypeDescriptor typeDescriptor, + Constructor constructor, Schema schema, FieldValueTypeSupplier fieldValueTypeSupplier, TypeConversionsFactory typeConversionsFactory) { @@ -191,13 +198,13 @@ public static SchemaUserTypeCreator getConstructorCreator( c -> { List types = fieldValueTypeSupplier.get(typeDescriptor, schema); - return createConstructorCreator( + return POJOUtils.createConstructorCreator( typeDescriptor.getRawType(), constructor, schema, types, typeConversionsFactory); }); } public static SchemaUserTypeCreator createConstructorCreator( - Class clazz, + Class clazz, Constructor constructor, Schema schema, List types, @@ -291,11 +298,10 @@ public static SchemaUserTypeCreator createStaticCreator( * } * */ - @SuppressWarnings("unchecked") - static @Nullable FieldValueGetter createGetter( + static FieldValueGetter<@NonNull ObjectT, ValueT> createGetter( FieldValueTypeInformation typeInformation, TypeConversionsFactory typeConversionsFactory) { - Field field = typeInformation.getField(); - DynamicType.Builder builder = + Field field = Preconditions.checkNotNull(typeInformation.getField()); + DynamicType.Builder> builder = ByteBuddyUtils.subclassGetterInterface( BYTE_BUDDY, field.getDeclaringClass(), @@ -322,11 +328,12 @@ public static SchemaUserTypeCreator createStaticCreator( } } - private static DynamicType.Builder implementGetterMethods( - DynamicType.Builder builder, - Field field, - String name, - TypeConversionsFactory typeConversionsFactory) { + private static + DynamicType.Builder> implementGetterMethods( + DynamicType.Builder> builder, + Field field, + String name, + TypeConversionsFactory typeConversionsFactory) { return builder .visit(new AsmVisitorWrapper.ForDeclaredMethods().writerFlags(ClassWriter.COMPUTE_FRAMES)) .method(ElementMatchers.named("name")) @@ -337,24 +344,25 @@ private static DynamicType.Builder implementGetterMethods( // The list of setters for a class is cached, so we only create the classes the first time // getSetters is called. - private static final Map> CACHED_SETTERS = - Maps.newConcurrentMap(); + private static final Map, List>> + CACHED_SETTERS = Maps.newConcurrentMap(); - public static List getSetters( - TypeDescriptor typeDescriptor, + public static List> getSetters( + TypeDescriptor typeDescriptor, Schema schema, FieldValueTypeSupplier fieldValueTypeSupplier, TypeConversionsFactory typeConversionsFactory) { // Return the setters, ordered by their position in the schema. - return CACHED_SETTERS.computeIfAbsent( - TypeDescriptorWithSchema.create(typeDescriptor, schema), - c -> { - List types = - fieldValueTypeSupplier.get(typeDescriptor, schema); - return types.stream() - .map(t -> createSetter(t, typeConversionsFactory)) - .collect(Collectors.toList()); - }); + return (List) + CACHED_SETTERS.computeIfAbsent( + TypeDescriptorWithSchema.create(typeDescriptor, schema), + c -> { + List types = + fieldValueTypeSupplier.get(typeDescriptor, schema); + return types.stream() + .map(t -> createSetter(t, typeConversionsFactory)) + .collect(Collectors.toList()); + }); } /** @@ -376,8 +384,8 @@ public static List getSetters( @SuppressWarnings("unchecked") private static FieldValueSetter createSetter( FieldValueTypeInformation typeInformation, TypeConversionsFactory typeConversionsFactory) { - Field field = typeInformation.getField(); - DynamicType.Builder builder = + Field field = Preconditions.checkNotNull(typeInformation.getField()); + DynamicType.Builder> builder = ByteBuddyUtils.subclassSetterInterface( BYTE_BUDDY, field.getDeclaringClass(), @@ -403,10 +411,11 @@ private static FieldValueSetter createSetter( } } - private static DynamicType.Builder implementSetterMethods( - DynamicType.Builder builder, - Field field, - TypeConversionsFactory typeConversionsFactory) { + private static + DynamicType.Builder> implementSetterMethods( + DynamicType.Builder> builder, + Field field, + TypeConversionsFactory typeConversionsFactory) { return builder .visit(new AsmVisitorWrapper.ForDeclaredMethods().writerFlags(ClassWriter.COMPUTE_FRAMES)) .method(ElementMatchers.named("name")) @@ -505,11 +514,11 @@ public ByteCodeAppender appender(final Target implementationTarget) { // Implements a method to construct an object. static class SetFieldCreateInstruction implements Implementation { private final List fields; - private final Class pojoClass; + private final Class pojoClass; private final TypeConversionsFactory typeConversionsFactory; SetFieldCreateInstruction( - List fields, Class pojoClass, TypeConversionsFactory typeConversionsFactory) { + List fields, Class pojoClass, TypeConversionsFactory typeConversionsFactory) { this.fields = fields; this.pojoClass = pojoClass; this.typeConversionsFactory = typeConversionsFactory; diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/ReflectUtils.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/ReflectUtils.java index 4349a04c28ad..423fea4c3845 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/ReflectUtils.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/ReflectUtils.java @@ -32,7 +32,6 @@ import java.util.LinkedHashMap; import java.util.List; import java.util.Map; -import java.util.stream.Collectors; import org.apache.beam.sdk.annotations.Internal; import org.apache.beam.sdk.schemas.Schema; import org.apache.beam.sdk.schemas.annotations.SchemaCreate; @@ -88,14 +87,23 @@ public static List getMethods(Class clazz) { return DECLARED_METHODS.computeIfAbsent( clazz, c -> { - return Arrays.stream(c.getDeclaredMethods()) - .filter( - m -> !m.isBridge()) // Covariant overloads insert bridge functions, which we must - // ignore. - .filter(m -> !Modifier.isPrivate(m.getModifiers())) - .filter(m -> !Modifier.isProtected(m.getModifiers())) - .filter(m -> !Modifier.isStatic(m.getModifiers())) - .collect(Collectors.toList()); + List methods = Lists.newArrayList(); + do { + if (c.getPackage() != null && c.getPackage().getName().startsWith("java.")) { + break; // skip java built-in classes + } + Arrays.stream(c.getDeclaredMethods()) + .filter( + m -> + !m.isBridge()) // Covariant overloads insert bridge functions, which we must + // ignore. + .filter(m -> !Modifier.isPrivate(m.getModifiers())) + .filter(m -> !Modifier.isProtected(m.getModifiers())) + .filter(m -> !Modifier.isStatic(m.getModifiers())) + .forEach(methods::add); + c = c.getSuperclass(); + } while (c != null); + return methods; }); } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/common/ReflectHelpers.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/common/ReflectHelpers.java index aeb76492bb6d..c2d945bbaac1 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/common/ReflectHelpers.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/common/ReflectHelpers.java @@ -44,6 +44,7 @@ import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSortedSet; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Queues; +import org.checkerframework.checker.nullness.qual.Nullable; /** Utilities for working with with {@link Class Classes} and {@link Method Methods}. */ @SuppressWarnings({"nullness", "keyfor"}) // TODO(https://github.com/apache/beam/issues/20497) @@ -216,7 +217,7 @@ public static Iterable loadServicesOrdered(Class iface) { * which by default would use the proposed {@code ClassLoader}, which can be null. The fallback is * as follows: context ClassLoader, class ClassLoader and finally the system ClassLoader. */ - public static ClassLoader findClassLoader(final ClassLoader proposed) { + public static ClassLoader findClassLoader(@Nullable final ClassLoader proposed) { ClassLoader classLoader = proposed; if (classLoader == null) { classLoader = ReflectHelpers.class.getClassLoader(); diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/Row.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/Row.java index ee3852d70bbe..591a83600561 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/Row.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/Row.java @@ -48,6 +48,7 @@ import org.apache.beam.sdk.values.RowUtils.RowFieldMatcher; import org.apache.beam.sdk.values.RowUtils.RowPosition; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; +import org.checkerframework.checker.nullness.qual.NonNull; import org.checkerframework.checker.nullness.qual.Nullable; import org.joda.time.DateTime; import org.joda.time.ReadableDateTime; @@ -771,6 +772,7 @@ public FieldValueBuilder withFieldValue( checkState(values.isEmpty()); return new FieldValueBuilder(schema, null).withFieldValue(fieldAccessDescriptor, value); } + /** * Sets field values using the field names. Nested values can be set using the field selection * syntax. @@ -836,10 +838,10 @@ public int nextFieldId() { } @Internal - public Row withFieldValueGetters( - Factory> fieldValueGetterFactory, Object getterTarget) { + public <@NonNull T> Row withFieldValueGetters( + Factory>> fieldValueGetterFactory, T getterTarget) { checkState(getterTarget != null, "getters require withGetterTarget."); - return new RowWithGetters(schema, fieldValueGetterFactory, getterTarget); + return new RowWithGetters<>(schema, fieldValueGetterFactory, getterTarget); } public Row build() { diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/RowWithGetters.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/RowWithGetters.java index 9731507fb0f6..35e0ac20d3f7 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/RowWithGetters.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/RowWithGetters.java @@ -42,13 +42,13 @@ * the appropriate fields from the POJO. */ @SuppressWarnings("rawtypes") -public class RowWithGetters extends Row { - private final Object getterTarget; - private final List getters; +public class RowWithGetters extends Row { + private final T getterTarget; + private final List> getters; private @Nullable Map cache = null; RowWithGetters( - Schema schema, Factory> getterFactory, Object getterTarget) { + Schema schema, Factory>> getterFactory, T getterTarget) { super(schema); this.getterTarget = getterTarget; this.getters = getterFactory.create(TypeDescriptor.of(getterTarget.getClass()), schema); @@ -56,7 +56,7 @@ public class RowWithGetters extends Row { @Override @SuppressWarnings({"TypeParameterUnusedInFormals", "unchecked"}) - public @Nullable T getValue(int fieldIdx) { + public W getValue(int fieldIdx) { Field field = getSchema().getField(fieldIdx); boolean cacheField = cacheFieldType(field); @@ -64,7 +64,7 @@ public class RowWithGetters extends Row { cache = new TreeMap<>(); } - Object fieldValue; + @Nullable Object fieldValue; if (cacheField) { if (cache == null) { cache = new TreeMap<>(); @@ -72,15 +72,12 @@ public class RowWithGetters extends Row { fieldValue = cache.computeIfAbsent( fieldIdx, - new Function() { + new Function() { @Override - public Object apply(Integer idx) { - FieldValueGetter getter = getters.get(idx); + public @Nullable Object apply(Integer idx) { + FieldValueGetter getter = getters.get(idx); checkStateNotNull(getter); - @SuppressWarnings("nullness") - @NonNull - Object value = getter.get(getterTarget); - return value; + return getter.get(getterTarget); } }); } else { @@ -90,7 +87,7 @@ public Object apply(Integer idx) { if (fieldValue == null && !field.getType().getNullable()) { throw new RuntimeException("Null value set on non-nullable field " + field); } - return (T) fieldValue; + return (W) fieldValue; } private boolean cacheFieldType(Field field) { @@ -116,7 +113,7 @@ public int getFieldCount() { return rawValues; } - public List getGetters() { + public List> getGetters() { return getters; } diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/FieldValueTypeInformationTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/FieldValueTypeInformationTest.java new file mode 100644 index 000000000000..26e3278df025 --- /dev/null +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/FieldValueTypeInformationTest.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.schemas; + +import static org.junit.Assert.assertEquals; + +import java.util.Map; +import org.apache.beam.sdk.values.TypeDescriptor; +import org.junit.Test; + +public class FieldValueTypeInformationTest { + public static class GenericClass { + public T t; + + public GenericClass(T t) { + this.t = t; + } + + public T getT() { + return t; + } + + public void setT(T t) { + this.t = t; + } + } + + private final TypeDescriptor>> typeDescriptor = + new TypeDescriptor>>() {}; + private final TypeDescriptor> expectedFieldTypeDescriptor = + new TypeDescriptor>() {}; + + @Test + public void testForGetter() throws Exception { + FieldValueTypeInformation actual = + FieldValueTypeInformation.forGetter( + typeDescriptor, GenericClass.class.getMethod("getT"), 0); + assertEquals(expectedFieldTypeDescriptor, actual.getType()); + } + + @Test + public void testForField() throws Exception { + FieldValueTypeInformation actual = + FieldValueTypeInformation.forField(typeDescriptor, GenericClass.class.getField("t"), 0); + assertEquals(expectedFieldTypeDescriptor, actual.getType()); + } + + @Test + public void testForSetter() throws Exception { + FieldValueTypeInformation actual = + FieldValueTypeInformation.forSetter( + typeDescriptor, GenericClass.class.getMethod("setT", Object.class)); + assertEquals(expectedFieldTypeDescriptor, actual.getType()); + } +} diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/utils/JavaBeanUtilsTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/utils/JavaBeanUtilsTest.java index 021e39b84849..7e9cf9a894b9 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/utils/JavaBeanUtilsTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/utils/JavaBeanUtilsTest.java @@ -53,6 +53,7 @@ import org.apache.beam.sdk.schemas.utils.TestJavaBeans.PrimitiveMapBean; import org.apache.beam.sdk.schemas.utils.TestJavaBeans.SimpleBean; import org.apache.beam.sdk.values.TypeDescriptor; +import org.checkerframework.checker.nullness.qual.NonNull; import org.joda.time.DateTime; import org.junit.Test; @@ -142,11 +143,11 @@ public void testGeneratedSimpleGetters() { simpleBean.setBigDecimal(new BigDecimal(42)); simpleBean.setStringBuilder(new StringBuilder("stringBuilder")); - List getters = + List> getters = JavaBeanUtils.getGetters( new TypeDescriptor() {}, SIMPLE_BEAN_SCHEMA, - new JavaBeanSchema.GetterTypeSupplier(), + new GetterTypeSupplier(), new DefaultTypeConversionsFactory()); assertEquals(12, getters.size()); assertEquals("str", getters.get(0).name()); @@ -220,7 +221,7 @@ public void testGeneratedSimpleBoxedGetters() { bean.setaLong(44L); bean.setaBoolean(true); - List getters = + List> getters = JavaBeanUtils.getGetters( new TypeDescriptor() {}, BEAN_WITH_BOXED_FIELDS_SCHEMA, diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/utils/POJOUtilsTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/utils/POJOUtilsTest.java index 723353ed8d15..378cdc06805f 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/utils/POJOUtilsTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/utils/POJOUtilsTest.java @@ -52,6 +52,7 @@ import org.apache.beam.sdk.schemas.utils.TestPOJOs.PrimitiveMapPOJO; import org.apache.beam.sdk.schemas.utils.TestPOJOs.SimplePOJO; import org.apache.beam.sdk.values.TypeDescriptor; +import org.checkerframework.checker.nullness.qual.NonNull; import org.joda.time.DateTime; import org.joda.time.Instant; import org.junit.Test; @@ -158,7 +159,7 @@ public void testGeneratedSimpleGetters() { new BigDecimal(42), new StringBuilder("stringBuilder")); - List getters = + List> getters = POJOUtils.getGetters( new TypeDescriptor() {}, SIMPLE_POJO_SCHEMA, @@ -184,7 +185,7 @@ public void testGeneratedSimpleGetters() { @Test public void testGeneratedSimpleSetters() { SimplePOJO simplePojo = new SimplePOJO(); - List setters = + List> setters = POJOUtils.getSetters( new TypeDescriptor() {}, SIMPLE_POJO_SCHEMA, @@ -223,7 +224,7 @@ public void testGeneratedSimpleSetters() { public void testGeneratedSimpleBoxedGetters() { POJOWithBoxedFields pojo = new POJOWithBoxedFields((byte) 41, (short) 42, 43, 44L, true); - List getters = + List> getters = POJOUtils.getGetters( new TypeDescriptor() {}, POJO_WITH_BOXED_FIELDS_SCHEMA, @@ -239,7 +240,7 @@ public void testGeneratedSimpleBoxedGetters() { @Test public void testGeneratedSimpleBoxedSetters() { POJOWithBoxedFields pojo = new POJOWithBoxedFields(); - List setters = + List> setters = POJOUtils.getSetters( new TypeDescriptor() {}, POJO_WITH_BOXED_FIELDS_SCHEMA, @@ -262,7 +263,7 @@ public void testGeneratedSimpleBoxedSetters() { @Test public void testGeneratedByteBufferSetters() { POJOWithByteArray pojo = new POJOWithByteArray(); - List setters = + List> setters = POJOUtils.getSetters( new TypeDescriptor() {}, POJO_WITH_BYTE_ARRAY_SCHEMA, diff --git a/sdks/java/extensions/arrow/src/main/java/org/apache/beam/sdk/extensions/arrow/ArrowConversion.java b/sdks/java/extensions/arrow/src/main/java/org/apache/beam/sdk/extensions/arrow/ArrowConversion.java index 4b6538157fd0..78ba610ad4d1 100644 --- a/sdks/java/extensions/arrow/src/main/java/org/apache/beam/sdk/extensions/arrow/ArrowConversion.java +++ b/sdks/java/extensions/arrow/src/main/java/org/apache/beam/sdk/extensions/arrow/ArrowConversion.java @@ -276,11 +276,11 @@ public static class RecordBatchRowIterator implements Iterator, AutoCloseab new ArrowValueConverterVisitor(); private final Schema schema; private final VectorSchemaRoot vectorSchemaRoot; - private final Factory> fieldValueGetters; + private final Factory>> fieldValueGetters; private Integer currRowIndex; private static class FieldVectorListValueGetterFactory - implements Factory> { + implements Factory>> { private final List fieldVectors; static FieldVectorListValueGetterFactory of(List fieldVectors) { @@ -292,7 +292,8 @@ private FieldVectorListValueGetterFactory(List fieldVectors) { } @Override - public List create(TypeDescriptor typeDescriptor, Schema schema) { + public List> create( + TypeDescriptor typeDescriptor, Schema schema) { return this.fieldVectors.stream() .map( (fieldVector) -> { diff --git a/sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/schemas/AvroRecordSchema.java b/sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/schemas/AvroRecordSchema.java index e75647a2ccfa..203bcccbf562 100644 --- a/sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/schemas/AvroRecordSchema.java +++ b/sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/schemas/AvroRecordSchema.java @@ -26,6 +26,7 @@ import org.apache.beam.sdk.schemas.SchemaProvider; import org.apache.beam.sdk.schemas.SchemaUserTypeCreator; import org.apache.beam.sdk.values.TypeDescriptor; +import org.checkerframework.checker.nullness.qual.NonNull; /** * A {@link SchemaProvider} for AVRO generated SpecificRecords and POJOs. @@ -44,8 +45,8 @@ public Schema schemaFor(TypeDescriptor typeDescriptor) { } @Override - public List fieldValueGetters( - TypeDescriptor targetTypeDescriptor, Schema schema) { + public List> fieldValueGetters( + TypeDescriptor targetTypeDescriptor, Schema schema) { return AvroUtils.getGetters(targetTypeDescriptor, schema); } diff --git a/sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/schemas/utils/AvroUtils.java b/sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/schemas/utils/AvroUtils.java index 1b1c45969307..bfbab6fe87f6 100644 --- a/sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/schemas/utils/AvroUtils.java +++ b/sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/schemas/utils/AvroUtils.java @@ -94,11 +94,13 @@ import org.apache.beam.sdk.values.Row; import org.apache.beam.sdk.values.TypeDescriptor; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.CaseFormat; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Strings; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Maps; +import org.checkerframework.checker.nullness.qual.EnsuresNonNullIf; +import org.checkerframework.checker.nullness.qual.NonNull; import org.checkerframework.checker.nullness.qual.Nullable; +import org.checkerframework.checker.nullness.qual.PolyNull; import org.joda.time.Days; import org.joda.time.Duration; import org.joda.time.Instant; @@ -139,10 +141,7 @@ * * is used. */ -@SuppressWarnings({ - "nullness", // TODO(https://github.com/apache/beam/issues/20497) - "rawtypes" -}) +@SuppressWarnings({"rawtypes"}) public class AvroUtils { private static final ForLoadedType BYTES = new ForLoadedType(byte[].class); private static final ForLoadedType JAVA_INSTANT = new ForLoadedType(java.time.Instant.class); @@ -152,6 +151,38 @@ public class AvroUtils { new ForLoadedType(ReadableInstant.class); private static final ForLoadedType JODA_INSTANT = new ForLoadedType(Instant.class); + // contains workarounds for third-party methods that accept nullable arguments but lack proper + // annotations + private static class NullnessCheckerWorkarounds { + + private static ReflectData newReflectData(Class clazz) { + // getClassLoader returns @Nullable Classloader, but it's ok, as ReflectData constructor + // actually tolerates null classloader argument despite lacking the @Nullable annotation + @SuppressWarnings("nullness") + @NonNull + ClassLoader classLoader = clazz.getClassLoader(); + return new ReflectData(classLoader); + } + + private static void builderSet( + GenericRecordBuilder builder, String fieldName, @Nullable Object value) { + // the value argument can actually be null here, it's not annotated as such in the method + // though, hence this wrapper + builder.set(fieldName, castToNonNull(value)); + } + + private static Object createFixed( + @Nullable Object old, byte[] bytes, org.apache.avro.Schema schema) { + // old is tolerated when null, due to an instanceof check + return GenericData.get().createFixed(castToNonNull(old), bytes, schema); + } + + @SuppressWarnings("nullness") + private static @NonNull T castToNonNull(@Nullable T value) { + return value; + } + } + public static void addLogicalTypeConversions(final GenericData data) { // do not add DecimalConversion by default as schema must have extra 'scale' and 'precision' // properties. avro reflect already handles BigDecimal as string with the 'java-class' property @@ -235,7 +266,9 @@ public static FixedBytesField withSize(int size) { /** Create a {@link FixedBytesField} from a Beam {@link FieldType}. */ public static @Nullable FixedBytesField fromBeamFieldType(FieldType fieldType) { if (fieldType.getTypeName().isLogicalType() - && fieldType.getLogicalType().getIdentifier().equals(FixedBytes.IDENTIFIER)) { + && checkNotNull(fieldType.getLogicalType()) + .getIdentifier() + .equals(FixedBytes.IDENTIFIER)) { int length = fieldType.getLogicalType(FixedBytes.class).getLength(); return new FixedBytesField(length); } else { @@ -264,7 +297,7 @@ public FieldType toBeamType() { /** Convert to an AVRO type. */ public org.apache.avro.Schema toAvroType(String name, String namespace) { - return org.apache.avro.Schema.createFixed(name, null, namespace, size); + return org.apache.avro.Schema.createFixed(name, "", namespace, size); } } @@ -451,8 +484,7 @@ public static Field toBeamField(org.apache.avro.Schema.Field field) { public static org.apache.avro.Schema.Field toAvroField(Field field, String namespace) { org.apache.avro.Schema fieldSchema = getFieldSchema(field.getType(), field.getName(), namespace); - return new org.apache.avro.Schema.Field( - field.getName(), fieldSchema, field.getDescription(), (Object) null); + return new org.apache.avro.Schema.Field(field.getName(), fieldSchema, field.getDescription()); } private AvroUtils() {} @@ -463,7 +495,7 @@ private AvroUtils() {} * @param clazz avro class */ public static Schema toBeamSchema(Class clazz) { - ReflectData data = new ReflectData(clazz.getClassLoader()); + ReflectData data = NullnessCheckerWorkarounds.newReflectData(clazz); return toBeamSchema(data.getSchema(clazz)); } @@ -486,10 +518,17 @@ public static Schema toBeamSchema(org.apache.avro.Schema schema) { return builder.build(); } + @EnsuresNonNullIf( + expression = {"#1"}, + result = false) + private static boolean isNullOrEmpty(@Nullable String str) { + return str == null || str.isEmpty(); + } + /** Converts a Beam Schema into an AVRO schema. */ public static org.apache.avro.Schema toAvroSchema( Schema beamSchema, @Nullable String name, @Nullable String namespace) { - final String schemaName = Strings.isNullOrEmpty(name) ? "topLevelRecord" : name; + final String schemaName = isNullOrEmpty(name) ? "topLevelRecord" : name; final String schemaNamespace = namespace == null ? "" : namespace; String childNamespace = !"".equals(schemaNamespace) ? schemaNamespace + "." + schemaName : schemaName; @@ -498,7 +537,7 @@ public static org.apache.avro.Schema toAvroSchema( org.apache.avro.Schema.Field recordField = toAvroField(field, childNamespace); fields.add(recordField); } - return org.apache.avro.Schema.createRecord(schemaName, null, schemaNamespace, false, fields); + return org.apache.avro.Schema.createRecord(schemaName, "", schemaNamespace, false, fields); } public static org.apache.avro.Schema toAvroSchema(Schema beamSchema) { @@ -557,7 +596,8 @@ public static GenericRecord toGenericRecord( GenericRecordBuilder builder = new GenericRecordBuilder(avroSchema); for (int i = 0; i < beamSchema.getFieldCount(); ++i) { Field field = beamSchema.getField(i); - builder.set( + NullnessCheckerWorkarounds.builderSet( + builder, field.getName(), genericFromBeamField( field.getType(), avroSchema.getField(field.getName()).schema(), row.getValue(i))); @@ -567,7 +607,7 @@ public static GenericRecord toGenericRecord( @SuppressWarnings("unchecked") public static SerializableFunction getToRowFunction( - Class clazz, org.apache.avro.@Nullable Schema schema) { + Class clazz, org.apache.avro.Schema schema) { if (GenericRecord.class.equals(clazz)) { Schema beamSchema = toBeamSchema(schema); return (SerializableFunction) getGenericRecordToRowFunction(beamSchema); @@ -662,9 +702,9 @@ public static SerializableFunction getGenericRecordToRowFunc } private static class GenericRecordToRowFn implements SerializableFunction { - private final Schema schema; + private final @Nullable Schema schema; - GenericRecordToRowFn(Schema schema) { + GenericRecordToRowFn(@Nullable Schema schema) { this.schema = schema; } @@ -701,7 +741,7 @@ public static SerializableFunction getRowToGenericRecordFunc } private static class RowToGenericRecordFn implements SerializableFunction { - private transient org.apache.avro.Schema avroSchema; + private transient org.apache.avro.@Nullable Schema avroSchema; RowToGenericRecordFn(org.apache.avro.@Nullable Schema avroSchema) { this.avroSchema = avroSchema; @@ -751,7 +791,8 @@ private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundE public static SchemaCoder schemaCoder(TypeDescriptor type) { @SuppressWarnings("unchecked") Class clazz = (Class) type.getRawType(); - org.apache.avro.Schema avroSchema = new ReflectData(clazz.getClassLoader()).getSchema(clazz); + org.apache.avro.Schema avroSchema = + NullnessCheckerWorkarounds.newReflectData(clazz).getSchema(clazz); Schema beamSchema = toBeamSchema(avroSchema); return SchemaCoder.of( beamSchema, type, getToRowFunction(clazz, avroSchema), getFromRowFunction(clazz)); @@ -790,7 +831,7 @@ public static SchemaCoder schemaCoder(org.apache.avro.Schema sche */ public static SchemaCoder schemaCoder(Class clazz, org.apache.avro.Schema schema) { return SchemaCoder.of( - getSchema(clazz, schema), + checkNotNull(getSchema(clazz, schema)), TypeDescriptor.of(clazz), getToRowFunction(clazz, schema), getFromRowFunction(clazz)); @@ -821,7 +862,7 @@ public List get(TypeDescriptor typeDescriptor, Sch Method method = methods.get(i); if (ReflectUtils.isGetter(method)) { FieldValueTypeInformation fieldValueTypeInformation = - FieldValueTypeInformation.forGetter(method, i); + FieldValueTypeInformation.forGetter(typeDescriptor, method, i); String name = mapping.get(fieldValueTypeInformation.getName()); if (name != null) { types.add(fieldValueTypeInformation.withName(name)); @@ -871,7 +912,8 @@ public List get(TypeDescriptor typeDescriptor) { for (int i = 0; i < classFields.size(); ++i) { java.lang.reflect.Field f = classFields.get(i); if (!f.isAnnotationPresent(AvroIgnore.class)) { - FieldValueTypeInformation typeInformation = FieldValueTypeInformation.forField(f, i); + FieldValueTypeInformation typeInformation = + FieldValueTypeInformation.forField(typeDescriptor, f, i); AvroName avroname = f.getAnnotation(AvroName.class); if (avroname != null) { typeInformation = typeInformation.withName(avroname.value()); @@ -895,7 +937,7 @@ public static List getFieldTypes( } /** Get generated getters for an AVRO-generated SpecificRecord or a POJO. */ - public static List getGetters( + public static List> getGetters( TypeDescriptor typeDescriptor, Schema schema) { if (typeDescriptor.isSubtypeOf(TypeDescriptor.of(SpecificRecord.class))) { return JavaBeanUtils.getGetters( @@ -968,7 +1010,7 @@ private static FieldType toFieldType(TypeWithNullability type) { break; case FIXED: - fieldType = FixedBytesField.fromAvroType(type.type).toBeamType(); + fieldType = checkNotNull(FixedBytesField.fromAvroType(type.type)).toBeamType(); break; case STRING: @@ -1066,7 +1108,8 @@ private static org.apache.avro.Schema getFieldSchema( break; case LOGICAL_TYPE: - String identifier = fieldType.getLogicalType().getIdentifier(); + Schema.LogicalType logicalType = checkNotNull(fieldType.getLogicalType()); + String identifier = logicalType.getIdentifier(); if (FixedBytes.IDENTIFIER.equals(identifier)) { FixedBytesField fixedBytesField = checkNotNull(FixedBytesField.fromBeamFieldType(fieldType)); @@ -1077,15 +1120,13 @@ private static org.apache.avro.Schema getFieldSchema( } else if (FixedString.IDENTIFIER.equals(identifier) || "CHAR".equals(identifier) || "NCHAR".equals(identifier)) { - baseType = - buildHiveLogicalTypeSchema("char", (int) fieldType.getLogicalType().getArgument()); + baseType = buildHiveLogicalTypeSchema("char", checkNotNull(logicalType.getArgument())); } else if (VariableString.IDENTIFIER.equals(identifier) || "NVARCHAR".equals(identifier) || "VARCHAR".equals(identifier) || "LONGNVARCHAR".equals(identifier) || "LONGVARCHAR".equals(identifier)) { - baseType = - buildHiveLogicalTypeSchema("varchar", (int) fieldType.getLogicalType().getArgument()); + baseType = buildHiveLogicalTypeSchema("varchar", checkNotNull(logicalType.getArgument())); } else if (EnumerationType.IDENTIFIER.equals(identifier)) { EnumerationType enumerationType = fieldType.getLogicalType(EnumerationType.class); baseType = @@ -1103,7 +1144,7 @@ private static org.apache.avro.Schema getFieldSchema( baseType = LogicalTypes.timeMillis().addToSchema(org.apache.avro.Schema.create(Type.INT)); } else { throw new RuntimeException( - "Unhandled logical type " + fieldType.getLogicalType().getIdentifier()); + "Unhandled logical type " + checkNotNull(fieldType.getLogicalType()).getIdentifier()); } break; @@ -1111,22 +1152,23 @@ private static org.apache.avro.Schema getFieldSchema( case ITERABLE: baseType = org.apache.avro.Schema.createArray( - getFieldSchema(fieldType.getCollectionElementType(), fieldName, namespace)); + getFieldSchema( + checkNotNull(fieldType.getCollectionElementType()), fieldName, namespace)); break; case MAP: - if (fieldType.getMapKeyType().getTypeName().isStringType()) { + if (checkNotNull(fieldType.getMapKeyType()).getTypeName().isStringType()) { // Avro only supports string keys in maps. baseType = org.apache.avro.Schema.createMap( - getFieldSchema(fieldType.getMapValueType(), fieldName, namespace)); + getFieldSchema(checkNotNull(fieldType.getMapValueType()), fieldName, namespace)); } else { throw new IllegalArgumentException("Avro only supports maps with string keys"); } break; case ROW: - baseType = toAvroSchema(fieldType.getRowSchema(), fieldName, namespace); + baseType = toAvroSchema(checkNotNull(fieldType.getRowSchema()), fieldName, namespace); break; default: @@ -1167,7 +1209,9 @@ private static org.apache.avro.Schema getFieldSchema( case DECIMAL: BigDecimal decimal = (BigDecimal) value; LogicalType logicalType = typeWithNullability.type.getLogicalType(); - return new Conversions.DecimalConversion().toBytes(decimal, null, logicalType); + @SuppressWarnings("nullness") + ByteBuffer result = new Conversions.DecimalConversion().toBytes(decimal, null, logicalType); + return result; case DATETIME: if (typeWithNullability.type.getType() == Type.INT) { @@ -1185,7 +1229,7 @@ private static org.apache.avro.Schema getFieldSchema( return ByteBuffer.wrap((byte[]) value); case LOGICAL_TYPE: - String identifier = fieldType.getLogicalType().getIdentifier(); + String identifier = checkNotNull(fieldType.getLogicalType()).getIdentifier(); if (FixedBytes.IDENTIFIER.equals(identifier)) { FixedBytesField fixedBytesField = checkNotNull(FixedBytesField.fromBeamFieldType(fieldType)); @@ -1193,9 +1237,11 @@ private static org.apache.avro.Schema getFieldSchema( if (byteArray.length != fixedBytesField.getSize()) { throw new IllegalArgumentException("Incorrectly sized byte array."); } - return GenericData.get().createFixed(null, (byte[]) value, typeWithNullability.type); + return NullnessCheckerWorkarounds.createFixed( + null, (byte[]) value, typeWithNullability.type); } else if (VariableBytes.IDENTIFIER.equals(identifier)) { - return GenericData.get().createFixed(null, (byte[]) value, typeWithNullability.type); + return NullnessCheckerWorkarounds.createFixed( + null, (byte[]) value, typeWithNullability.type); } else if (FixedString.IDENTIFIER.equals(identifier) || "CHAR".equals(identifier) || "NCHAR".equals(identifier)) { @@ -1239,26 +1285,27 @@ private static org.apache.avro.Schema getFieldSchema( case ARRAY: case ITERABLE: Iterable iterable = (Iterable) value; - List translatedArray = Lists.newArrayListWithExpectedSize(Iterables.size(iterable)); + List<@Nullable Object> translatedArray = + Lists.newArrayListWithExpectedSize(Iterables.size(iterable)); for (Object arrayElement : iterable) { translatedArray.add( genericFromBeamField( - fieldType.getCollectionElementType(), + checkNotNull(fieldType.getCollectionElementType()), typeWithNullability.type.getElementType(), arrayElement)); } return translatedArray; case MAP: - Map map = Maps.newHashMap(); + Map map = Maps.newHashMap(); Map valueMap = (Map) value; for (Map.Entry entry : valueMap.entrySet()) { - Utf8 key = new Utf8((String) entry.getKey()); + Utf8 key = new Utf8((String) checkNotNull(entry.getKey())); map.put( key, genericFromBeamField( - fieldType.getMapValueType(), + checkNotNull(fieldType.getMapValueType()), typeWithNullability.type.getValueType(), entry.getValue())); } @@ -1282,8 +1329,8 @@ private static org.apache.avro.Schema getFieldSchema( * @return value converted for {@link Row} */ @SuppressWarnings("unchecked") - public static @Nullable Object convertAvroFieldStrict( - @Nullable Object value, + public static @PolyNull Object convertAvroFieldStrict( + @PolyNull Object value, @Nonnull org.apache.avro.Schema avroSchema, @Nonnull FieldType fieldType) { if (value == null) { @@ -1383,7 +1430,8 @@ private static Object convertBytesStrict(ByteBuffer bb, FieldType fieldType) { private static Object convertFixedStrict(GenericFixed fixed, FieldType fieldType) { checkTypeName(fieldType.getTypeName(), TypeName.LOGICAL_TYPE, "fixed"); - checkArgument(FixedBytes.IDENTIFIER.equals(fieldType.getLogicalType().getIdentifier())); + checkArgument( + FixedBytes.IDENTIFIER.equals(checkNotNull(fieldType.getLogicalType()).getIdentifier())); return fixed.bytes().clone(); // clone because GenericFixed is mutable } @@ -1434,7 +1482,10 @@ private static Object convertBooleanStrict(Boolean value, FieldType fieldType) { private static Object convertEnumStrict(Object value, FieldType fieldType) { checkTypeName(fieldType.getTypeName(), TypeName.LOGICAL_TYPE, "enum"); - checkArgument(fieldType.getLogicalType().getIdentifier().equals(EnumerationType.IDENTIFIER)); + checkArgument( + checkNotNull(fieldType.getLogicalType()) + .getIdentifier() + .equals(EnumerationType.IDENTIFIER)); EnumerationType enumerationType = fieldType.getLogicalType(EnumerationType.class); return enumerationType.valueOf(value.toString()); } @@ -1442,7 +1493,8 @@ private static Object convertEnumStrict(Object value, FieldType fieldType) { private static Object convertUnionStrict( Object value, org.apache.avro.Schema unionAvroSchema, FieldType fieldType) { checkTypeName(fieldType.getTypeName(), TypeName.LOGICAL_TYPE, "oneOfType"); - checkArgument(fieldType.getLogicalType().getIdentifier().equals(OneOfType.IDENTIFIER)); + checkArgument( + checkNotNull(fieldType.getLogicalType()).getIdentifier().equals(OneOfType.IDENTIFIER)); OneOfType oneOfType = fieldType.getLogicalType(OneOfType.class); int fieldNumber = GenericData.get().resolveUnion(unionAvroSchema, value); FieldType baseFieldType = oneOfType.getOneOfSchema().getField(fieldNumber).getType(); @@ -1459,7 +1511,7 @@ private static Object convertArrayStrict( FieldType elemFieldType = fieldType.getCollectionElementType(); for (Object value : values) { - ret.add(convertAvroFieldStrict(value, elemAvroSchema, elemFieldType)); + ret.add(convertAvroFieldStrict(value, elemAvroSchema, checkNotNull(elemFieldType))); } return ret; @@ -1470,10 +1522,10 @@ private static Object convertMapStrict( org.apache.avro.Schema valueAvroSchema, FieldType fieldType) { checkTypeName(fieldType.getTypeName(), TypeName.MAP, "map"); - checkNotNull(fieldType.getMapKeyType()); - checkNotNull(fieldType.getMapValueType()); + FieldType mapKeyType = checkNotNull(fieldType.getMapKeyType()); + FieldType mapValueType = checkNotNull(fieldType.getMapValueType()); - if (!fieldType.getMapKeyType().equals(FieldType.STRING)) { + if (!FieldType.STRING.equals(fieldType.getMapKeyType())) { throw new IllegalArgumentException( "Can't convert 'string' map keys to " + fieldType.getMapKeyType()); } @@ -1482,8 +1534,8 @@ private static Object convertMapStrict( for (Map.Entry value : values.entrySet()) { ret.put( - convertStringStrict(value.getKey(), fieldType.getMapKeyType()), - convertAvroFieldStrict(value.getValue(), valueAvroSchema, fieldType.getMapValueType())); + convertStringStrict(value.getKey(), mapKeyType), + convertAvroFieldStrict(value.getValue(), valueAvroSchema, mapValueType)); } return ret; diff --git a/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoByteBuddyUtils.java b/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoByteBuddyUtils.java index d159e9de44a8..9fe6162ec936 100644 --- a/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoByteBuddyUtils.java +++ b/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoByteBuddyUtils.java @@ -104,16 +104,14 @@ import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Strings; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Verify; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Maps; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Multimap; +import org.checkerframework.checker.nullness.qual.NonNull; import org.checkerframework.checker.nullness.qual.Nullable; -@SuppressWarnings({ - "rawtypes", // TODO(https://github.com/apache/beam/issues/20447) - "nullness" // TODO(https://github.com/apache/beam/issues/20497) -}) class ProtoByteBuddyUtils { private static final ByteBuddy BYTE_BUDDY = new ByteBuddy(); private static final TypeDescriptor BYTE_STRING_TYPE_DESCRIPTOR = @@ -270,7 +268,7 @@ static class ProtoConvertType extends ConvertType { .build(); @Override - public Type convert(TypeDescriptor typeDescriptor) { + public Type convert(TypeDescriptor typeDescriptor) { if (typeDescriptor.equals(BYTE_STRING_TYPE_DESCRIPTOR) || typeDescriptor.isSubtypeOf(BYTE_STRING_TYPE_DESCRIPTOR)) { return byte[].class; @@ -297,7 +295,7 @@ protected ProtoTypeConversionsFactory getFactory() { } @Override - public StackManipulation convert(TypeDescriptor type) { + public StackManipulation convert(TypeDescriptor type) { if (type.equals(BYTE_STRING_TYPE_DESCRIPTOR) || type.isSubtypeOf(BYTE_STRING_TYPE_DESCRIPTOR)) { return new Compound( @@ -372,7 +370,7 @@ protected ProtoTypeConversionsFactory getFactory() { } @Override - public StackManipulation convert(TypeDescriptor type) { + public StackManipulation convert(TypeDescriptor type) { if (type.isSubtypeOf(TypeDescriptor.of(ByteString.class))) { return new Compound( readValue, @@ -459,7 +457,7 @@ public TypeConversion createSetterConversions(StackManipulati // The list of getters for a class is cached, so we only create the classes the first time // getSetters is called. - private static final Map> CACHED_GETTERS = + private static final Map>> CACHED_GETTERS = Maps.newConcurrentMap(); /** @@ -467,35 +465,36 @@ public TypeConversion createSetterConversions(StackManipulati * *

The returned list is ordered by the order of fields in the schema. */ - public static List getGetters( - Class clazz, + public static List> getGetters( + Class clazz, Schema schema, FieldValueTypeSupplier fieldValueTypeSupplier, TypeConversionsFactory typeConversionsFactory) { Multimap methods = ReflectUtils.getMethodsMap(clazz); - return CACHED_GETTERS.computeIfAbsent( - ClassWithSchema.create(clazz, schema), - c -> { - List types = - fieldValueTypeSupplier.get(TypeDescriptor.of(clazz), schema); - return types.stream() - .map( - t -> - createGetter( - t, - typeConversionsFactory, - clazz, - methods, - schema.getField(t.getName()), - fieldValueTypeSupplier)) - .collect(Collectors.toList()); - }); + return (List) + CACHED_GETTERS.computeIfAbsent( + ClassWithSchema.create(clazz, schema), + c -> { + List types = + fieldValueTypeSupplier.get(TypeDescriptor.of(clazz), schema); + return types.stream() + .map( + t -> + createGetter( + t, + typeConversionsFactory, + clazz, + methods, + schema.getField(t.getName()), + fieldValueTypeSupplier)) + .collect(Collectors.toList()); + }); } - static FieldValueGetter createOneOfGetter( + static FieldValueGetter<@NonNull ProtoT, OneOfType.Value> createOneOfGetter( FieldValueTypeInformation typeInformation, - TreeMap> getterMethodMap, - Class protoClass, + TreeMap> getterMethodMap, + Class protoClass, OneOfType oneOfType, Method getCaseMethod) { Set indices = getterMethodMap.keySet(); @@ -505,7 +504,7 @@ static FieldValueGetter createOneOfGetter( int[] keys = getterMethodMap.keySet().stream().mapToInt(Integer::intValue).toArray(); - DynamicType.Builder builder = + DynamicType.Builder> builder = ByteBuddyUtils.subclassGetterInterface(BYTE_BUDDY, protoClass, OneOfType.Value.class); builder = builder @@ -514,7 +513,8 @@ static FieldValueGetter createOneOfGetter( .method(ElementMatchers.named("get")) .intercept(new OneOfGetterInstruction(contiguous, keys, getCaseMethod)); - List getters = Lists.newArrayList(getterMethodMap.values()); + List> getters = + Lists.newArrayList(getterMethodMap.values()); builder = builder // Store a field with the list of individual getters. The get() instruction will pick @@ -556,12 +556,12 @@ static FieldValueGetter createOneOfGetter( FieldValueSetter createOneOfSetter( String name, TreeMap> setterMethodMap, - Class protoBuilderClass) { + Class protoBuilderClass) { Set indices = setterMethodMap.keySet(); boolean contiguous = isContiguous(indices); int[] keys = setterMethodMap.keySet().stream().mapToInt(Integer::intValue).toArray(); - DynamicType.Builder builder = + DynamicType.Builder> builder = ByteBuddyUtils.subclassSetterInterface( BYTE_BUDDY, protoBuilderClass, OneOfType.Value.class); builder = @@ -585,7 +585,8 @@ FieldValueSetter createOneOfSetter( .withParameters(List.class) .intercept(new OneOfSetterConstructor()); - List setters = Lists.newArrayList(setterMethodMap.values()); + List> setters = + Lists.newArrayList(setterMethodMap.values()); try { return builder .visit(new AsmVisitorWrapper.ForDeclaredMethods().writerFlags(ClassWriter.COMPUTE_FRAMES)) @@ -947,10 +948,10 @@ public ByteCodeAppender appender(final Target implementationTarget) { } } - private static FieldValueGetter createGetter( + private static FieldValueGetter<@NonNull ProtoT, ?> createGetter( FieldValueTypeInformation fieldValueTypeInformation, TypeConversionsFactory typeConversionsFactory, - Class clazz, + Class clazz, Multimap methods, Field field, FieldValueTypeSupplier fieldValueTypeSupplier) { @@ -964,21 +965,23 @@ private static FieldValueGetter createGetter( field.getName() + "_case", FieldType.logicalType(oneOfType.getCaseEnumType())); // Create a map of case enum value to getter. This must be sorted, so store in a TreeMap. - TreeMap> oneOfGetters = Maps.newTreeMap(); + TreeMap> oneOfGetters = + Maps.newTreeMap(); Map oneOfFieldTypes = fieldValueTypeSupplier.get(TypeDescriptor.of(clazz), oneOfType.getOneOfSchema()).stream() .collect(Collectors.toMap(FieldValueTypeInformation::getName, f -> f)); for (Field oneOfField : oneOfType.getOneOfSchema().getFields()) { int protoFieldIndex = getFieldNumber(oneOfField); - FieldValueGetter oneOfFieldGetter = + FieldValueGetter<@NonNull ProtoT, ?> oneOfFieldGetter = createGetter( - oneOfFieldTypes.get(oneOfField.getName()), + Verify.verifyNotNull(oneOfFieldTypes.get(oneOfField.getName())), typeConversionsFactory, clazz, methods, oneOfField, fieldValueTypeSupplier); - oneOfGetters.put(protoFieldIndex, oneOfFieldGetter); + oneOfGetters.put( + protoFieldIndex, (FieldValueGetter<@NonNull ProtoT, OneOfType.Value>) oneOfFieldGetter); } return createOneOfGetter( fieldValueTypeInformation, oneOfGetters, clazz, oneOfType, caseMethod); @@ -987,10 +990,11 @@ private static FieldValueGetter createGetter( } } - private static Class getProtoGeneratedBuilder(Class clazz) { + private static @Nullable Class getProtoGeneratedBuilder( + Class clazz) { String builderClassName = clazz.getName() + "$Builder"; try { - return Class.forName(builderClassName); + return (Class) Class.forName(builderClassName); } catch (ClassNotFoundException e) { return null; } @@ -1018,51 +1022,59 @@ static Method getProtoGetter(Multimap methods, String name, Fiel public static @Nullable SchemaUserTypeCreator getBuilderCreator( - Class protoClass, Schema schema, FieldValueTypeSupplier fieldValueTypeSupplier) { - Class builderClass = getProtoGeneratedBuilder(protoClass); + TypeDescriptor protoTypeDescriptor, + Schema schema, + FieldValueTypeSupplier fieldValueTypeSupplier) { + Class builderClass = getProtoGeneratedBuilder(protoTypeDescriptor.getRawType()); if (builderClass == null) { return null; } Multimap methods = ReflectUtils.getMethodsMap(builderClass); List> setters = schema.getFields().stream() - .map(f -> getProtoFieldValueSetter(f, methods, builderClass)) + .map(f -> getProtoFieldValueSetter(protoTypeDescriptor, f, methods, builderClass)) .collect(Collectors.toList()); - return createBuilderCreator(protoClass, builderClass, setters, schema); + return createBuilderCreator(protoTypeDescriptor.getRawType(), builderClass, setters, schema); } private static FieldValueSetter getProtoFieldValueSetter( - Field field, Multimap methods, Class builderClass) { + TypeDescriptor typeDescriptor, + Field field, + Multimap methods, + Class builderClass) { if (field.getType().isLogicalType(OneOfType.IDENTIFIER)) { OneOfType oneOfType = field.getType().getLogicalType(OneOfType.class); TreeMap> oneOfSetters = Maps.newTreeMap(); for (Field oneOfField : oneOfType.getOneOfSchema().getFields()) { - FieldValueSetter setter = getProtoFieldValueSetter(oneOfField, methods, builderClass); + FieldValueSetter setter = + getProtoFieldValueSetter(typeDescriptor, oneOfField, methods, builderClass); oneOfSetters.put(getFieldNumber(oneOfField), setter); } return createOneOfSetter(field.getName(), oneOfSetters, builderClass); } else { Method method = getProtoSetter(methods, field.getName(), field.getType()); return JavaBeanUtils.createSetter( - FieldValueTypeInformation.forSetter(method, protoSetterPrefix(field.getType())), + FieldValueTypeInformation.forSetter( + typeDescriptor, method, protoSetterPrefix(field.getType())), new ProtoTypeConversionsFactory()); } } static SchemaUserTypeCreator createBuilderCreator( Class protoClass, - Class builderClass, + Class builderClass, List> setters, Schema schema) { try { - DynamicType.Builder builder = - BYTE_BUDDY - .with(new InjectPackageStrategy(builderClass)) - .subclass(Supplier.class) - .method(ElementMatchers.named("get")) - .intercept(new BuilderSupplier(protoClass)); - Supplier supplier = + DynamicType.Builder> builder = + (DynamicType.Builder) + BYTE_BUDDY + .with(new InjectPackageStrategy(builderClass)) + .subclass(Supplier.class) + .method(ElementMatchers.named("get")) + .intercept(new BuilderSupplier(protoClass)); + Supplier supplier = builder .visit( new AsmVisitorWrapper.ForDeclaredMethods() diff --git a/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoMessageSchema.java b/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoMessageSchema.java index faf3ad407af5..b0bb9071524b 100644 --- a/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoMessageSchema.java +++ b/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoMessageSchema.java @@ -43,12 +43,9 @@ import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Maps; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Multimap; +import org.checkerframework.checker.nullness.qual.NonNull; import org.checkerframework.checker.nullness.qual.Nullable; -@SuppressWarnings({ - "nullness", // TODO(https://github.com/apache/beam/issues/20497) - "rawtypes" // TODO(https://github.com/apache/beam/issues/20447) -}) public class ProtoMessageSchema extends GetterBasedSchemaProviderV2 { private static final class ProtoClassFieldValueTypeSupplier implements FieldValueTypeSupplier { @@ -72,7 +69,8 @@ public List get(TypeDescriptor typeDescriptor, Sch Method method = getProtoGetter(methods, oneOfField.getName(), oneOfField.getType()); oneOfTypes.put( oneOfField.getName(), - FieldValueTypeInformation.forGetter(method, i).withName(field.getName())); + FieldValueTypeInformation.forGetter(typeDescriptor, method, i) + .withName(field.getName())); } // Add an entry that encapsulates information about all possible getters. types.add( @@ -82,7 +80,9 @@ public List get(TypeDescriptor typeDescriptor, Sch } else { // This is a simple field. Add the getter. Method method = getProtoGetter(methods, field.getName(), field.getType()); - types.add(FieldValueTypeInformation.forGetter(method, i).withName(field.getName())); + types.add( + FieldValueTypeInformation.forGetter(typeDescriptor, method, i) + .withName(field.getName())); } } return types; @@ -96,8 +96,8 @@ public List get(TypeDescriptor typeDescriptor, Sch } @Override - public List fieldValueGetters( - TypeDescriptor targetTypeDescriptor, Schema schema) { + public List> fieldValueGetters( + TypeDescriptor targetTypeDescriptor, Schema schema) { return ProtoByteBuddyUtils.getGetters( targetTypeDescriptor.getRawType(), schema, @@ -117,7 +117,7 @@ public SchemaUserTypeCreator schemaTypeCreator( TypeDescriptor targetTypeDescriptor, Schema schema) { SchemaUserTypeCreator creator = ProtoByteBuddyUtils.getBuilderCreator( - targetTypeDescriptor.getRawType(), schema, new ProtoClassFieldValueTypeSupplier()); + targetTypeDescriptor, schema, new ProtoClassFieldValueTypeSupplier()); if (creator == null) { throw new RuntimeException("Cannot create creator for " + targetTypeDescriptor); } @@ -152,7 +152,8 @@ public static SimpleFunction getRowToProtoBytesFn(Class claz private void checkForDynamicType(TypeDescriptor typeDescriptor) { if (typeDescriptor.getRawType().equals(DynamicMessage.class)) { throw new RuntimeException( - "DynamicMessage is not allowed for the standard ProtoSchemaProvider, use ProtoDynamicMessageSchema instead."); + "DynamicMessage is not allowed for the standard ProtoSchemaProvider, use" + + " ProtoDynamicMessageSchema instead."); } } diff --git a/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/schemas/AwsSchemaProvider.java b/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/schemas/AwsSchemaProvider.java index acdfcfc1ad09..e8b05a8a319e 100644 --- a/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/schemas/AwsSchemaProvider.java +++ b/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/schemas/AwsSchemaProvider.java @@ -19,7 +19,6 @@ import static java.util.function.Function.identity; import static java.util.stream.Collectors.toMap; -import static org.apache.beam.sdk.io.aws2.schemas.AwsSchemaUtils.getter; import static org.apache.beam.sdk.util.Preconditions.checkStateNotNull; import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkState; import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Sets.difference; @@ -46,6 +45,7 @@ import org.apache.beam.sdk.values.RowWithGetters; import org.apache.beam.sdk.values.TypeDescriptor; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Maps; +import org.checkerframework.checker.nullness.qual.NonNull; import org.checkerframework.checker.nullness.qual.Nullable; import software.amazon.awssdk.core.SdkField; import software.amazon.awssdk.core.SdkPojo; @@ -73,17 +73,20 @@ public class AwsSchemaProvider extends GetterBasedSchemaProviderV2 { return AwsTypes.schemaFor(sdkFields((Class) type.getRawType())); } - @SuppressWarnings("rawtypes") @Override - public List fieldValueGetters( - TypeDescriptor targetTypeDescriptor, Schema schema) { + public List> fieldValueGetters( + TypeDescriptor targetTypeDescriptor, Schema schema) { ConverterFactory fromAws = ConverterFactory.fromAws(); Map> sdkFields = sdkFieldsByName((Class) targetTypeDescriptor.getRawType()); - List getters = new ArrayList<>(schema.getFieldCount()); - for (String field : schema.getFieldNames()) { + List> getters = new ArrayList<>(schema.getFieldCount()); + for (@NonNull String field : schema.getFieldNames()) { SdkField sdkField = checkStateNotNull(sdkFields.get(field), "Unknown field"); - getters.add(getter(field, fromAws.create(sdkField::getValueOrDefault, sdkField))); + getters.add( + AwsSchemaUtils.getter( + field, + (SerializableFunction<@NonNull T, Object>) + fromAws.create(sdkField::getValueOrDefault, sdkField))); } return getters; } diff --git a/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/schemas/AwsSchemaUtils.java b/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/schemas/AwsSchemaUtils.java index d36c197d80a4..9e994702fe61 100644 --- a/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/schemas/AwsSchemaUtils.java +++ b/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/schemas/AwsSchemaUtils.java @@ -33,6 +33,7 @@ import org.apache.beam.sdk.schemas.utils.ByteBuddyUtils; import org.apache.beam.sdk.transforms.SerializableFunction; import org.apache.beam.sdk.util.common.ReflectHelpers; +import org.checkerframework.checker.nullness.qual.NonNull; import org.checkerframework.checker.nullness.qual.Nullable; import software.amazon.awssdk.core.SdkPojo; import software.amazon.awssdk.utils.builder.SdkBuilder; @@ -78,7 +79,7 @@ static SdkBuilderSetter setter(String name, BiConsumer, Object> return new ValueSetter(name, setter); } - static FieldValueGetter getter( + static FieldValueGetter getter( String name, SerializableFunction getter) { return new ValueGetter<>(name, getter); } @@ -107,7 +108,8 @@ public String name() { } } - private static class ValueGetter implements FieldValueGetter { + private static class ValueGetter + implements FieldValueGetter { private final SerializableFunction getter; private final String name; diff --git a/sdks/java/io/thrift/src/main/java/org/apache/beam/sdk/io/thrift/ThriftSchema.java b/sdks/java/io/thrift/src/main/java/org/apache/beam/sdk/io/thrift/ThriftSchema.java index 5f4e195f227f..3094ea47d6ad 100644 --- a/sdks/java/io/thrift/src/main/java/org/apache/beam/sdk/io/thrift/ThriftSchema.java +++ b/sdks/java/io/thrift/src/main/java/org/apache/beam/sdk/io/thrift/ThriftSchema.java @@ -202,10 +202,10 @@ private Schema.Field beamField(FieldMetaData fieldDescriptor) { @SuppressWarnings("rawtypes") @Override - public @NonNull List fieldValueGetters( - @NonNull TypeDescriptor targetTypeDescriptor, @NonNull Schema schema) { + public @NonNull List> fieldValueGetters( + @NonNull TypeDescriptor targetTypeDescriptor, @NonNull Schema schema) { return schemaFieldDescriptors(targetTypeDescriptor.getRawType(), schema).keySet().stream() - .map(FieldExtractor::new) + .>map(FieldExtractor::new) .collect(Collectors.toList()); } @@ -242,10 +242,12 @@ private FieldValueTypeInformation fieldValueTypeInfo(Class type, String field if (factoryMethods.size() > 1) { throw new IllegalStateException("Overloaded factory methods: " + factoryMethods); } - return FieldValueTypeInformation.forSetter(factoryMethods.get(0), ""); + return FieldValueTypeInformation.forSetter( + TypeDescriptor.of(type), factoryMethods.get(0), ""); } else { try { - return FieldValueTypeInformation.forField(type.getDeclaredField(fieldName), 0); + return FieldValueTypeInformation.forField( + TypeDescriptor.of(type), type.getDeclaredField(fieldName), 0); } catch (NoSuchFieldException e) { throw new IllegalArgumentException(e); } @@ -373,7 +375,7 @@ private & TEnum> FieldType beamType(FieldValueMetaDat } } - private static class FieldExtractor> + private static class FieldExtractor implements FieldValueGetter { private final FieldT field; @@ -383,8 +385,9 @@ private FieldExtractor(FieldT field) { @Override public @Nullable Object get(T thrift) { - if (!(thrift instanceof TUnion) || thrift.isSet(field)) { - final Object value = thrift.getFieldValue(field); + TBase t = (TBase) thrift; + if (!(thrift instanceof TUnion) || t.isSet(field)) { + final Object value = t.getFieldValue(field); if (value instanceof Enum) { return ((Enum) value).ordinal(); } else {