Skip to content

Commit 326ad8e

Browse files
committed
DynamoDB Enhanced Client Polymorphic Types Support
1 parent 7f787d5 commit 326ad8e

File tree

2 files changed

+578
-257
lines changed

2 files changed

+578
-257
lines changed

services-custom/dynamodb-enhanced/src/main/java/software/amazon/awssdk/enhanced/dynamodb/mapper/StaticPolymorphicTableSchema.java

Lines changed: 129 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -35,100 +35,87 @@
3535

3636
@SdkPublicApi
3737
public final class StaticPolymorphicTableSchema<T> implements TableSchema<T> {
38+
3839
private final TableSchema<T> rootTableSchema;
3940
private final String discriminatorAttributeName;
40-
private final Map<String, StaticSubtype<? extends T>> subtypeByName;
41-
private final List<StaticSubtype<? extends T>> subtypes;
42-
43-
private StaticPolymorphicTableSchema(Builder<T> builder) {
44-
this.rootTableSchema = Validate.paramNotNull(builder.rootTableSchema, "rootTableSchema");
45-
this.discriminatorAttributeName = Validate.notEmpty(builder.discriminatorAttributeName, "discriminatorAttributeName");
46-
Validate.notEmpty(builder.staticSubtypes, "A polymorphic TableSchema must have at least one subtype");
47-
48-
Map<String, StaticSubtype<? extends T>> map = new LinkedHashMap<>();
49-
for (StaticSubtype<? extends T> subtype : builder.staticSubtypes) {
50-
map.compute(subtype.name(), (name, existing) -> {
51-
if (existing != null) {
52-
throw new IllegalArgumentException("Duplicate subtype names are not permitted. [name = \"" + name + "\"]");
53-
}
54-
return subtype;
55-
});
56-
}
57-
58-
this.subtypeByName = Collections.unmodifiableMap(map);
59-
this.subtypes = Collections.unmodifiableList(new ArrayList<>(builder.staticSubtypes));
41+
private final Map<String, StaticSubtype<? extends T>> subtypeByName; // discriminator -> subtype
42+
private final List<StaticSubtype<? extends T>> subtypes; // ordered most-specific -> least-specific
43+
private final boolean allowMissingDiscriminatorFallbackToRoot;
44+
45+
private StaticPolymorphicTableSchema(TableSchema<T> rootTableSchema,
46+
String discriminatorAttributeName,
47+
Map<String, StaticSubtype<? extends T>> subtypeByName,
48+
List<StaticSubtype<? extends T>> subtypes,
49+
boolean allowMissingDiscriminatorFallbackToRoot) {
50+
this.rootTableSchema = rootTableSchema;
51+
this.discriminatorAttributeName = discriminatorAttributeName;
52+
this.subtypeByName = subtypeByName;
53+
this.subtypes = subtypes;
54+
this.allowMissingDiscriminatorFallbackToRoot = allowMissingDiscriminatorFallbackToRoot;
6055
}
6156

62-
public static <T> Builder<T> builder(Class<T> itemClass) {
57+
public static <U> Builder<U> builder(Class<U> itemClass) {
6358
return new Builder<>(itemClass);
6459
}
6560

61+
// Serialization
6662
@Override
6763
public Map<String, AttributeValue> itemToMap(T item, boolean ignoreNulls) {
68-
StaticSubtype<T> subtype = (StaticSubtype<T>) resolveByInstance(item);
69-
T castItem = subtype.tableSchema()
70-
.itemType()
71-
.rawClass()
72-
.cast(item);
64+
StaticSubtype<T> subtype = cast(resolveByInstance(item));
65+
T castItem = subtype.tableSchema().itemType().rawClass().cast(item);
7366

74-
// copy into a mutable map
7567
Map<String, AttributeValue> result = new HashMap<>(subtype.tableSchema().itemToMap(castItem, ignoreNulls));
7668

77-
// inject discriminator
7869
result.put(discriminatorAttributeName, AttributeValue.builder().s(subtype.name()).build());
7970
return result;
8071
}
8172

8273
@Override
8374
public Map<String, AttributeValue> itemToMap(T item, Collection<String> attributes) {
84-
StaticSubtype<T> subtype = (StaticSubtype<T>) resolveByInstance(item);
85-
T castItem = subtype.tableSchema()
86-
.itemType()
87-
.rawClass()
88-
.cast(item);
75+
StaticSubtype<T> subtype = cast(resolveByInstance(item));
76+
T castItem = subtype.tableSchema().itemType().rawClass().cast(item);
8977

90-
// Copy into a mutable map so we can inject the discriminator
91-
Map<String, AttributeValue> result = new HashMap<>(subtype.tableSchema().itemToMap(castItem, attributes));
78+
Map<String, AttributeValue> result =
79+
new HashMap<>(subtype.tableSchema().itemToMap(castItem, attributes));
9280

93-
// Only inject if they explicitly requested the discriminator field
9481
if (attributes.contains(discriminatorAttributeName)) {
9582
result.put(discriminatorAttributeName, AttributeValue.builder().s(subtype.name()).build());
9683
}
97-
9884
return result;
9985
}
10086

87+
// Deserialization
10188
@Override
10289
public T mapToItem(Map<String, AttributeValue> attributeMap) {
10390
String discriminator = Optional.ofNullable(attributeMap.get(discriminatorAttributeName))
104-
.map(AttributeValue::s)
105-
.orElseThrow(() -> new IllegalArgumentException(
106-
"Missing discriminator '" + discriminatorAttributeName + "' in item map"));
91+
.map(AttributeValue::s)
92+
.orElse(null);
93+
94+
if (discriminator == null) {
95+
if (allowMissingDiscriminatorFallbackToRoot) {
96+
// Legacy record (no discriminator) → use root schema
97+
return rootTableSchema.mapToItem(attributeMap);
98+
}
99+
throw new IllegalArgumentException("Missing discriminator '" + discriminatorAttributeName + "' in item map");
100+
}
107101

108102
StaticSubtype<? extends T> subtype = subtypeByName.get(discriminator);
109103
if (subtype == null) {
110104
throw new IllegalArgumentException("Unknown discriminator '" + discriminator + "'");
111105
}
106+
112107
return returnWithSubtypeCast(subtype, ts -> ts.mapToItem(attributeMap));
113108
}
114109

115110
@Override
116111
public AttributeValue attributeValue(T item, String attributeName) {
117-
// If we want to get the discriminator itself, just return it
118112
if (discriminatorAttributeName.equals(attributeName)) {
119-
StaticSubtype<? extends T> raw = resolveByInstance(item);
120-
return AttributeValue.builder().s(raw.name()).build();
113+
StaticSubtype<? extends T> s = resolveByInstance(item);
114+
return AttributeValue.builder().s(s.name()).build();
121115
}
122116

123-
// Otherwise delegate to the concrete subtype
124-
StaticSubtype<T> subtype = (StaticSubtype<T>) resolveByInstance(item);
125-
126-
// Cast the item into the subtype's class
127-
T castItem = subtype.tableSchema()
128-
.itemType()
129-
.rawClass()
130-
.cast(item);
131-
117+
StaticSubtype<T> subtype = cast(resolveByInstance(item));
118+
T castItem = subtype.tableSchema().itemType().rawClass().cast(item);
132119
return subtype.tableSchema().attributeValue(castItem, attributeName);
133120
}
134121

@@ -145,9 +132,15 @@ public TableSchema<? extends T> subtypeTableSchema(T itemContext) {
145132
@Override
146133
public TableSchema<? extends T> subtypeTableSchema(Map<String, AttributeValue> itemContext) {
147134
String discriminator = Optional.ofNullable(itemContext.get(discriminatorAttributeName))
148-
.map(AttributeValue::s)
149-
.orElseThrow(() -> new IllegalArgumentException(
150-
"Missing discriminator '" + discriminatorAttributeName + "' in item map"));
135+
.map(AttributeValue::s)
136+
.orElse(null);
137+
138+
if (discriminator == null) {
139+
if (allowMissingDiscriminatorFallbackToRoot) {
140+
return rootTableSchema;
141+
}
142+
throw new IllegalArgumentException("Missing discriminator '" + discriminatorAttributeName + "' in item map");
143+
}
151144

152145
StaticSubtype<? extends T> subtype = subtypeByName.get(discriminator);
153146
if (subtype == null) {
@@ -182,7 +175,6 @@ private StaticSubtype<? extends T> resolveByInstance(T item) {
182175
return s;
183176
}
184177
}
185-
186178
throw new IllegalArgumentException("Cannot serialize item of type " + item.getClass().getName());
187179
}
188180

@@ -191,43 +183,113 @@ private static <T, S extends T> S returnWithSubtypeCast(StaticSubtype<S> subtype
191183
return subtype.tableSchema().itemType().rawClass().cast(r);
192184
}
193185

186+
@SuppressWarnings("unchecked")
187+
private static <T> StaticSubtype<T> cast(StaticSubtype<? extends T> s) {
188+
return (StaticSubtype<T>) s;
189+
}
190+
194191
public static final class Builder<T> {
195192
private TableSchema<T> rootTableSchema;
196193
private String discriminatorAttributeName;
197-
private List<StaticSubtype<? extends T>> staticSubtypes;
194+
private final List<StaticSubtype<? extends T>> staticSubtypes = new ArrayList<>();
195+
private boolean allowMissingDiscriminatorFallbackToRoot = false;
198196

199-
private Builder(Class<T> itemClass) {
197+
private Builder(Class<T> ignored) {
200198
}
201199

202200
/**
203-
* The root (monomorphic) schema for the supertype.
201+
* Root (non-polymorphic) schema for the supertype.
204202
*/
205203
public Builder<T> rootTableSchema(TableSchema<T> root) {
206204
this.rootTableSchema = root;
207205
return this;
208206
}
209207

210208
/**
211-
* Optional: override the attribute name used for the discriminator. Defaults to `"type"`.
209+
* Discriminator attribute name (defaults to "type").
212210
*/
213211
public Builder<T> discriminatorAttributeName(String name) {
214212
this.discriminatorAttributeName = Validate.notEmpty(name, "discriminatorAttributeName");
215213
return this;
216214
}
217215

218216
/**
219-
* Register one or more (discriminatorValue → subtypeSchema) pairs.
217+
* Register one or more subtypes. Order is not required; we will sort most-specific first.
220218
*/
221-
public Builder<T> addStaticSubtype(StaticSubtype<? extends T>... subs) {
222-
if (this.staticSubtypes == null) {
223-
this.staticSubtypes = new ArrayList<>();
224-
}
219+
@SafeVarargs
220+
public final Builder<T> addStaticSubtype(StaticSubtype<? extends T>... subs) {
225221
Collections.addAll(this.staticSubtypes, subs);
226222
return this;
227223
}
228224

225+
/**
226+
* If true, legacy items without a discriminator are deserialized using the root schema. Defaults to false (strict mode).
227+
*/
228+
public Builder<T> allowMissingDiscriminatorFallbackToRoot(boolean allow) {
229+
this.allowMissingDiscriminatorFallbackToRoot = allow;
230+
return this;
231+
}
232+
229233
public StaticPolymorphicTableSchema<T> build() {
230-
return new StaticPolymorphicTableSchema<>(this);
234+
// Validate required fields
235+
Validate.paramNotNull(rootTableSchema, "rootTableSchema");
236+
Validate.notEmpty(discriminatorAttributeName, "discriminatorAttributeName");
237+
Validate.notEmpty(staticSubtypes, "A polymorphic TableSchema must have at least one subtype");
238+
239+
// Each subtype must be assignable to root
240+
Class<?> root = rootTableSchema.itemType().rawClass();
241+
for (StaticSubtype<? extends T> s : staticSubtypes) {
242+
Class<?> sub = s.tableSchema().itemType().rawClass();
243+
if (!root.isAssignableFrom(sub)) {
244+
throw new IllegalArgumentException(
245+
"Subtype " + sub.getSimpleName() + " is not assignable to " + root.getSimpleName());
246+
}
247+
}
248+
249+
// Build discriminator map with uniqueness check
250+
Map<String, StaticSubtype<? extends T>> byName = new LinkedHashMap<>();
251+
for (StaticSubtype<? extends T> s : staticSubtypes) {
252+
String key = s.name();
253+
if (byName.putIfAbsent(key, s) != null) {
254+
throw new IllegalArgumentException("Duplicate subtype discriminator: " + key);
255+
}
256+
}
257+
258+
// Sort subtypes: deeper (more specific) before shallower
259+
List<StaticSubtype<? extends T>> ordered = new ArrayList<>(staticSubtypes);
260+
sortSubtypesMostSpecificFirst(ordered, root);
261+
262+
return new StaticPolymorphicTableSchema<>(
263+
rootTableSchema,
264+
discriminatorAttributeName,
265+
Collections.unmodifiableMap(byName),
266+
Collections.unmodifiableList(ordered),
267+
allowMissingDiscriminatorFallbackToRoot
268+
);
269+
}
270+
271+
/**
272+
* Orders subtypes so that deeper subclasses (more specific) are checked first by resolveByInstance.
273+
*/
274+
private static <T> void sortSubtypesMostSpecificFirst(List<StaticSubtype<? extends T>> subs, Class<?> root) {
275+
subs.sort((first, second) -> Integer.compare(
276+
inheritanceDepthFromRoot(second.tableSchema().itemType().rawClass(), root),
277+
inheritanceDepthFromRoot(first.tableSchema().itemType().rawClass(), root)
278+
));
279+
}
280+
281+
/**
282+
* Counts how many superclass steps it takes to reach the given root. Example: if Manager extends Employee extends Person
283+
* (root), then Manager → depth 2, Employee → depth 1, Person → depth 0.
284+
*/
285+
private static int inheritanceDepthFromRoot(Class<?> type, Class<?> root) {
286+
int depth = 0;
287+
Class<?> current = type;
288+
while (current != null && !current.equals(root)) {
289+
current = current.getSuperclass();
290+
depth++;
291+
}
292+
return depth;
231293
}
232294
}
233-
}
295+
}

0 commit comments

Comments
 (0)