Skip to content

Commit 3065e8d

Browse files
authored
GH-2961: Cycle detection in AvroSchemaConverter to prevent infinite recursion (#3272)
1 parent a1d8412 commit 3065e8d

File tree

2 files changed

+153
-12
lines changed

2 files changed

+153
-12
lines changed

parquet-avro/src/main/java/org/apache/parquet/avro/AvroSchemaConverter.java

Lines changed: 73 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
import java.util.Collections;
5252
import java.util.HashMap;
5353
import java.util.HashSet;
54+
import java.util.IdentityHashMap;
5455
import java.util.List;
5556
import java.util.Map;
5657
import java.util.Optional;
@@ -150,16 +151,23 @@ public MessageType convert(Schema avroSchema) {
150151
if (!avroSchema.getType().equals(Schema.Type.RECORD)) {
151152
throw new IllegalArgumentException("Avro schema must be a record.");
152153
}
153-
return new MessageType(avroSchema.getFullName(), convertFields(avroSchema.getFields(), ""));
154+
return new MessageType(
155+
avroSchema.getFullName(),
156+
convertFields(avroSchema.getFields(), "", new IdentityHashMap<Schema, Void>()));
154157
}
155158

156159
private List<Type> convertFields(List<Schema.Field> fields, String schemaPath) {
160+
return convertFields(fields, schemaPath, new IdentityHashMap<Schema, Void>());
161+
}
162+
163+
private List<Type> convertFields(
164+
List<Schema.Field> fields, String schemaPath, IdentityHashMap<Schema, Void> seenSchemas) {
157165
List<Type> types = new ArrayList<Type>();
158166
for (Schema.Field field : fields) {
159167
if (field.schema().getType().equals(Schema.Type.NULL)) {
160168
continue; // Avro nulls are not encoded, unless they are null unions
161169
}
162-
types.add(convertField(field, appendPath(schemaPath, field.name())));
170+
types.add(convertField(field, appendPath(schemaPath, field.name()), seenSchemas));
163171
}
164172
return types;
165173
}
@@ -168,11 +176,37 @@ private Type convertField(String fieldName, Schema schema, String schemaPath) {
168176
return convertField(fieldName, schema, Type.Repetition.REQUIRED, schemaPath);
169177
}
170178

179+
private Type convertField(
180+
String fieldName, Schema schema, String schemaPath, IdentityHashMap<Schema, Void> seenSchemas) {
181+
return convertField(fieldName, schema, Type.Repetition.REQUIRED, schemaPath, seenSchemas);
182+
}
183+
171184
@SuppressWarnings("deprecation")
172185
private Type convertField(String fieldName, Schema schema, Type.Repetition repetition, String schemaPath) {
173-
Types.PrimitiveBuilder<PrimitiveType> builder;
186+
return convertField(fieldName, schema, repetition, schemaPath, new IdentityHashMap<Schema, Void>());
187+
}
188+
189+
@SuppressWarnings("deprecation")
190+
private Type convertField(
191+
String fieldName,
192+
Schema schema,
193+
Type.Repetition repetition,
194+
String schemaPath,
195+
IdentityHashMap<Schema, Void> seenSchemas) {
174196
Schema.Type type = schema.getType();
175197
LogicalType logicalType = schema.getLogicalType();
198+
199+
if (type.equals(Schema.Type.RECORD) || type.equals(Schema.Type.ENUM) || type.equals(Schema.Type.FIXED)) {
200+
// If this schema has already been seen in the current branch, we have a recursion loop
201+
if (seenSchemas.containsKey(schema)) {
202+
throw new UnsupportedOperationException(
203+
"Recursive Avro schemas are not supported by parquet-avro: " + schema.getFullName());
204+
}
205+
seenSchemas = new IdentityHashMap<>(seenSchemas);
206+
seenSchemas.put(schema, null);
207+
}
208+
209+
Types.PrimitiveBuilder<PrimitiveType> builder;
176210
if (type.equals(Schema.Type.BOOLEAN)) {
177211
builder = Types.primitive(BOOLEAN, repetition);
178212
} else if (type.equals(Schema.Type.INT)) {
@@ -195,21 +229,24 @@ private Type convertField(String fieldName, Schema schema, Type.Repetition repet
195229
builder = Types.primitive(BINARY, repetition).as(stringType());
196230
}
197231
} else if (type.equals(Schema.Type.RECORD)) {
198-
return new GroupType(repetition, fieldName, convertFields(schema.getFields(), schemaPath));
232+
return new GroupType(repetition, fieldName, convertFields(schema.getFields(), schemaPath, seenSchemas));
199233
} else if (type.equals(Schema.Type.ENUM)) {
200234
builder = Types.primitive(BINARY, repetition).as(enumType());
201235
} else if (type.equals(Schema.Type.ARRAY)) {
202236
if (writeOldListStructure) {
203237
return ConversionPatterns.listType(
204-
repetition, fieldName, convertField("array", schema.getElementType(), REPEATED, schemaPath));
238+
repetition,
239+
fieldName,
240+
convertField("array", schema.getElementType(), REPEATED, schemaPath, seenSchemas));
205241
} else {
206242
return ConversionPatterns.listOfElements(
207243
repetition,
208244
fieldName,
209-
convertField(AvroWriteSupport.LIST_ELEMENT_NAME, schema.getElementType(), schemaPath));
245+
convertField(
246+
AvroWriteSupport.LIST_ELEMENT_NAME, schema.getElementType(), schemaPath, seenSchemas));
210247
}
211248
} else if (type.equals(Schema.Type.MAP)) {
212-
Type valType = convertField("value", schema.getValueType(), schemaPath);
249+
Type valType = convertField("value", schema.getValueType(), schemaPath, seenSchemas);
213250
// avro map key type is always string
214251
return ConversionPatterns.stringKeyMapType(repetition, fieldName, valType);
215252
} else if (type.equals(Schema.Type.FIXED)) {
@@ -223,7 +260,7 @@ private Type convertField(String fieldName, Schema schema, Type.Repetition repet
223260
builder = Types.primitive(FIXED_LEN_BYTE_ARRAY, repetition).length(schema.getFixedSize());
224261
}
225262
} else if (type.equals(Schema.Type.UNION)) {
226-
return convertUnion(fieldName, schema, repetition, schemaPath);
263+
return convertUnion(fieldName, schema, repetition, schemaPath, seenSchemas);
227264
} else {
228265
throw new UnsupportedOperationException("Cannot convert Avro type " + type);
229266
}
@@ -246,6 +283,15 @@ private Type convertField(String fieldName, Schema schema, Type.Repetition repet
246283
}
247284

