Skip to content

Commit 8c2849a

Browse files
johnhany97cloud-fan
authored andcommitted
[SPARK-30082][SQL] Do not replace Zeros when replacing NaNs
### What changes were proposed in this pull request? Do not cast `NaN` to an `Integer`, `Long`, `Short` or `Byte`. This is because casting `NaN` to those types results in a `0` which erroneously replaces `0`s while only `NaN`s should be replaced. ### Why are the changes needed? This Scala code snippet: ``` import scala.math; println(Double.NaN.toLong) ``` returns `0` which is problematic as if you run the following Spark code, `0`s get replaced as well: ``` >>> df = spark.createDataFrame([(1.0, 0), (0.0, 3), (float('nan'), 0)], ("index", "value")) >>> df.show() +-----+-----+ |index|value| +-----+-----+ | 1.0| 0| | 0.0| 3| | NaN| 0| +-----+-----+ >>> df.replace(float('nan'), 2).show() +-----+-----+ |index|value| +-----+-----+ | 1.0| 2| | 0.0| 3| | 2.0| 2| +-----+-----+ ``` ### Does this PR introduce any user-facing change? Yes, after the PR, running the same above code snippet returns the correct expected results: ``` >>> df = spark.createDataFrame([(1.0, 0), (0.0, 3), (float('nan'), 0)], ("index", "value")) >>> df.show() +-----+-----+ |index|value| +-----+-----+ | 1.0| 0| | 0.0| 3| | NaN| 0| +-----+-----+ >>> df.replace(float('nan'), 2).show() +-----+-----+ |index|value| +-----+-----+ | 1.0| 0| | 0.0| 3| | 2.0| 0| +-----+-----+ ``` ### How was this patch tested? Added unit tests to verify replacing `NaN` only affects columns of type `Float` and `Double` Closes apache#26738 from johnhany97/SPARK-30082. Lead-authored-by: John Ayad <[email protected]> Co-authored-by: John Ayad <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent 65552a8 commit 8c2849a

File tree

2 files changed

+45
-1
lines changed

2 files changed

+45
-1
lines changed

sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -456,7 +456,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
456456
val keyExpr = df.col(col.name).expr
457457
def buildExpr(v: Any) = Cast(Literal(v), keyExpr.dataType)
458458
val branches = replacementMap.flatMap { case (source, target) =>
459-
Seq(buildExpr(source), buildExpr(target))
459+
Seq(Literal(source), buildExpr(target))
460460
}.toSeq
461461
new Column(CaseKeyWhen(keyExpr, branches :+ keyExpr)).as(col.name)
462462
}

sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,14 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSparkSession {
3737
).toDF("name", "age", "height")
3838
}
3939

40+
def createNaNDF(): DataFrame = {
41+
Seq[(java.lang.Integer, java.lang.Long, java.lang.Short,
42+
java.lang.Byte, java.lang.Float, java.lang.Double)](
43+
(1, 1L, 1.toShort, 1.toByte, 1.0f, 1.0),
44+
(0, 0L, 0.toShort, 0.toByte, Float.NaN, Double.NaN)
45+
).toDF("int", "long", "short", "byte", "float", "double")
46+
}
47+
4048
test("drop") {
4149
val input = createDF()
4250
val rows = input.collect()
@@ -404,4 +412,40 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSparkSession {
404412
df.na.drop("any"),
405413
Row("5", "6", "6") :: Nil)
406414
}
415+
416+
test("replace nan with float") {
417+
checkAnswer(
418+
createNaNDF().na.replace("*", Map(
419+
Float.NaN -> 10.0f
420+
)),
421+
Row(1, 1L, 1.toShort, 1.toByte, 1.0f, 1.0) ::
422+
Row(0, 0L, 0.toShort, 0.toByte, 10.0f, 10.0) :: Nil)
423+
}
424+
425+
test("replace nan with double") {
426+
checkAnswer(
427+
createNaNDF().na.replace("*", Map(
428+
Double.NaN -> 10.0
429+
)),
430+
Row(1, 1L, 1.toShort, 1.toByte, 1.0f, 1.0) ::
431+
Row(0, 0L, 0.toShort, 0.toByte, 10.0f, 10.0) :: Nil)
432+
}
433+
434+
test("replace float with nan") {
435+
checkAnswer(
436+
createNaNDF().na.replace("*", Map(
437+
1.0f -> Float.NaN
438+
)),
439+
Row(0, 0L, 0.toShort, 0.toByte, Float.NaN, Double.NaN) ::
440+
Row(0, 0L, 0.toShort, 0.toByte, Float.NaN, Double.NaN) :: Nil)
441+
}
442+
443+
test("replace double with nan") {
444+
checkAnswer(
445+
createNaNDF().na.replace("*", Map(
446+
1.0 -> Double.NaN
447+
)),
448+
Row(0, 0L, 0.toShort, 0.toByte, Float.NaN, Double.NaN) ::
449+
Row(0, 0L, 0.toShort, 0.toByte, Float.NaN, Double.NaN) :: Nil)
450+
}
407451
}

0 commit comments

Comments
 (0)