Skip to content

Commit 4eb694c

Browse files
gengliangwangcloud-fan
authored andcommitted
[SPARK-27443][SQL] Support UDF input_file_name in file source V2
## What changes were proposed in this pull request? Currently, if we select the UDF `input_file_name` as a column in file source V2, the results are empty. We should support it in file source V2. ## How was this patch tested? Unit test Closes apache#24347 from gengliangwang/input_file_name. Authored-by: Gengliang Wang <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent bbbe54a commit 4eb694c

File tree

3 files changed

+30
-7
lines changed

3 files changed

+30
-7
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FilePartitionReader.scala

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.datasources.v2
1919
import java.io.{FileNotFoundException, IOException}
2020

2121
import org.apache.spark.internal.Logging
22+
import org.apache.spark.rdd.InputFileBlockHolder
2223
import org.apache.spark.sql.internal.SQLConf
2324
import org.apache.spark.sql.sources.v2.reader.PartitionReader
2425

@@ -35,8 +36,7 @@ class FilePartitionReader[T](readers: Iterator[PartitionedFileReader[T]])
3536
if (readers.hasNext) {
3637
if (ignoreMissingFiles || ignoreCorruptFiles) {
3738
try {
38-
currentReader = readers.next()
39-
logInfo(s"Reading file $currentReader")
39+
currentReader = getNextReader()
4040
} catch {
4141
case e: FileNotFoundException if ignoreMissingFiles =>
4242
logWarning(s"Skipped missing file: $currentReader", e)
@@ -48,11 +48,11 @@ class FilePartitionReader[T](readers: Iterator[PartitionedFileReader[T]])
4848
logWarning(
4949
s"Skipped the rest of the content in the corrupted file: $currentReader", e)
5050
currentReader = null
51+
InputFileBlockHolder.unset()
5152
return false
5253
}
5354
} else {
54-
currentReader = readers.next()
55-
logInfo(s"Reading file $currentReader")
55+
currentReader = getNextReader()
5656
}
5757
} else {
5858
return false
@@ -84,5 +84,15 @@ class FilePartitionReader[T](readers: Iterator[PartitionedFileReader[T]])
8484
if (currentReader != null) {
8585
currentReader.close()
8686
}
87+
InputFileBlockHolder.unset()
88+
}
89+
90+
private def getNextReader(): PartitionedFileReader[T] = {
91+
val reader = readers.next()
92+
logInfo(s"Reading file $reader")
93+
// Sets InputFileBlockHolder for the file block's information
94+
val file = reader.file
95+
InputFileBlockHolder.set(file.filePath, file.start, file.length)
96+
reader
8797
}
8898
}

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FilePartitionReaderFactory.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ abstract class FilePartitionReaderFactory extends PartitionReaderFactory {
2727
assert(partition.isInstanceOf[FilePartition])
2828
val filePartition = partition.asInstanceOf[FilePartition]
2929
val iter = filePartition.files.toIterator.map { file =>
30-
new PartitionedFileReader(file, buildReader(file))
30+
PartitionedFileReader(file, buildReader(file))
3131
}
3232
new FilePartitionReader[InternalRow](iter)
3333
}
@@ -36,7 +36,7 @@ abstract class FilePartitionReaderFactory extends PartitionReaderFactory {
3636
assert(partition.isInstanceOf[FilePartition])
3737
val filePartition = partition.asInstanceOf[FilePartition]
3838
val iter = filePartition.files.toIterator.map { file =>
39-
new PartitionedFileReader(file, buildColumnarReader(file))
39+
PartitionedFileReader(file, buildColumnarReader(file))
4040
}
4141
new FilePartitionReader[ColumnarBatch](iter)
4242
}
@@ -49,7 +49,7 @@ abstract class FilePartitionReaderFactory extends PartitionReaderFactory {
4949
}
5050

5151
// A compound class for combining file and its corresponding reader.
52-
private[v2] class PartitionedFileReader[T](
52+
private[v2] case class PartitionedFileReader[T](
5353
file: PartitionedFile,
5454
reader: PartitionReader[T]) extends PartitionReader[T] {
5555
override def next(): Boolean = reader.next()

sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -526,6 +526,19 @@ class FileBasedDataSourceSuite extends QueryTest with SharedSQLContext with Befo
526526
}
527527
}
528528

529+
test("UDF input_file_name()") {
530+
Seq("", "orc").foreach { useV1SourceReaderList =>
531+
withSQLConf(SQLConf.USE_V1_SOURCE_READER_LIST.key -> useV1SourceReaderList) {
532+
withTempPath { dir =>
533+
val path = dir.getCanonicalPath
534+
spark.range(10).write.orc(path)
535+
val row = spark.read.orc(path).select(input_file_name).first()
536+
assert(row.getString(0).contains(path))
537+
}
538+
}
539+
}
540+
}
541+
529542
test("Return correct results when data columns overlap with partition columns") {
530543
Seq("parquet", "orc", "json").foreach { format =>
531544
withTempPath { path =>

0 commit comments

Comments
 (0)