248285
private Type convertUnion(String fieldName, Schema schema, Type.Repetition repetition, String schemaPath) {
286+
return convertUnion(fieldName, schema, repetition, schemaPath, new IdentityHashMap<Schema, Void>());
287+
}
288+
289+
private Type convertUnion(
290+
String fieldName,
291+
Schema schema,
292+
Type.Repetition repetition,
293+
String schemaPath,
294+
IdentityHashMap<Schema, Void> seenSchemas) {
249295
List<Schema> nonNullSchemas = new ArrayList<Schema>(schema.getTypes().size());
250296
// Found any schemas in the union? Required for the edge case, where the union contains only a single type.
251297
boolean foundNullSchema = false;
@@ -267,20 +313,31 @@ private Type convertUnion(String fieldName, Schema schema, Type.Repetition repet
267313

268314
case 1:
269315
return foundNullSchema
270-
? convertField(fieldName, nonNullSchemas.get(0), repetition, schemaPath)
271-
: convertUnionToGroupType(fieldName, repetition, nonNullSchemas, schemaPath);
316+
? convertField(fieldName, nonNullSchemas.get(0), repetition, schemaPath, seenSchemas)
317+
: convertUnionToGroupType(fieldName, repetition, nonNullSchemas, schemaPath, seenSchemas);
272318

273319
default: // complex union type
274-
return convertUnionToGroupType(fieldName, repetition, nonNullSchemas, schemaPath);
320+
return convertUnionToGroupType(fieldName, repetition, nonNullSchemas, schemaPath, seenSchemas);
275321
}
276322
}
277323

278324
private Type convertUnionToGroupType(
279325
String fieldName, Type.Repetition repetition, List<Schema> nonNullSchemas, String schemaPath) {
326+
return convertUnionToGroupType(
327+
fieldName, repetition, nonNullSchemas, schemaPath, new IdentityHashMap<Schema, Void>());
328+
}
329+
330+
private Type convertUnionToGroupType(
331+
String fieldName,
332+
Type.Repetition repetition,
333+
List<Schema> nonNullSchemas,
334+
String schemaPath,
335+
IdentityHashMap<Schema, Void> seenSchemas) {
280336
List<Type> unionTypes = new ArrayList<Type>(nonNullSchemas.size());
281337
int index = 0;
282338
for (Schema childSchema : nonNullSchemas) {
283-
unionTypes.add(convertField("member" + index++, childSchema, Type.Repetition.OPTIONAL, schemaPath));
339+
unionTypes.add(
340+
convertField("member" + index++, childSchema, Type.Repetition.OPTIONAL, schemaPath, seenSchemas));
284341
}
285342
return new GroupType(repetition, fieldName, unionTypes);
286343
}
@@ -289,6 +346,10 @@ private Type convertField(Schema.Field field, String schemaPath) {
289346
return convertField(field.name(), field.schema(), schemaPath);
290347
}
291348

349+
private Type convertField(Schema.Field field, String schemaPath, IdentityHashMap<Schema, Void> seenSchemas) {
350+
return convertField(field.name(), field.schema(), schemaPath, seenSchemas);
351+
}
352+
292353
public Schema convert(MessageType parquetSchema) {
293354
return convertFields(parquetSchema.getName(), parquetSchema.getFields(), new HashMap<>());
294355
}

