Skip to content

Commit 8321c14

Browse files
Jan Vrsovskysrowen
authored andcommitted
[SPARK-21723][ML] Fix writing LibSVM (key not found: numFeatures)
## What changes were proposed in this pull request? Check the option "numFeatures" only when reading LibSVM, not when writing. When writing, Spark was raising an exception. After the change it will ignore the option completely. liancheng HyukjinKwon (Maybe the usage should be forbidden when writing, in a major version change?). ## How was this patch tested? Manual test, that loading and writing LibSVM files work fine, both with and without the numFeatures option. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Jan Vrsovsky <[email protected]> Closes apache#18872 from ProtD/master.
1 parent 8c54f1e commit 8321c14

File tree

2 files changed

+34
-10
lines changed

2 files changed

+34
-10
lines changed

mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -76,12 +76,12 @@ private[libsvm] class LibSVMFileFormat
7676

7777
override def toString: String = "LibSVM"
7878

79-
private def verifySchema(dataSchema: StructType): Unit = {
79+
private def verifySchema(dataSchema: StructType, forWriting: Boolean): Unit = {
8080
if (
8181
dataSchema.size != 2 ||
8282
!dataSchema(0).dataType.sameType(DataTypes.DoubleType) ||
8383
!dataSchema(1).dataType.sameType(new VectorUDT()) ||
84-
!(dataSchema(1).metadata.getLong(LibSVMOptions.NUM_FEATURES).toInt > 0)
84+
!(forWriting || dataSchema(1).metadata.getLong(LibSVMOptions.NUM_FEATURES).toInt > 0)
8585
) {
8686
throw new IOException(s"Illegal schema for libsvm data, schema=$dataSchema")
8787
}
@@ -119,7 +119,7 @@ private[libsvm] class LibSVMFileFormat
119119
job: Job,
120120
options: Map[String, String],
121121
dataSchema: StructType): OutputWriterFactory = {
122-
verifySchema(dataSchema)
122+
verifySchema(dataSchema, true)
123123
new OutputWriterFactory {
124124
override def newInstance(
125125
path: String,
@@ -142,7 +142,7 @@ private[libsvm] class LibSVMFileFormat
142142
filters: Seq[Filter],
143143
options: Map[String, String],
144144
hadoopConf: Configuration): (PartitionedFile) => Iterator[InternalRow] = {
145-
verifySchema(dataSchema)
145+
verifySchema(dataSchema, false)
146146
val numFeatures = dataSchema("features").metadata.getLong(LibSVMOptions.NUM_FEATURES).toInt
147147
assert(numFeatures > 0)
148148

mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,16 @@ package org.apache.spark.ml.source.libsvm
1919

2020
import java.io.{File, IOException}
2121
import java.nio.charset.StandardCharsets
22+
import java.util.List
2223

2324
import com.google.common.io.Files
2425

2526
import org.apache.spark.SparkFunSuite
2627
import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors}
28+
import org.apache.spark.ml.linalg.SQLDataTypes.VectorType
2729
import org.apache.spark.mllib.util.MLlibTestSparkContext
2830
import org.apache.spark.sql.{Row, SaveMode}
31+
import org.apache.spark.sql.types.{DoubleType, StructField, StructType}
2932
import org.apache.spark.util.Utils
3033

3134

@@ -44,14 +47,14 @@ class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext {
4447
"""
4548
|0 2:4.0 4:5.0 6:6.0
4649
""".stripMargin
47-
val dir = Utils.createDirectory(tempDir.getCanonicalPath, "data")
50+
val dir = Utils.createTempDir()
4851
val succ = new File(dir, "_SUCCESS")
4952
val file0 = new File(dir, "part-00000")
5053
val file1 = new File(dir, "part-00001")
5154
Files.write("", succ, StandardCharsets.UTF_8)
5255
Files.write(lines0, file0, StandardCharsets.UTF_8)
5356
Files.write(lines1, file1, StandardCharsets.UTF_8)
54-
path = dir.toURI.toString
57+
path = dir.getPath
5558
}
5659

5760
override def afterAll(): Unit = {
@@ -108,12 +111,12 @@ class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext {
108111

109112
test("write libsvm data and read it again") {
110113
val df = spark.read.format("libsvm").load(path)
111-
val tempDir2 = new File(tempDir, "read_write_test")
112-
val writepath = tempDir2.toURI.toString
114+
val writePath = Utils.createTempDir().getPath
115+
113116
// TODO: Remove requirement to coalesce by supporting multiple reads.
114-
df.coalesce(1).write.format("libsvm").mode(SaveMode.Overwrite).save(writepath)
117+
df.coalesce(1).write.format("libsvm").mode(SaveMode.Overwrite).save(writePath)
115118

116-
val df2 = spark.read.format("libsvm").load(writepath)
119+
val df2 = spark.read.format("libsvm").load(writePath)
117120
val row1 = df2.first()
118121
val v = row1.getAs[SparseVector](1)
119122
assert(v == Vectors.sparse(6, Seq((0, 1.0), (2, 2.0), (4, 3.0))))
@@ -126,6 +129,27 @@ class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext {
126129
}
127130
}
128131

132+
test("write libsvm data from scratch and read it again") {
133+
val rawData = new java.util.ArrayList[Row]()
134+
rawData.add(Row(1.0, Vectors.sparse(3, Seq((0, 2.0), (1, 3.0)))))
135+
rawData.add(Row(4.0, Vectors.sparse(3, Seq((0, 5.0), (2, 6.0)))))
136+
137+
val struct = StructType(
138+
StructField("labelFoo", DoubleType, false) ::
139+
StructField("featuresBar", VectorType, false) :: Nil
140+
)
141+
val df = spark.sqlContext.createDataFrame(rawData, struct)
142+
143+
val writePath = Utils.createTempDir().getPath
144+
145+
df.coalesce(1).write.format("libsvm").mode(SaveMode.Overwrite).save(writePath)
146+
147+
val df2 = spark.read.format("libsvm").load(writePath)
148+
val row1 = df2.first()
149+
val v = row1.getAs[SparseVector](1)
150+
assert(v == Vectors.sparse(3, Seq((0, 2.0), (1, 3.0))))
151+
}
152+
129153
test("select features from libsvm relation") {
130154
val df = spark.read.format("libsvm").load(path)
131155
df.select("features").rdd.map { case Row(d: Vector) => d }.first

0 commit comments

Comments
 (0)