-
Notifications
You must be signed in to change notification settings - Fork 278
Expand file tree
/
Copy pathGpuReadSequenceFileBinaryFormat.scala
More file actions
114 lines (103 loc) · 4.35 KB
/
GpuReadSequenceFileBinaryFormat.scala
File metadata and controls
114 lines (103 loc) · 4.35 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
/*
* Copyright (c) 2025, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.nvidia.spark.rapids
import com.nvidia.spark.rapids.sequencefile.GpuSequenceFileMultiFilePartitionReaderFactory
import com.nvidia.spark.rapids.sequencefile.GpuSequenceFilePartitionReaderFactory
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileStatus, Path}
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.connector.read.PartitionReaderFactory
import org.apache.spark.sql.execution.FileSourceScanExec
import org.apache.spark.sql.execution.datasources.{FileFormat, PartitionedFile}
import org.apache.spark.sql.rapids.GpuFileSourceScanExec
import org.apache.spark.sql.sources.Filter
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.SerializableConfiguration
/**
* A FileFormat that allows reading Hadoop SequenceFiles and returning raw key/value bytes as
* Spark SQL BinaryType columns.
*
* This is a GPU-enabled scan format in the sense that it returns GPU-backed ColumnarBatch output
* (the parsing itself is CPU-side IO + byte parsing).
*/
class GpuReadSequenceFileBinaryFormat extends FileFormat with GpuReadFileFormatWithMetrics {
override def inferSchema(
sparkSession: SparkSession,
options: Map[String, String],
files: Seq[FileStatus]): Option[StructType] = Some(SequenceFileBinaryFileFormat.dataSchema)
override def isSplitable(
sparkSession: SparkSession,
options: Map[String, String],
path: Path): Boolean = true
override def buildReaderWithPartitionValuesAndMetrics(
sparkSession: SparkSession,
dataSchema: StructType,
partitionSchema: StructType,
requiredSchema: StructType,
filters: Seq[Filter],
options: Map[String, String],
hadoopConf: Configuration,
metrics: Map[String, GpuMetric]): PartitionedFile => Iterator[InternalRow] = {
val sqlConf = sparkSession.sessionState.conf
val broadcastedHadoopConf =
sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf))
val rapidsConf = new RapidsConf(sqlConf)
val factory = GpuSequenceFilePartitionReaderFactory(
sqlConf,
broadcastedHadoopConf,
requiredSchema,
partitionSchema,
rapidsConf,
metrics,
options)
PartitionReaderIterator.buildReader(factory)
}
// Default to multi-file reads (recommended for many small files).
override def isPerFileReadEnabled(conf: RapidsConf): Boolean = false
override def createMultiFileReaderFactory(
broadcastedConf: Broadcast[SerializableConfiguration],
pushedFilters: Array[Filter],
fileScan: GpuFileSourceScanExec): PartitionReaderFactory = {
GpuSequenceFileMultiFilePartitionReaderFactory(
fileScan.conf,
broadcastedConf,
fileScan.requiredSchema,
fileScan.readPartitionSchema,
fileScan.rapidsConf,
fileScan.allMetrics,
fileScan.queryUsesInputFile)
}
}
object GpuReadSequenceFileBinaryFormat {
def tagSupport(meta: SparkPlanMeta[FileSourceScanExec]): Unit = {
val fsse = meta.wrapped
val required = fsse.requiredSchema
// Only support reading BinaryType columns named "key" and/or "value".
required.fields.foreach { f =>
val isKey = f.name.equalsIgnoreCase(SequenceFileBinaryFileFormat.KEY_FIELD)
val isValue = f.name.equalsIgnoreCase(SequenceFileBinaryFileFormat.VALUE_FIELD)
if ((isKey || isValue) && f.dataType != org.apache.spark.sql.types.BinaryType) {
meta.willNotWorkOnGpu(
s"SequenceFileBinary only supports BinaryType for " +
s"'${SequenceFileBinaryFileFormat.KEY_FIELD}' and " +
s"'${SequenceFileBinaryFileFormat.VALUE_FIELD}' columns, but saw " +
s"${f.name}: ${f.dataType.catalogString}")
}
}
}
}