parquet-avro/src/test/java/org/apache/parquet/avro/TestAvroSchemaConverter.java

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -965,6 +965,86 @@ public void testAvroFixed12AsParquetInt96Type() throws Exception {
965965
() -> new AvroSchemaConverter(conf).convert(schema));
966966
}
967967

968+
@Test
969+
public void testRecursiveSchemaThrowsException() {
970+
String recursiveSchemaJson = "{"
971+
+ "\"type\": \"record\", \"name\": \"Node\", \"fields\": ["
972+
+ " {\"name\": \"value\", \"type\": \"int\"},"
973+
+ " {\"name\": \"children\", \"type\": ["
974+
+ " \"null\", {"
975+
+ " \"type\": \"array\", \"items\": [\"null\", \"Node\"]"
976+
+ " }"
977+
+ " ], \"default\": null}"
978+
+ "]}";
979+
980+
Schema recursiveSchema = new Schema.Parser().parse(recursiveSchemaJson);
981+
982+
assertThrows(
983+
"Recursive Avro schema should throw UnsupportedOperationException for cycles",
984+
UnsupportedOperationException.class,
985+
() -> new AvroSchemaConverter().convert(recursiveSchema));
986+
}
987+
988+
@Test
989+
public void testRecursiveSchemaFromGitHubIssue() {
990+
String issueSchemaJson = "{"
991+
+ "\"type\": \"record\", \"name\": \"ObjXX\", \"fields\": ["
992+
+ " {\"name\": \"id\", \"type\": [\"null\", \"long\"], \"default\": null},"
993+
+ " {\"name\": \"struct_add_list\", \"type\": [\"null\", {"
994+
+ " \"type\": \"array\", \"items\": [\"null\", {"
995+
+ " \"type\": \"record\", \"name\": \"ObjStructAdd\", \"fields\": ["
996+
+ " {\"name\": \"name\", \"type\": [\"null\", \"string\"], \"default\": null},"
997+
+ " {\"name\": \"fld_list\", \"type\": [\"null\", {"
998+
+ " \"type\": \"array\", \"items\": [\"null\", {"
999+
+ " \"type\": \"record\", \"name\": \"ObjStructAddFld\", \"fields\": ["
1000+
+ " {\"name\": \"name\", \"type\": [\"null\", \"string\"], \"default\": null},"
1001+
+ " {\"name\": \"ref_val\", \"type\": [\"null\", \"ObjStructAdd\"], \"default\": null}"
1002+
+ " ]"
1003+
+ " }]"
1004+
+ " }], \"default\": null}"
1005+
+ " ]"
1006+
+ " }]"
1007+
+ " }], \"default\": null},"
1008+
+ " {\"name\": \"kafka_timestamp\", \"type\": {\"type\": \"long\", \"logicalType\": \"timestamp-millis\"}}"
1009+
+ "]}";
1010+
1011+
Schema issueSchema = new Schema.Parser().parse(issueSchemaJson);
1012+
1013+
assertThrows(
1014+
"Schema hould throw UnsupportedOperationException for cycles",
1015+
UnsupportedOperationException.class,
1016+
() -> new AvroSchemaConverter().convert(issueSchema));
1017+
}
1018+
1019+
@Test
1020+
public void testRecursiveSchemaErrorMessage() {
1021+
String recursiveSchemaJson = "{"
1022+
+ "\"type\": \"record\", \"name\": \"TestRecord\", \"fields\": ["
1023+
+ " {\"name\": \"self\", \"type\": [\"null\", \"TestRecord\"], \"default\": null}"
1024+
+ "]}";
1025+
1026+
Schema recursiveSchema = new Schema.Parser().parse(recursiveSchemaJson);
1027+
1028+
// With our cycle detection fix, this should throw UnsupportedOperationException
1029+
assertThrows(
1030+
"Recursive schema should throw UnsupportedOperationException with clear error message",
1031+
UnsupportedOperationException.class,
1032+
() -> new AvroSchemaConverter().convert(recursiveSchema));
1033+
}
1034+
1035+
@Test
1036+
public void testDeeplyNestedNonRecursiveSchema() {
1037+
Schema level3 = record("Level3", field("value", primitive(STRING)));
1038+
Schema level2 = record("Level2", field("level3", level3));
1039+
Schema level1 = record("Level1", field("level2", level2));
1040+
Schema rootSchema = record("Root", field("level1", level1));
1041+
1042+
AvroSchemaConverter converter = new AvroSchemaConverter();
1043+
MessageType result = converter.convert(rootSchema);
1044+
Assert.assertNotNull("Non-recursive deep schema should convert successfully", result);
1045+
Assert.assertEquals("Root schema name should be preserved", "Root", result.getName());
1046+
}
1047+
9681048
public static Schema optional(Schema original) {
9691049
return Schema.createUnion(Lists.newArrayList(Schema.create(Schema.Type.NULL), original));
9701050
}

0 commit comments

Comments
 (0)