Skip to content

Commit 530ba7e

Browse files
authored
fix: map parquet field_id correctly (native_iceberg_compat) (#1815)
* fix: map parquet field_id correctly (native_iceberg_compat)
1 parent d33f903 commit 530ba7e

File tree

1 file changed

+198
-16
lines changed

1 file changed

+198
-16
lines changed

common/src/main/java/org/apache/comet/parquet/NativeBatchReader.java

Lines changed: 198 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import java.net.URI;
2828
import java.nio.channels.Channels;
2929
import java.util.*;
30+
import java.util.stream.Collectors;
3031

3132
import scala.Option;
3233
import scala.collection.JavaConverters;
@@ -61,14 +62,14 @@
6162
import org.apache.spark.sql.catalyst.InternalRow;
6263
import org.apache.spark.sql.comet.parquet.CometParquetReadSupport;
6364
import org.apache.spark.sql.comet.util.Utils$;
65+
import org.apache.spark.sql.errors.QueryExecutionErrors;
6466
import org.apache.spark.sql.execution.datasources.PartitionedFile;
6567
import org.apache.spark.sql.execution.datasources.parquet.ParquetColumn;
6668
import org.apache.spark.sql.execution.datasources.parquet.ParquetToSparkSchemaConverter;
69+
import org.apache.spark.sql.execution.datasources.parquet.ParquetUtils;
6770
import org.apache.spark.sql.execution.metric.SQLMetric;
6871
import org.apache.spark.sql.internal.SQLConf;
69-
import org.apache.spark.sql.types.DataType;
70-
import org.apache.spark.sql.types.StructField;
71-
import org.apache.spark.sql.types.StructType;
72+
import org.apache.spark.sql.types.*;
7273
import org.apache.spark.sql.vectorized.ColumnarBatch;
7374
import org.apache.spark.util.AccumulatorV2;
7475

@@ -235,12 +236,6 @@ public NativeBatchReader(AbstractColumnReader[] columnReaders) {
235236
*/
236237
public void init() throws Throwable {
237238

238-
conf.set("spark.sql.parquet.binaryAsString", "false");
239-
conf.set("spark.sql.parquet.int96AsTimestamp", "false");
240-
conf.set("spark.sql.caseSensitive", "false");
241-
conf.set("spark.sql.parquet.inferTimestampNTZ.enabled", "true");
242-
conf.set("spark.sql.legacy.parquet.nanosAsLong", "false");
243-
244239
useDecimal128 =
245240
conf.getBoolean(
246241
CometConf.COMET_USE_DECIMAL_128().key(),
@@ -268,9 +263,9 @@ public void init() throws Throwable {
268263

269264
requestedSchema = footer.getFileMetaData().getSchema();
270265
fileSchema = requestedSchema;
271-
ParquetToSparkSchemaConverter converter = new ParquetToSparkSchemaConverter(conf);
272266

273267
if (sparkSchema == null) {
268+
ParquetToSparkSchemaConverter converter = new ParquetToSparkSchemaConverter(conf);
274269
sparkSchema = converter.convert(requestedSchema);
275270
} else {
276271
requestedSchema =
@@ -283,8 +278,18 @@ public void init() throws Throwable {
283278
sparkSchema.size(), requestedSchema.getFieldCount()));
284279
}
285280
}
286-
this.parquetColumn =
287-
converter.convertParquetColumn(requestedSchema, Option.apply(this.sparkSchema));
281+
282+
boolean caseSensitive =
283+
conf.getBoolean(
284+
SQLConf.CASE_SENSITIVE().key(),
285+
(boolean) SQLConf.CASE_SENSITIVE().defaultValue().get());
286+
// rename spark fields based on field_id so name of spark schema field matches the parquet
287+
// field name
288+
if (useFieldId && ParquetUtils.hasFieldIds(sparkSchema)) {
289+
sparkSchema =
290+
getSparkSchemaByFieldId(sparkSchema, requestedSchema.asGroupType(), caseSensitive);
291+
}
292+
this.parquetColumn = getParquetColumn(requestedSchema, this.sparkSchema);
288293

289294
String timeZoneId = conf.get("spark.sql.session.timeZone");
290295
// Native code uses "UTC" always as the timeZoneId when converting from spark to arrow schema.
@@ -404,10 +409,6 @@ public void init() throws Throwable {
404409
conf.getInt(
405410
CometConf.COMET_BATCH_SIZE().key(),
406411
(Integer) CometConf.COMET_BATCH_SIZE().defaultValue().get());
407-
boolean caseSensitive =
408-
conf.getBoolean(
409-
SQLConf.CASE_SENSITIVE().key(),
410-
(boolean) SQLConf.CASE_SENSITIVE().defaultValue().get());
411412
this.handle =
412413
Native.initRecordBatchReader(
413414
filePath,
@@ -424,6 +425,187 @@ public void init() throws Throwable {
424425
isInitialized = true;
425426
}
426427

428+
private ParquetColumn getParquetColumn(MessageType schema, StructType sparkSchema) {
429+
// We use a different config from the config that is passed in.
430+
// This follows the setting used in Spark's SpecificParquetRecordReaderBase
431+
Configuration config = new Configuration();
432+
config.setBoolean(SQLConf.PARQUET_BINARY_AS_STRING().key(), false);
433+
config.setBoolean(SQLConf.PARQUET_INT96_AS_TIMESTAMP().key(), false);
434+
config.setBoolean(SQLConf.CASE_SENSITIVE().key(), false);
435+
config.setBoolean(SQLConf.PARQUET_INFER_TIMESTAMP_NTZ_ENABLED().key(), false);
436+
config.setBoolean(SQLConf.LEGACY_PARQUET_NANOS_AS_LONG().key(), false);
437+
ParquetToSparkSchemaConverter converter = new ParquetToSparkSchemaConverter(config);
438+
return converter.convertParquetColumn(schema, Option.apply(sparkSchema));
439+
}
440+
441+
private Map<Integer, List<Type>> getIdToParquetFieldMap(GroupType type) {
442+
return type.getFields().stream()
443+
.filter(f -> f.getId() != null)
444+
.collect(Collectors.groupingBy(f -> f.getId().intValue()));
445+
}
446+
447+
private Map<String, List<Type>> getCaseSensitiveParquetFieldMap(GroupType schema) {
448+
return schema.getFields().stream().collect(Collectors.toMap(Type::getName, Arrays::asList));
449+
}
450+
451+
private Map<String, List<Type>> getCaseInsensitiveParquetFieldMap(GroupType schema) {
452+
return schema.getFields().stream()
453+
.collect(Collectors.groupingBy(f -> f.getName().toLowerCase(Locale.ROOT)));
454+
}
455+
456+
private Type getMatchingParquetFieldById(
457+
StructField f,
458+
Map<Integer, List<Type>> idToParquetFieldMap,
459+
Map<String, List<Type>> nameToParquetFieldMap,
460+
boolean isCaseSensitive) {
461+
List<Type> matched = null;
462+
int fieldId = 0;
463+
if (ParquetUtils.hasFieldId(f)) {
464+
fieldId = ParquetUtils.getFieldId(f);
465+
matched = idToParquetFieldMap.get(fieldId);
466+
} else {
467+
String fieldName = isCaseSensitive ? f.name() : f.name().toLowerCase(Locale.ROOT);
468+
matched = nameToParquetFieldMap.get(fieldName);
469+
}
470+
471+
if (matched == null || matched.isEmpty()) {
472+
return null;
473+
}
474+
if (matched.size() > 1) {
475+
// Need to fail if there is ambiguity, i.e. more than one field is matched
476+
String parquetTypesString =
477+
matched.stream().map(Type::getName).collect(Collectors.joining("[", ", ", "]"));
478+
throw QueryExecutionErrors.foundDuplicateFieldInFieldIdLookupModeError(
479+
fieldId, parquetTypesString);
480+
} else {
481+
return matched.get(0);
482+
}
483+
}
484+
485+
// Derived from CometParquetReadSupport.matchFieldId
486+
private String getMatchingNameById(
487+
StructField f,
488+
Map<Integer, List<Type>> idToParquetFieldMap,
489+
Map<String, List<Type>> nameToParquetFieldMap,
490+
boolean isCaseSensitive) {
491+
Type matched =
492+
getMatchingParquetFieldById(f, idToParquetFieldMap, nameToParquetFieldMap, isCaseSensitive);
493+
494+
// When there is no ID match, we use a fake name to avoid a name match by accident
495+
// We need this name to be unique as well, otherwise there will be type conflicts
496+
if (matched == null) {
497+
return CometParquetReadSupport.generateFakeColumnName();
498+
} else {
499+
return matched.getName();
500+
}
501+
}
502+
503+
// clip ParquetGroup Type
504+
private StructType getSparkSchemaByFieldId(
505+
StructType schema, GroupType parquetSchema, boolean caseSensitive) {
506+
StructType newSchema = new StructType();
507+
Map<Integer, List<Type>> idToParquetFieldMap = getIdToParquetFieldMap(parquetSchema);
508+
Map<String, List<Type>> nameToParquetFieldMap =
509+
caseSensitive
510+
? getCaseSensitiveParquetFieldMap(parquetSchema)
511+
: getCaseInsensitiveParquetFieldMap(parquetSchema);
512+
for (StructField f : schema.fields()) {
513+
DataType newDataType;
514+
String fieldName = isCaseSensitive ? f.name() : f.name().toLowerCase(Locale.ROOT);
515+
List<Type> parquetFieldList = nameToParquetFieldMap.get(fieldName);
516+
if (parquetFieldList == null) {
517+
newDataType = f.dataType();
518+
} else {
519+
Type fieldType = parquetFieldList.get(0);
520+
if (f.dataType() instanceof StructType) {
521+
newDataType =
522+
getSparkSchemaByFieldId(
523+
(StructType) f.dataType(), fieldType.asGroupType(), caseSensitive);
524+
} else {
525+
newDataType = getSparkTypeByFieldId(f.dataType(), fieldType, caseSensitive);
526+
}
527+
}
528+
String matchedName =
529+
getMatchingNameById(f, idToParquetFieldMap, nameToParquetFieldMap, isCaseSensitive);
530+
StructField newField = f.copy(matchedName, newDataType, f.nullable(), f.metadata());
531+
newSchema = newSchema.add(newField);
532+
}
533+
return newSchema;
534+
}
535+
536+
private DataType getSparkTypeByFieldId(
537+
DataType dataType, Type parquetType, boolean caseSensitive) {
538+
DataType newDataType;
539+
if (dataType instanceof StructType) {
540+
newDataType =
541+
getSparkSchemaByFieldId((StructType) dataType, parquetType.asGroupType(), caseSensitive);
542+
} else if (dataType instanceof ArrayType) {
543+
544+
newDataType =
545+
getSparkArrayTypeByFieldId(
546+
(ArrayType) dataType, parquetType.asGroupType(), caseSensitive);
547+
} else if (dataType instanceof MapType) {
548+
MapType mapType = (MapType) dataType;
549+
DataType keyType = mapType.keyType();
550+
DataType valueType = mapType.valueType();
551+
DataType newKeyType;
552+
DataType newValueType;
553+
Type parquetMapType = parquetType.asGroupType().getFields().get(0);
554+
Type parquetKeyType = parquetMapType.asGroupType().getType("key");
555+
Type parquetValueType = parquetMapType.asGroupType().getType("value");
556+
if (keyType instanceof StructType) {
557+
newKeyType =
558+
getSparkSchemaByFieldId(
559+
(StructType) keyType, parquetKeyType.asGroupType(), caseSensitive);
560+
} else {
561+
newKeyType = keyType;
562+
}
563+
if (valueType instanceof StructType) {
564+
newValueType =
565+
getSparkSchemaByFieldId(
566+
(StructType) valueType, parquetValueType.asGroupType(), caseSensitive);
567+
} else {
568+
newValueType = valueType;
569+
}
570+
newDataType = new MapType(newKeyType, newValueType, mapType.valueContainsNull());
571+
} else {
572+
newDataType = dataType;
573+
}
574+
return newDataType;
575+
}
576+
577+
private DataType getSparkArrayTypeByFieldId(
578+
ArrayType arrayType, GroupType parquetType, boolean caseSensitive) {
579+
DataType newDataType;
580+
DataType elementType = arrayType.elementType();
581+
DataType newElementType;
582+
Type parquetList = parquetType.getFields().get(0);
583+
Type parquetElementType;
584+
if (parquetList.getLogicalTypeAnnotation() == null
585+
&& parquetList.isRepetition(Type.Repetition.REPEATED)) {
586+
parquetElementType = parquetList;
587+
} else {
588+
// we expect only non-primitive types here (see clipParquetListTypes for related logic)
589+
GroupType repeatedGroup = parquetList.asGroupType().getType(0).asGroupType();
590+
if (repeatedGroup.getFieldCount() > 1
591+
|| Objects.equals(repeatedGroup.getName(), "array")
592+
|| Objects.equals(repeatedGroup.getName(), parquetList.getName() + "_tuple")) {
593+
parquetElementType = repeatedGroup;
594+
} else {
595+
parquetElementType = repeatedGroup.getType(0);
596+
}
597+
}
598+
if (elementType instanceof StructType) {
599+
newElementType =
600+
getSparkSchemaByFieldId(
601+
(StructType) elementType, parquetElementType.asGroupType(), caseSensitive);
602+
} else {
603+
newElementType = getSparkTypeByFieldId(elementType, parquetElementType, caseSensitive);
604+
}
605+
newDataType = new ArrayType(newElementType, arrayType.containsNull());
606+
return newDataType;
607+
}
608+
427609
private void checkParquetType(ParquetColumn column) throws IOException {
428610
String[] path = JavaConverters.seqAsJavaList(column.path()).toArray(new String[0]);
429611
if (containsPath(fileSchema, path)) {

0 commit comments

Comments
 (0)