Skip to content

Commit 710ddab

Browse files
committed
[SPARK-29914][ML] ML models attach metadata in transform/transformSchema
### What changes were proposed in this pull request? 1, `predictionCol` in `ml.classification` & `ml.clustering` add `NominalAttribute` 2, `rawPredictionCol` in `ml.classification` add `AttributeGroup` containing vectorsize=`numClasses` 3, `probabilityCol` in `ml.classification` & `ml.clustering` add `AttributeGroup` containing vectorsize=`numClasses`/`k` 4, `leafCol` in GBT/RF add `AttributeGroup` containing vectorsize=`numTrees` 5, `leafCol` in DecisionTree add `NominalAttribute` 6, `outputCol` in models in `ml.feature` add `AttributeGroup` containing vectorsize 7, `outputCol` in `UnaryTransformer`s in `ml.feature` add `AttributeGroup` containing vectorsize ### Why are the changes needed? Appened metadata can be used in downstream ops, like `Classifier.getNumClasses` There are many impls (like `Binarizer`/`Bucketizer`/`VectorAssembler`/`OneHotEncoder`/`FeatureHasher`/`HashingTF`/`VectorSlicer`/...) in `.ml` that append appropriate metadata in `transform`/`transformSchema` method. However there are also many impls return no metadata in transformation, even some metadata like `vector.size`/`numAttrs`/`attrs` can be ealily inferred. ### Does this PR introduce any user-facing change? Yes, add some metadatas in transformed dataset. ### How was this patch tested? existing testsuites and added testsuites Closes apache#26547 from zhengruifeng/add_output_vecSize. Authored-by: zhengruifeng <[email protected]> Signed-off-by: zhengruifeng <[email protected]>
1 parent 55132ae commit 710ddab

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

51 files changed

+593
-105
lines changed

