Skip to content

Commit 0bb8d1f

Browse files
author
Nick Pentreath
committed
[SPARK-13969][ML] Add FeatureHasher transformer
This PR adds a `FeatureHasher` transformer, modeled on [scikit-learn](http://scikit-learn.org/stable/modules/generated/sklearn.feature_extraction.FeatureHasher.html) and [Vowpal wabbit](https://github.com/JohnLangford/vowpal_wabbit/wiki/Feature-Hashing-and-Extraction). The transformer operates on multiple input columns in one pass. Current behavior is: * for numerical columns, the values are assumed to be real values and the feature index is `hash(columnName)` while feature value is `feature_value` * for string columns, the values are assumed to be categorical and the feature index is `hash(column_name=feature_value)`, while feature value is `1.0` * For hash collisions, feature values will be summed * `null` (missing) values are ignored The following dataframe illustrates the basic semantics: ``` +---+------+-----+---------+------+-----------------------------------------+ |int|double|float|stringNum|string|features | +---+------+-----+---------+------+-----------------------------------------+ |3 |4.0 |5.0 |1 |foo |(16,[0,8,11,12,15],[5.0,3.0,1.0,4.0,1.0])| |6 |7.0 |8.0 |2 |bar |(16,[0,8,11,12,15],[8.0,6.0,1.0,7.0,1.0])| +---+------+-----+---------+------+-----------------------------------------+ ``` ## How was this patch tested? New unit tests and manual experiments. Author: Nick Pentreath <[email protected]> Closes apache#18513 from MLnick/FeatureHasher.
1 parent 8321c14 commit 0bb8d1f

File tree

3 files changed

+396
-1
lines changed

3 files changed

+396
-1
lines changed
Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.ml.feature
19+
20+
import org.apache.spark.annotation.{Experimental, Since}
21+
import org.apache.spark.ml.Transformer
22+
import org.apache.spark.ml.attribute.AttributeGroup
23+
import org.apache.spark.ml.linalg.Vectors
24+
import org.apache.spark.ml.param.{IntParam, ParamMap, ParamValidators}
25+
import org.apache.spark.ml.param.shared.{HasInputCols, HasOutputCol}
26+
import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable, SchemaUtils}
27+
import org.apache.spark.mllib.feature.{HashingTF => OldHashingTF}
28+
import org.apache.spark.sql.{DataFrame, Dataset, Row}
29+
import org.apache.spark.sql.functions._
30+
import org.apache.spark.sql.types._
31+
import org.apache.spark.util.Utils
32+
import org.apache.spark.util.collection.OpenHashMap
33+
34+
/**
35+
* Feature hashing projects a set of categorical or numerical features into a feature vector of
36+
* specified dimension (typically substantially smaller than that of the original feature
37+
* space). This is done using the hashing trick (https://en.wikipedia.org/wiki/Feature_hashing)
38+
* to map features to indices in the feature vector.
39+
*
40+
* The [[FeatureHasher]] transformer operates on multiple columns. Each column may contain either
41+
* numeric or categorical features. Behavior and handling of column data types is as follows:
42+
* -Numeric columns: For numeric features, the hash value of the column name is used to map the
43+
* feature value to its index in the feature vector. Numeric features are never
44+
* treated as categorical, even when they are integers. You must explicitly
45+
* convert numeric columns containing categorical features to strings first.
46+
* -String columns: For categorical features, the hash value of the string "column_name=value"
47+
* is used to map to the vector index, with an indicator value of `1.0`.
48+
* Thus, categorical features are "one-hot" encoded
49+
* (similarly to using [[OneHotEncoder]] with `dropLast=false`).
50+
* -Boolean columns: Boolean values are treated in the same way as string columns. That is,
51+
* boolean features are represented as "column_name=true" or "column_name=false",
52+
* with an indicator value of `1.0`.
53+
*
54+
* Null (missing) values are ignored (implicitly zero in the resulting feature vector).
55+
*
56+
* Since a simple modulo is used to transform the hash function to a vector index,
57+
* it is advisable to use a power of two as the numFeatures parameter;
58+
* otherwise the features will not be mapped evenly to the vector indices.
59+
*
60+
* {{{
61+
* val df = Seq(
62+
* (2.0, true, "1", "foo"),
63+
* (3.0, false, "2", "bar")
64+
* ).toDF("real", "bool", "stringNum", "string")
65+
*
66+
* val hasher = new FeatureHasher()
67+
* .setInputCols("real", "bool", "stringNum", "num")
68+
* .setOutputCol("features")
69+
*
70+
* hasher.transform(df).show()
71+
*
72+
* +----+-----+---------+------+--------------------+
73+
* |real| bool|stringNum|string| features|
74+
* +----+-----+---------+------+--------------------+
75+
* | 2.0| true| 1| foo|(262144,[51871,63...|
76+
* | 3.0|false| 2| bar|(262144,[6031,806...|
77+
* +----+-----+---------+------+--------------------+
78+
* }}}
79+
*/
80+
@Experimental
81+
@Since("2.3.0")
82+
class FeatureHasher(@Since("2.3.0") override val uid: String) extends Transformer
83+
with HasInputCols with HasOutputCol with DefaultParamsWritable {
84+
85+
@Since("2.3.0")
86+
def this() = this(Identifiable.randomUID("featureHasher"))
87+
88+
/**
89+
* Number of features. Should be greater than 0.
90+
* (default = 2^18^)
91+
* @group param
92+
*/
93+
@Since("2.3.0")
94+
val numFeatures = new IntParam(this, "numFeatures", "number of features (> 0)",
95+
ParamValidators.gt(0))
96+
97+
setDefault(numFeatures -> (1 << 18))
98+
99+
/** @group getParam */
100+
@Since("2.3.0")
101+
def getNumFeatures: Int = $(numFeatures)
102+
103+
/** @group setParam */
104+
@Since("2.3.0")
105+
def setNumFeatures(value: Int): this.type = set(numFeatures, value)
106+
107+
/** @group setParam */
108+
@Since("2.3.0")
109+
def setInputCols(values: String*): this.type = setInputCols(values.toArray)
110+
111+
/** @group setParam */
112+
@Since("2.3.0")
113+
def setInputCols(value: Array[String]): this.type = set(inputCols, value)
114+
115+
/** @group setParam */
116+
@Since("2.3.0")
117+
def setOutputCol(value: String): this.type = set(outputCol, value)
118+
119+
@Since("2.3.0")
120+
override def transform(dataset: Dataset[_]): DataFrame = {
121+
val hashFunc: Any => Int = OldHashingTF.murmur3Hash
122+
val n = $(numFeatures)
123+
val localInputCols = $(inputCols)
124+
125+
val outputSchema = transformSchema(dataset.schema)
126+
val realFields = outputSchema.fields.filter { f =>
127+
f.dataType.isInstanceOf[NumericType]
128+
}.map(_.name).toSet
129+
130+
def getDouble(x: Any): Double = {
131+
x match {
132+
case n: java.lang.Number =>
133+
n.doubleValue()
134+
case other =>
135+
// will throw ClassCastException if it cannot be cast, as would row.getDouble
136+
other.asInstanceOf[Double]
137+
}
138+
}
139+
140+
val hashFeatures = udf { row: Row =>
141+
val map = new OpenHashMap[Int, Double]()
142+
localInputCols.foreach { colName =>
143+
val fieldIndex = row.fieldIndex(colName)
144+
if (!row.isNullAt(fieldIndex)) {
145+
val (rawIdx, value) = if (realFields(colName)) {
146+
// numeric values are kept as is, with vector index based on hash of "column_name"
147+
val value = getDouble(row.get(fieldIndex))
148+
val hash = hashFunc(colName)
149+
(hash, value)
150+
} else {
151+
// string and boolean values are treated as categorical, with an indicator value of 1.0
152+
// and vector index based on hash of "column_name=value"
153+
val value = row.get(fieldIndex).toString
154+
val fieldName = s"$colName=$value"
155+
val hash = hashFunc(fieldName)
156+
(hash, 1.0)
157+
}
158+
val idx = Utils.nonNegativeMod(rawIdx, n)
159+
map.changeValue(idx, value, v => v + value)
160+
}
161+
}
162+
Vectors.sparse(n, map.toSeq)
163+
}
164+
165+
val metadata = outputSchema($(outputCol)).metadata
166+
dataset.select(
167+
col("*"),
168+
hashFeatures(struct($(inputCols).map(col): _*)).as($(outputCol), metadata))
169+
}
170+
171+
@Since("2.3.0")
172+
override def copy(extra: ParamMap): FeatureHasher = defaultCopy(extra)
173+
174+
@Since("2.3.0")
175+
override def transformSchema(schema: StructType): StructType = {
176+
val fields = schema($(inputCols).toSet)
177+
fields.foreach { fieldSchema =>
178+
val dataType = fieldSchema.dataType
179+
val fieldName = fieldSchema.name
180+
require(dataType.isInstanceOf[NumericType] ||
181+
dataType.isInstanceOf[StringType] ||
182+
dataType.isInstanceOf[BooleanType],
183+
s"FeatureHasher requires columns to be of NumericType, BooleanType or StringType. " +
184+
s"Column $fieldName was $dataType")
185+
}
186+
val attrGroup = new AttributeGroup($(outputCol), $(numFeatures))
187+
SchemaUtils.appendColumn(schema, attrGroup.toStructField())
188+
}
189+
}
190+
191+
@Since("2.3.0")
192+
object FeatureHasher extends DefaultParamsReadable[FeatureHasher] {
193+
194+
@Since("2.3.0")
195+
override def load(path: String): FeatureHasher = super.load(path)
196+
}
Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,193 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.ml.feature
19+
20+
import org.apache.spark.SparkFunSuite
21+
import org.apache.spark.ml.attribute.AttributeGroup
22+
import org.apache.spark.ml.linalg.{Vector, Vectors}
23+
import org.apache.spark.ml.param.ParamsSuite
24+
import org.apache.spark.ml.util.DefaultReadWriteTest
25+
import org.apache.spark.ml.util.TestingUtils._
26+
import org.apache.spark.mllib.util.MLlibTestSparkContext
27+
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
28+
import org.apache.spark.sql.functions.col
29+
import org.apache.spark.sql.types._
30+
31+
class FeatureHasherSuite extends SparkFunSuite
32+
with MLlibTestSparkContext
33+
with DefaultReadWriteTest {
34+
35+
import testImplicits._
36+
37+
import HashingTFSuite.murmur3FeatureIdx
38+
39+
implicit private val vectorEncoder = ExpressionEncoder[Vector]()
40+
41+
test("params") {
42+
ParamsSuite.checkParams(new FeatureHasher)
43+
}
44+
45+
test("specify input cols using varargs or array") {
46+
val featureHasher1 = new FeatureHasher()
47+
.setInputCols("int", "double", "float", "stringNum", "string")
48+
val featureHasher2 = new FeatureHasher()
49+
.setInputCols(Array("int", "double", "float", "stringNum", "string"))
50+
assert(featureHasher1.getInputCols === featureHasher2.getInputCols)
51+
}
52+
53+
test("feature hashing") {
54+
val df = Seq(
55+
(2.0, true, "1", "foo"),
56+
(3.0, false, "2", "bar")
57+
).toDF("real", "bool", "stringNum", "string")
58+
59+
val n = 100
60+
val hasher = new FeatureHasher()
61+
.setInputCols("real", "bool", "stringNum", "string")
62+
.setOutputCol("features")
63+
.setNumFeatures(n)
64+
val output = hasher.transform(df)
65+
val attrGroup = AttributeGroup.fromStructField(output.schema("features"))
66+
assert(attrGroup.numAttributes === Some(n))
67+
68+
val features = output.select("features").as[Vector].collect()
69+
// Assume perfect hash on field names
70+
def idx: Any => Int = murmur3FeatureIdx(n)
71+
// check expected indices
72+
val expected = Seq(
73+
Vectors.sparse(n, Seq((idx("real"), 2.0), (idx("bool=true"), 1.0),
74+
(idx("stringNum=1"), 1.0), (idx("string=foo"), 1.0))),
75+
Vectors.sparse(n, Seq((idx("real"), 3.0), (idx("bool=false"), 1.0),
76+
(idx("stringNum=2"), 1.0), (idx("string=bar"), 1.0)))
77+
)
78+
assert(features.zip(expected).forall { case (e, a) => e ~== a absTol 1e-14 })
79+
}
80+
81+
test("hashing works for all numeric types") {
82+
val df = Seq(5.0, 10.0, 15.0).toDF("real")
83+
84+
val hasher = new FeatureHasher()
85+
.setInputCols("real")
86+
.setOutputCol("features")
87+
88+
val expectedResult = hasher.transform(df).select("features").as[Vector].collect()
89+
// check all numeric types work as expected. String & boolean types are tested in default case
90+
val types =
91+
Seq(ShortType, LongType, IntegerType, FloatType, ByteType, DoubleType, DecimalType(10, 0))
92+
types.foreach { t =>
93+
val castDF = df.select(col("real").cast(t))
94+
val castResult = hasher.transform(castDF).select("features").as[Vector].collect()
95+
withClue(s"FeatureHasher works for all numeric types (testing $t): ") {
96+
assert(castResult.zip(expectedResult).forall { case (actual, expected) =>
97+
actual ~== expected absTol 1e-14
98+
})
99+
}
100+
}
101+
}
102+
103+
test("invalid input type should fail") {
104+
val df = Seq(
105+
Vectors.dense(1),
106+
Vectors.dense(2)
107+
).toDF("vec")
108+
109+
intercept[IllegalArgumentException] {
110+
new FeatureHasher().setInputCols("vec").transform(df)
111+
}
112+
}
113+
114+
test("hash collisions sum feature values") {
115+
val df = Seq(
116+
(1.0, "foo", "foo"),
117+
(2.0, "bar", "baz")
118+
).toDF("real", "string1", "string2")
119+
120+
val n = 1
121+
val hasher = new FeatureHasher()
122+
.setInputCols("real", "string1", "string2")
123+
.setOutputCol("features")
124+
.setNumFeatures(n)
125+
126+
val features = hasher.transform(df).select("features").as[Vector].collect()
127+
def idx: Any => Int = murmur3FeatureIdx(n)
128+
// everything should hash into one field
129+
assert(idx("real") === idx("string1=foo"))
130+
assert(idx("string1=foo") === idx("string2=foo"))
131+
assert(idx("string2=foo") === idx("string1=bar"))
132+
assert(idx("string1=bar") === idx("string2=baz"))
133+
val expected = Seq(
134+
Vectors.sparse(n, Seq((idx("string1=foo"), 3.0))),
135+
Vectors.sparse(n, Seq((idx("string2=bar"), 4.0)))
136+
)
137+
assert(features.zip(expected).forall { case (e, a) => e ~== a absTol 1e-14 })
138+
}
139+
140+
test("ignores null values in feature hashing") {
141+
import org.apache.spark.sql.functions._
142+
143+
val df = Seq(
144+
(2.0, "foo", null),
145+
(3.0, "bar", "baz")
146+
).toDF("real", "string1", "string2").select(
147+
when(col("real") === 3.0, null).otherwise(col("real")).alias("real"),
148+
col("string1"),
149+
col("string2")
150+
)
151+
152+
val n = 100
153+
val hasher = new FeatureHasher()
154+
.setInputCols("real", "string1", "string2")
155+
.setOutputCol("features")
156+
.setNumFeatures(n)
157+
158+
val features = hasher.transform(df).select("features").as[Vector].collect()
159+
def idx: Any => Int = murmur3FeatureIdx(n)
160+
val expected = Seq(
161+
Vectors.sparse(n, Seq((idx("real"), 2.0), (idx("string1=foo"), 1.0))),
162+
Vectors.sparse(n, Seq((idx("string1=bar"), 1.0), (idx("string2=baz"), 1.0)))
163+
)
164+
assert(features.zip(expected).forall { case (e, a) => e ~== a absTol 1e-14 })
165+
}
166+
167+
test("unicode column names and values") {
168+
// scalastyle:off nonascii
169+
val df = Seq((2.0, "中文")).toDF("中文", "unicode")
170+
171+
val n = 100
172+
val hasher = new FeatureHasher()
173+
.setInputCols("中文", "unicode")
174+
.setOutputCol("features")
175+
.setNumFeatures(n)
176+
177+
val features = hasher.transform(df).select("features").as[Vector].collect()
178+
def idx: Any => Int = murmur3FeatureIdx(n)
179+
val expected = Seq(
180+
Vectors.sparse(n, Seq((idx("中文"), 2.0), (idx("unicode=中文"), 1.0)))
181+
)
182+
assert(features.zip(expected).forall { case (e, a) => e ~== a absTol 1e-14 })
183+
// scalastyle:on nonascii
184+
}
185+
186+
test("read/write") {
187+
val t = new FeatureHasher()
188+
.setInputCols(Array("myCol1", "myCol2", "myCol3"))
189+
.setOutputCol("myOutputCol")
190+
.setNumFeatures(10)
191+
testDefaultReadWrite(t)
192+
}
193+
}

0 commit comments

Comments
 (0)