2727import java .net .URI ;
2828import java .nio .channels .Channels ;
2929import java .util .*;
30+ import java .util .stream .Collectors ;
3031
3132import scala .Option ;
3233import scala .collection .JavaConverters ;
6162import org .apache .spark .sql .catalyst .InternalRow ;
6263import org .apache .spark .sql .comet .parquet .CometParquetReadSupport ;
6364import org .apache .spark .sql .comet .util .Utils$ ;
65+ import org .apache .spark .sql .errors .QueryExecutionErrors ;
6466import org .apache .spark .sql .execution .datasources .PartitionedFile ;
6567import org .apache .spark .sql .execution .datasources .parquet .ParquetColumn ;
6668import org .apache .spark .sql .execution .datasources .parquet .ParquetToSparkSchemaConverter ;
69+ import org .apache .spark .sql .execution .datasources .parquet .ParquetUtils ;
6770import org .apache .spark .sql .execution .metric .SQLMetric ;
6871import org .apache .spark .sql .internal .SQLConf ;
69- import org .apache .spark .sql .types .DataType ;
70- import org .apache .spark .sql .types .StructField ;
71- import org .apache .spark .sql .types .StructType ;
72+ import org .apache .spark .sql .types .*;
7273import org .apache .spark .sql .vectorized .ColumnarBatch ;
7374import org .apache .spark .util .AccumulatorV2 ;
7475
@@ -235,12 +236,6 @@ public NativeBatchReader(AbstractColumnReader[] columnReaders) {
235236 */
236237 public void init () throws Throwable {
237238
238- conf .set ("spark.sql.parquet.binaryAsString" , "false" );
239- conf .set ("spark.sql.parquet.int96AsTimestamp" , "false" );
240- conf .set ("spark.sql.caseSensitive" , "false" );
241- conf .set ("spark.sql.parquet.inferTimestampNTZ.enabled" , "true" );
242- conf .set ("spark.sql.legacy.parquet.nanosAsLong" , "false" );
243-
244239 useDecimal128 =
245240 conf .getBoolean (
246241 CometConf .COMET_USE_DECIMAL_128 ().key (),
@@ -268,9 +263,9 @@ public void init() throws Throwable {
268263
269264 requestedSchema = footer .getFileMetaData ().getSchema ();
270265 fileSchema = requestedSchema ;
271- ParquetToSparkSchemaConverter converter = new ParquetToSparkSchemaConverter (conf );
272266
273267 if (sparkSchema == null ) {
268+ ParquetToSparkSchemaConverter converter = new ParquetToSparkSchemaConverter (conf );
274269 sparkSchema = converter .convert (requestedSchema );
275270 } else {
276271 requestedSchema =
@@ -283,8 +278,18 @@ public void init() throws Throwable {
283278 sparkSchema .size (), requestedSchema .getFieldCount ()));
284279 }
285280 }
286- this .parquetColumn =
287- converter .convertParquetColumn (requestedSchema , Option .apply (this .sparkSchema ));
281+
282+ boolean caseSensitive =
283+ conf .getBoolean (
284+ SQLConf .CASE_SENSITIVE ().key (),
285+ (boolean ) SQLConf .CASE_SENSITIVE ().defaultValue ().get ());
286+ // rename spark fields based on field_id so name of spark schema field matches the parquet
287+ // field name
288+ if (useFieldId && ParquetUtils .hasFieldIds (sparkSchema )) {
289+ sparkSchema =
290+ getSparkSchemaByFieldId (sparkSchema , requestedSchema .asGroupType (), caseSensitive );
291+ }
292+ this .parquetColumn = getParquetColumn (requestedSchema , this .sparkSchema );
288293
289294 String timeZoneId = conf .get ("spark.sql.session.timeZone" );
290295 // Native code uses "UTC" always as the timeZoneId when converting from spark to arrow schema.
@@ -404,10 +409,6 @@ public void init() throws Throwable {
404409 conf .getInt (
405410 CometConf .COMET_BATCH_SIZE ().key (),
406411 (Integer ) CometConf .COMET_BATCH_SIZE ().defaultValue ().get ());
407- boolean caseSensitive =
408- conf .getBoolean (
409- SQLConf .CASE_SENSITIVE ().key (),
410- (boolean ) SQLConf .CASE_SENSITIVE ().defaultValue ().get ());
411412 this .handle =
412413 Native .initRecordBatchReader (
413414 filePath ,
@@ -424,6 +425,187 @@ public void init() throws Throwable {
424425 isInitialized = true ;
425426 }
426427
428+ private ParquetColumn getParquetColumn (MessageType schema , StructType sparkSchema ) {
429+ // We use a different config from the config that is passed in.
430+ // This follows the setting used in Spark's SpecificParquetRecordReaderBase
431+ Configuration config = new Configuration ();
432+ config .setBoolean (SQLConf .PARQUET_BINARY_AS_STRING ().key (), false );
433+ config .setBoolean (SQLConf .PARQUET_INT96_AS_TIMESTAMP ().key (), false );
434+ config .setBoolean (SQLConf .CASE_SENSITIVE ().key (), false );
435+ config .setBoolean (SQLConf .PARQUET_INFER_TIMESTAMP_NTZ_ENABLED ().key (), false );
436+ config .setBoolean (SQLConf .LEGACY_PARQUET_NANOS_AS_LONG ().key (), false );
437+ ParquetToSparkSchemaConverter converter = new ParquetToSparkSchemaConverter (config );
438+ return converter .convertParquetColumn (schema , Option .apply (sparkSchema ));
439+ }
440+
441+ private Map <Integer , List <Type >> getIdToParquetFieldMap (GroupType type ) {
442+ return type .getFields ().stream ()
443+ .filter (f -> f .getId () != null )
444+ .collect (Collectors .groupingBy (f -> f .getId ().intValue ()));
445+ }
446+
447+ private Map <String , List <Type >> getCaseSensitiveParquetFieldMap (GroupType schema ) {
448+ return schema .getFields ().stream ().collect (Collectors .toMap (Type ::getName , Arrays ::asList ));
449+ }
450+
451+ private Map <String , List <Type >> getCaseInsensitiveParquetFieldMap (GroupType schema ) {
452+ return schema .getFields ().stream ()
453+ .collect (Collectors .groupingBy (f -> f .getName ().toLowerCase (Locale .ROOT )));
454+ }
455+
456+ private Type getMatchingParquetFieldById (
457+ StructField f ,
458+ Map <Integer , List <Type >> idToParquetFieldMap ,
459+ Map <String , List <Type >> nameToParquetFieldMap ,
460+ boolean isCaseSensitive ) {
461+ List <Type > matched = null ;
462+ int fieldId = 0 ;
463+ if (ParquetUtils .hasFieldId (f )) {
464+ fieldId = ParquetUtils .getFieldId (f );
465+ matched = idToParquetFieldMap .get (fieldId );
466+ } else {
467+ String fieldName = isCaseSensitive ? f .name () : f .name ().toLowerCase (Locale .ROOT );
468+ matched = nameToParquetFieldMap .get (fieldName );
469+ }
470+
471+ if (matched == null || matched .isEmpty ()) {
472+ return null ;
473+ }
474+ if (matched .size () > 1 ) {
475+ // Need to fail if there is ambiguity, i.e. more than one field is matched
476+ String parquetTypesString =
477+ matched .stream ().map (Type ::getName ).collect (Collectors .joining ("[" , ", " , "]" ));
478+ throw QueryExecutionErrors .foundDuplicateFieldInFieldIdLookupModeError (
479+ fieldId , parquetTypesString );
480+ } else {
481+ return matched .get (0 );
482+ }
483+ }
484+
485+ // Derived from CometParquetReadSupport.matchFieldId
486+ private String getMatchingNameById (
487+ StructField f ,
488+ Map <Integer , List <Type >> idToParquetFieldMap ,
489+ Map <String , List <Type >> nameToParquetFieldMap ,
490+ boolean isCaseSensitive ) {
491+ Type matched =
492+ getMatchingParquetFieldById (f , idToParquetFieldMap , nameToParquetFieldMap , isCaseSensitive );
493+
494+ // When there is no ID match, we use a fake name to avoid a name match by accident
495+ // We need this name to be unique as well, otherwise there will be type conflicts
496+ if (matched == null ) {
497+ return CometParquetReadSupport .generateFakeColumnName ();
498+ } else {
499+ return matched .getName ();
500+ }
501+ }
502+
503+ // clip ParquetGroup Type
504+ private StructType getSparkSchemaByFieldId (
505+ StructType schema , GroupType parquetSchema , boolean caseSensitive ) {
506+ StructType newSchema = new StructType ();
507+ Map <Integer , List <Type >> idToParquetFieldMap = getIdToParquetFieldMap (parquetSchema );
508+ Map <String , List <Type >> nameToParquetFieldMap =
509+ caseSensitive
510+ ? getCaseSensitiveParquetFieldMap (parquetSchema )
511+ : getCaseInsensitiveParquetFieldMap (parquetSchema );
512+ for (StructField f : schema .fields ()) {
513+ DataType newDataType ;
514+ String fieldName = isCaseSensitive ? f .name () : f .name ().toLowerCase (Locale .ROOT );
515+ List <Type > parquetFieldList = nameToParquetFieldMap .get (fieldName );
516+ if (parquetFieldList == null ) {
517+ newDataType = f .dataType ();
518+ } else {
519+ Type fieldType = parquetFieldList .get (0 );
520+ if (f .dataType () instanceof StructType ) {
521+ newDataType =
522+ getSparkSchemaByFieldId (
523+ (StructType ) f .dataType (), fieldType .asGroupType (), caseSensitive );
524+ } else {
525+ newDataType = getSparkTypeByFieldId (f .dataType (), fieldType , caseSensitive );
526+ }
527+ }
528+ String matchedName =
529+ getMatchingNameById (f , idToParquetFieldMap , nameToParquetFieldMap , isCaseSensitive );
530+ StructField newField = f .copy (matchedName , newDataType , f .nullable (), f .metadata ());
531+ newSchema = newSchema .add (newField );
532+ }
533+ return newSchema ;
534+ }
535+
536+ private DataType getSparkTypeByFieldId (
537+ DataType dataType , Type parquetType , boolean caseSensitive ) {
538+ DataType newDataType ;
539+ if (dataType instanceof StructType ) {
540+ newDataType =
541+ getSparkSchemaByFieldId ((StructType ) dataType , parquetType .asGroupType (), caseSensitive );
542+ } else if (dataType instanceof ArrayType ) {
543+
544+ newDataType =
545+ getSparkArrayTypeByFieldId (
546+ (ArrayType ) dataType , parquetType .asGroupType (), caseSensitive );
547+ } else if (dataType instanceof MapType ) {
548+ MapType mapType = (MapType ) dataType ;
549+ DataType keyType = mapType .keyType ();
550+ DataType valueType = mapType .valueType ();
551+ DataType newKeyType ;
552+ DataType newValueType ;
553+ Type parquetMapType = parquetType .asGroupType ().getFields ().get (0 );
554+ Type parquetKeyType = parquetMapType .asGroupType ().getType ("key" );
555+ Type parquetValueType = parquetMapType .asGroupType ().getType ("value" );
556+ if (keyType instanceof StructType ) {
557+ newKeyType =
558+ getSparkSchemaByFieldId (
559+ (StructType ) keyType , parquetKeyType .asGroupType (), caseSensitive );
560+ } else {
561+ newKeyType = keyType ;
562+ }
563+ if (valueType instanceof StructType ) {
564+ newValueType =
565+ getSparkSchemaByFieldId (
566+ (StructType ) valueType , parquetValueType .asGroupType (), caseSensitive );
567+ } else {
568+ newValueType = valueType ;
569+ }
570+ newDataType = new MapType (newKeyType , newValueType , mapType .valueContainsNull ());
571+ } else {
572+ newDataType = dataType ;
573+ }
574+ return newDataType ;
575+ }
576+
577+ private DataType getSparkArrayTypeByFieldId (
578+ ArrayType arrayType , GroupType parquetType , boolean caseSensitive ) {
579+ DataType newDataType ;
580+ DataType elementType = arrayType .elementType ();
581+ DataType newElementType ;
582+ Type parquetList = parquetType .getFields ().get (0 );
583+ Type parquetElementType ;
584+ if (parquetList .getLogicalTypeAnnotation () == null
585+ && parquetList .isRepetition (Type .Repetition .REPEATED )) {
586+ parquetElementType = parquetList ;
587+ } else {
588+ // we expect only non-primitive types here (see clipParquetListTypes for related logic)
589+ GroupType repeatedGroup = parquetList .asGroupType ().getType (0 ).asGroupType ();
590+ if (repeatedGroup .getFieldCount () > 1
591+ || Objects .equals (repeatedGroup .getName (), "array" )
592+ || Objects .equals (repeatedGroup .getName (), parquetList .getName () + "_tuple" )) {
593+ parquetElementType = repeatedGroup ;
594+ } else {
595+ parquetElementType = repeatedGroup .getType (0 );
596+ }
597+ }
598+ if (elementType instanceof StructType ) {
599+ newElementType =
600+ getSparkSchemaByFieldId (
601+ (StructType ) elementType , parquetElementType .asGroupType (), caseSensitive );
602+ } else {
603+ newElementType = getSparkTypeByFieldId (elementType , parquetElementType , caseSensitive );
604+ }
605+ newDataType = new ArrayType (newElementType , arrayType .containsNull ());
606+ return newDataType ;
607+ }
608+
427609 private void checkParquetType (ParquetColumn column ) throws IOException {
428610 String [] path = JavaConverters .seqAsJavaList (column .path ()).toArray (new String [0 ]);
429611 if (containsPath (fileSchema , path )) {
0 commit comments