mllib/src/main/scala/org/apache/spark/ml/Predictor.scala

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -88,8 +88,9 @@ private[ml] trait PredictorParams extends Params
8888
* and put it in an RDD with strong types.
8989
* Validate the output instances with the given function.
9090
*/
91-
protected def extractInstances(dataset: Dataset[_],
92-
validateInstance: Instance => Unit): RDD[Instance] = {
91+
protected def extractInstances(
92+
dataset: Dataset[_],
93+
validateInstance: Instance => Unit): RDD[Instance] = {
9394
extractInstances(dataset).map { instance =>
9495
validateInstance(instance)
9596
instance
@@ -222,7 +223,11 @@ abstract class PredictionModel[FeaturesType, M <: PredictionModel[FeaturesType,
222223
protected def featuresDataType: DataType = new VectorUDT
223224

224225
override def transformSchema(schema: StructType): StructType = {
225-
validateAndTransformSchema(schema, fitting = false, featuresDataType)
226+
var outputSchema = validateAndTransformSchema(schema, fitting = false, featuresDataType)
227+
if ($(predictionCol).nonEmpty) {
228+
outputSchema = SchemaUtils.updateNumeric(outputSchema, $(predictionCol))
229+
}
230+
outputSchema
226231
}
227232

228233
/**
@@ -244,10 +249,12 @@ abstract class PredictionModel[FeaturesType, M <: PredictionModel[FeaturesType,
244249
}
245250

246251
protected def transformImpl(dataset: Dataset[_]): DataFrame = {
247-
val predictUDF = udf { (features: Any) =>
252+
val outputSchema = transformSchema(dataset.schema, logging = true)
253+
val predictUDF = udf { features: Any =>
248254
predict(features.asInstanceOf[FeaturesType])
249255
}
250-
dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
256+
dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))),
257+
outputSchema($(predictionCol)).metadata)
251258
}
252259

253260
/**

mllib/src/main/scala/org/apache/spark/ml/Transformer.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,9 +117,10 @@ abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, OUT, T]]
117117
}
118118

119119
override def transform(dataset: Dataset[_]): DataFrame = {
120-
transformSchema(dataset.schema, logging = true)
120+
val outputSchema = transformSchema(dataset.schema, logging = true)
121121
val transformUDF = udf(this.createTransformFunc, outputDataType)
122-
dataset.withColumn($(outputCol), transformUDF(dataset($(inputCol))))
122+
dataset.withColumn($(outputCol), transformUDF(dataset($(inputCol))),
123+
outputSchema($(outputCol)).metadata)
123124
}
124125

125126
override def copy(extra: ParamMap): T = defaultCopy(extra)

mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,9 @@ private[spark] trait ClassifierParams
4848
* and put it in an RDD with strong types.
4949
* Validates the label on the classifier is a valid integer in the range [0, numClasses).
5050
*/
51-
protected def extractInstances(dataset: Dataset[_],
52-
numClasses: Int): RDD[Instance] = {
51+
protected def extractInstances(
52+
dataset: Dataset[_],
53+
numClasses: Int): RDD[Instance] = {
5354
val validateInstance = (instance: Instance) => {
5455
val label = instance.label
5556
require(label.toLong == label && label >= 0 && label < numClasses, s"Classifier was given" +
@@ -183,6 +184,19 @@ abstract class ClassificationModel[FeaturesType, M <: ClassificationModel[Featur
183184
/** Number of classes (values which the label can take). */
184185
def numClasses: Int
185186

187+
override def transformSchema(schema: StructType): StructType = {
188+
var outputSchema = super.transformSchema(schema)
189+
if ($(predictionCol).nonEmpty) {
190+
outputSchema = SchemaUtils.updateNumValues(schema,
191+
$(predictionCol), numClasses)
192+
}
193+
if ($(rawPredictionCol).nonEmpty) {
194+
outputSchema = SchemaUtils.updateAttributeGroupSize(outputSchema,
195+
$(rawPredictionCol), numClasses)
196+
}
197+
outputSchema
198+
}
199+
186200
/**
187201
* Transforms dataset by reading from [[featuresCol]], and appending new columns as specified by
188202
* parameters:
@@ -193,29 +207,31 @@ abstract class ClassificationModel[FeaturesType, M <: ClassificationModel[Featur
193207
* @return transformed dataset
194208
*/
195209
override def transform(dataset: Dataset[_]): DataFrame = {
196-
transformSchema(dataset.schema, logging = true)
210+
val outputSchema = transformSchema(dataset.schema, logging = true)
197211

198212
// Output selected columns only.
199213
// This is a bit complicated since it tries to avoid repeated computation.
200214
var outputData = dataset
201215
var numColsOutput = 0
202216
if (getRawPredictionCol != "") {
203-
val predictRawUDF = udf { (features: Any) =>
217+
val predictRawUDF = udf { features: Any =>
204218
predictRaw(features.asInstanceOf[FeaturesType])
205219
}
206-
outputData = outputData.withColumn(getRawPredictionCol, predictRawUDF(col(getFeaturesCol)))
220+
outputData = outputData.withColumn(getRawPredictionCol, predictRawUDF(col(getFeaturesCol)),
221+
outputSchema($(rawPredictionCol)).metadata)
207222
numColsOutput += 1
208223
}
209224
if (getPredictionCol != "") {
210-
val predUDF = if (getRawPredictionCol != "") {
225+
val predCol = if (getRawPredictionCol != "") {
211226
udf(raw2prediction _).apply(col(getRawPredictionCol))
212227
} else {
213-
val predictUDF = udf { (features: Any) =>
228+
val predictUDF = udf { features: Any =>
214229
predict(features.asInstanceOf[FeaturesType])
215230
}
216231
predictUDF(col(getFeaturesCol))
217232
}
218-
outputData = outputData.withColumn(getPredictionCol, predUDF)
233+
outputData = outputData.withColumn(getPredictionCol, predCol,
234+
outputSchema($(predictionCol)).metadata)
219235
numColsOutput += 1
220236
}
221237

mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeMo
3636
import org.apache.spark.rdd.RDD
3737
import org.apache.spark.sql.{DataFrame, Dataset}
3838
import org.apache.spark.sql.functions.{col, udf}
39+
import org.apache.spark.sql.types.StructType
3940

4041
/**
4142
* Decision tree learning algorithm (http://en.wikipedia.org/wiki/Decision_tree_learning)
@@ -202,13 +203,23 @@ class DecisionTreeClassificationModel private[ml] (
202203
rootNode.predictImpl(features).prediction
203204
}
204205

206+
@Since("3.0.0")
207+
override def transformSchema(schema: StructType): StructType = {
208+
var outputSchema = super.transformSchema(schema)
209+
if ($(leafCol).nonEmpty) {
210+
outputSchema = SchemaUtils.updateField(outputSchema, getLeafField($(leafCol)))
211+
}
212+
outputSchema
213+
}
214+
205215
override def transform(dataset: Dataset[_]): DataFrame = {
206-
transformSchema(dataset.schema, logging = true)
216+
val outputSchema = transformSchema(dataset.schema, logging = true)
207217

208218
val outputData = super.transform(dataset)
209219
if ($(leafCol).nonEmpty) {
210220
val leafUDF = udf { features: Vector => predictLeaf(features) }
211-
outputData.withColumn($(leafCol), leafUDF(col($(featuresCol))))
221+
outputData.withColumn($(leafCol), leafUDF(col($(featuresCol))),
222+
outputSchema($(leafCol)).metadata)
212223
} else {
213224
outputData
214225
}

mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
3636
import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel}
3737
import org.apache.spark.sql.{DataFrame, Dataset}
3838
import org.apache.spark.sql.functions._
39+
import org.apache.spark.sql.types.StructType
3940

4041
/**
4142
* Gradient-Boosted Trees (GBTs) (http://en.wikipedia.org/wiki/Gradient_boosting)
@@ -291,13 +292,23 @@ class GBTClassificationModel private[ml](
291292
@Since("1.4.0")
292293
override def treeWeights: Array[Double] = _treeWeights
293294

295+
@Since("1.6.0")
296+
override def transformSchema(schema: StructType): StructType = {
297+
var outputSchema = super.transformSchema(schema)
298+
if ($(leafCol).nonEmpty) {
299+
outputSchema = SchemaUtils.updateField(outputSchema, getLeafField($(leafCol)))
300+
}
301+
outputSchema
302+
}
303+
294304
override def transform(dataset: Dataset[_]): DataFrame = {
295-
transformSchema(dataset.schema, logging = true)
305+
val outputSchema = transformSchema(dataset.schema, logging = true)
296306

297307
val outputData = super.transform(dataset)
298308
if ($(leafCol).nonEmpty) {
299309
val leafUDF = udf { features: Vector => predictLeaf(features) }
300-
outputData.withColumn($(leafCol), leafUDF(col($(featuresCol))))
310+
outputData.withColumn($(leafCol), leafUDF(col($(featuresCol))),
311+
outputSchema($(leafCol)).metadata)
301312
} else {
302313
outputData
303314
}

mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -161,13 +161,23 @@ final class OneVsRestModel private[ml] (
161161

162162
@Since("1.4.0")
163163
override def transformSchema(schema: StructType): StructType = {
164-
validateAndTransformSchema(schema, fitting = false, getClassifier.featuresDataType)
164+
var outputSchema = validateAndTransformSchema(schema, fitting = false,
165+
getClassifier.featuresDataType)
166+
if ($(predictionCol).nonEmpty) {
167+
outputSchema = SchemaUtils.updateNumValues(outputSchema,
168+
$(predictionCol), numClasses)
169+
}
170+
if ($(rawPredictionCol).nonEmpty) {
171+
outputSchema = SchemaUtils.updateAttributeGroupSize(outputSchema,
172+
$(rawPredictionCol), numClasses)
173+
}
174+
outputSchema
165175
}
166176

167177
@Since("2.0.0")
168178
override def transform(dataset: Dataset[_]): DataFrame = {
169179
// Check schema
170-
transformSchema(dataset.schema, logging = true)
180+
val outputSchema = transformSchema(dataset.schema, logging = true)
171181

172182
if (getPredictionCol.isEmpty && getRawPredictionCol.isEmpty) {
173183
logWarning(s"$uid: OneVsRestModel.transform() does nothing" +
@@ -230,6 +240,7 @@ final class OneVsRestModel private[ml] (
230240

231241
predictionColNames :+= getRawPredictionCol
232242
predictionColumns :+= rawPredictionUDF(col(accColName))
243+
.as($(rawPredictionCol), outputSchema($(rawPredictionCol)).metadata)
233244
}
234245

235246
if (getPredictionCol.nonEmpty) {

mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,15 @@ abstract class ProbabilisticClassificationModel[
9090
set(thresholds, value).asInstanceOf[M]
9191
}
9292

93+
override def transformSchema(schema: StructType): StructType = {
94+
var outputSchema = super.transformSchema(schema)
95+
if ($(probabilityCol).nonEmpty) {
96+
outputSchema = SchemaUtils.updateAttributeGroupSize(outputSchema,
97+
$(probabilityCol), numClasses)
98+
}
99+
outputSchema
100+
}
101+
93102
/**
94103
* Transforms dataset by reading from [[featuresCol]], and appending new columns as specified by
95104
* parameters:
@@ -101,7 +110,7 @@ abstract class ProbabilisticClassificationModel[
101110
* @return transformed dataset
102111
*/
103112
override def transform(dataset: Dataset[_]): DataFrame = {
104-
transformSchema(dataset.schema, logging = true)
113+
val outputSchema = transformSchema(dataset.schema, logging = true)
105114
if (isDefined(thresholds)) {
106115
require($(thresholds).length == numClasses, this.getClass.getSimpleName +
107116
".transform() called with non-matching numClasses and thresholds.length." +
@@ -113,36 +122,39 @@ abstract class ProbabilisticClassificationModel[
113122
var outputData = dataset
114123
var numColsOutput = 0
115124
if ($(rawPredictionCol).nonEmpty) {
116-
val predictRawUDF = udf { (features: Any) =>
125+
val predictRawUDF = udf { features: Any =>
117126
predictRaw(features.asInstanceOf[FeaturesType])
118127
}
119-
outputData = outputData.withColumn(getRawPredictionCol, predictRawUDF(col(getFeaturesCol)))
128+
outputData = outputData.withColumn(getRawPredictionCol, predictRawUDF(col(getFeaturesCol)),
129+
outputSchema($(rawPredictionCol)).metadata)
120130
numColsOutput += 1
121131
}
122132
if ($(probabilityCol).nonEmpty) {
123-
val probUDF = if ($(rawPredictionCol).nonEmpty) {
133+
val probCol = if ($(rawPredictionCol).nonEmpty) {
124134
udf(raw2probability _).apply(col($(rawPredictionCol)))
125135
} else {
126-
val probabilityUDF = udf { (features: Any) =>
136+
val probabilityUDF = udf { features: Any =>
127137
predictProbability(features.asInstanceOf[FeaturesType])
128138
}
129139
probabilityUDF(col($(featuresCol)))
130140
}
131-
outputData = outputData.withColumn($(probabilityCol), probUDF)
141+
outputData = outputData.withColumn($(probabilityCol), probCol,
142+
outputSchema($(probabilityCol)).metadata)
132143
numColsOutput += 1
133144
}
134145
if ($(predictionCol).nonEmpty) {
135-
val predUDF = if ($(rawPredictionCol).nonEmpty) {
146+
val predCol = if ($(rawPredictionCol).nonEmpty) {
136147
udf(raw2prediction _).apply(col($(rawPredictionCol)))
137148
} else if ($(probabilityCol).nonEmpty) {
138149
udf(probability2prediction _).apply(col($(probabilityCol)))
139150
} else {
140-
val predictUDF = udf { (features: Any) =>
151+
val predictUDF = udf { features: Any =>
141152
predict(features.asInstanceOf[FeaturesType])
142153
}
143154
predictUDF(col($(featuresCol)))
144155
}
145-
outputData = outputData.withColumn($(predictionCol), predUDF)
156+
outputData = outputData.withColumn($(predictionCol), predCol,
157+
outputSchema($(predictionCol)).metadata)
146158
numColsOutput += 1
147159
}
148160

mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestMo
3636
import org.apache.spark.rdd.RDD
3737
import org.apache.spark.sql.{DataFrame, Dataset}
3838
import org.apache.spark.sql.functions.{col, udf}
39+
import org.apache.spark.sql.types.StructType
3940

4041
/**
4142
* <a href="http://en.wikipedia.org/wiki/Random_forest">Random Forest</a> learning algorithm for
@@ -210,13 +211,23 @@ class RandomForestClassificationModel private[ml] (
210211
@Since("1.4.0")
211212
override def treeWeights: Array[Double] = _treeWeights
212213

214+
@Since("1.4.0")
215+
override def transformSchema(schema: StructType): StructType = {
216+
var outputSchema = super.transformSchema(schema)
217+
if ($(leafCol).nonEmpty) {
218+
outputSchema = SchemaUtils.updateField(outputSchema, getLeafField($(leafCol)))
219+
}
220+
outputSchema
221+
}
222+
213223
override def transform(dataset: Dataset[_]): DataFrame = {
214-
transformSchema(dataset.schema, logging = true)
224+
val outputSchema = transformSchema(dataset.schema, logging = true)
215225

216226
val outputData = super.transform(dataset)
217227
if ($(leafCol).nonEmpty) {
218228
val leafUDF = udf { features: Vector => predictLeaf(features) }
219-
outputData.withColumn($(leafCol), leafUDF(col($(featuresCol))))
229+
outputData.withColumn($(leafCol), leafUDF(col($(featuresCol))),
230+
outputSchema($(leafCol)).metadata)
220231
} else {
221232
outputData
222233
}

mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -110,15 +110,21 @@ class BisectingKMeansModel private[ml] (
110110

111111
@Since("2.0.0")
112112
override def transform(dataset: Dataset[_]): DataFrame = {
113-
transformSchema(dataset.schema, logging = true)
113+
val outputSchema = transformSchema(dataset.schema, logging = true)
114114
val predictUDF = udf((vector: Vector) => predict(vector))
115115
dataset.withColumn($(predictionCol),
116-
predictUDF(DatasetUtils.columnToVector(dataset, getFeaturesCol)))
116+
predictUDF(DatasetUtils.columnToVector(dataset, getFeaturesCol)),
117+
outputSchema($(predictionCol)).metadata)
117118
}
118119

119120
@Since("2.0.0")
120121
override def transformSchema(schema: StructType): StructType = {
121-
validateAndTransformSchema(schema)
122+
var outputSchema = validateAndTransformSchema(schema)
123+
if ($(predictionCol).nonEmpty) {
124+
outputSchema = SchemaUtils.updateNumValues(outputSchema,
125+
$(predictionCol), parentModel.k)
126+
}
127+
outputSchema
122128
}
123129

124130
@Since("3.0.0")

0 commit comments

Comments
 (0)