Skip to content

Commit a3dccd2

Browse files
mgaido91srowen
authored andcommitted
[SPARK-10697][ML] Add lift to Association rules
## What changes were proposed in this pull request? The PR adds the lift measure to Association rules. ## How was this patch tested? existing and modified UTs Closes apache#22236 from mgaido91/SPARK-10697. Authored-by: Marco Gaido <[email protected]> Signed-off-by: Sean Owen <[email protected]>
1 parent 6ad8d4c commit a3dccd2

File tree

9 files changed

+108
-40
lines changed

9 files changed

+108
-40
lines changed

R/pkg/R/mllib_fpm.R

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,10 +116,11 @@ setMethod("spark.freqItemsets", signature(object = "FPGrowthModel"),
116116
# Get association rules.
117117

118118
#' @return A \code{SparkDataFrame} with association rules.
119-
#' The \code{SparkDataFrame} contains three columns:
119+
#' The \code{SparkDataFrame} contains four columns:
120120
#' \code{antecedent} (an array of the same type as the input column),
121121
#' \code{consequent} (an array of the same type as the input column),
122-
#' and \code{condfidence} (confidence).
122+
#' \code{condfidence} (confidence for the rule)
123+
#' and \code{lift} (lift for the rule)
123124
#' @rdname spark.fpGrowth
124125
#' @aliases associationRules,FPGrowthModel-method
125126
#' @note spark.associationRules(FPGrowthModel) since 2.2.0

R/pkg/tests/fulltests/test_mllib_fpm.R

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,8 @@ test_that("spark.fpGrowth", {
4444
expected_association_rules <- data.frame(
4545
antecedent = I(list(list("2"), list("3"))),
4646
consequent = I(list(list("1"), list("1"))),
47-
confidence = c(1, 1)
47+
confidence = c(1, 1),
48+
lift = c(1, 1)
4849
)
4950

5051
expect_equivalent(expected_association_rules, collect(spark.associationRules(model)))

mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala

Lines changed: 45 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ package org.apache.spark.ml.fpm
2020
import scala.reflect.ClassTag
2121

2222
import org.apache.hadoop.fs.Path
23+
import org.json4s.{DefaultFormats, JObject}
24+
import org.json4s.JsonDSL._
2325

2426
import org.apache.spark.annotation.{Experimental, Since}
2527
import org.apache.spark.ml.{Estimator, Model}
@@ -34,6 +36,7 @@ import org.apache.spark.sql._
3436
import org.apache.spark.sql.functions._
3537
import org.apache.spark.sql.types._
3638
import org.apache.spark.storage.StorageLevel
39+
import org.apache.spark.util.VersionUtils
3740

3841
/**
3942
* Common params for FPGrowth and FPGrowthModel
@@ -175,7 +178,8 @@ class FPGrowth @Since("2.2.0") (
175178
if (handlePersistence) {
176179
items.persist(StorageLevel.MEMORY_AND_DISK)
177180
}
178-
181+
val inputRowCount = items.count()
182+
instr.logNumExamples(inputRowCount)
179183
val parentModel = mllibFP.run(items)
180184
val rows = parentModel.freqItemsets.map(f => Row(f.items, f.freq))
181185
val schema = StructType(Seq(
@@ -187,7 +191,8 @@ class FPGrowth @Since("2.2.0") (
187191
items.unpersist()
188192
}
189193

190-
copyValues(new FPGrowthModel(uid, frequentItems)).setParent(this)
194+
copyValues(new FPGrowthModel(uid, frequentItems, parentModel.itemSupport, inputRowCount))
195+
.setParent(this)
191196
}
192197

193198
@Since("2.2.0")
@@ -217,7 +222,9 @@ object FPGrowth extends DefaultParamsReadable[FPGrowth] {
217222
@Experimental
218223
class FPGrowthModel private[ml] (
219224
@Since("2.2.0") override val uid: String,
220-
@Since("2.2.0") @transient val freqItemsets: DataFrame)
225+
@Since("2.2.0") @transient val freqItemsets: DataFrame,
226+
private val itemSupport: scala.collection.Map[Any, Double],
227+
private val numTrainingRecords: Long)
221228
extends Model[FPGrowthModel] with FPGrowthParams with MLWritable {
222229

223230
/** @group setParam */
@@ -241,17 +248,17 @@ class FPGrowthModel private[ml] (
241248
@transient private var _cachedRules: DataFrame = _
242249

243250
/**
244-
* Get association rules fitted using the minConfidence. Returns a dataframe
245-
* with three fields, "antecedent", "consequent" and "confidence", where "antecedent" and
246-
* "consequent" are Array[T] and "confidence" is Double.
251+
* Get association rules fitted using the minConfidence. Returns a dataframe with four fields,
252+
* "antecedent", "consequent", "confidence" and "lift", where "antecedent" and "consequent" are
253+
* Array[T], whereas "confidence" and "lift" are Double.
247254
*/
248255
@Since("2.2.0")
249256
@transient def associationRules: DataFrame = {
250257
if ($(minConfidence) == _cachedMinConf) {
251258
_cachedRules
252259
} else {
253260
_cachedRules = AssociationRules
254-
.getAssociationRulesFromFP(freqItemsets, "items", "freq", $(minConfidence))
261+
.getAssociationRulesFromFP(freqItemsets, "items", "freq", $(minConfidence), itemSupport)
255262
_cachedMinConf = $(minConfidence)
256263
_cachedRules
257264
}
@@ -301,7 +308,7 @@ class FPGrowthModel private[ml] (
301308

302309
@Since("2.2.0")
303310
override def copy(extra: ParamMap): FPGrowthModel = {
304-
val copied = new FPGrowthModel(uid, freqItemsets)
311+
val copied = new FPGrowthModel(uid, freqItemsets, itemSupport, numTrainingRecords)
305312
copyValues(copied, extra).setParent(this.parent)
306313
}
307314

@@ -323,7 +330,8 @@ object FPGrowthModel extends MLReadable[FPGrowthModel] {
323330
class FPGrowthModelWriter(instance: FPGrowthModel) extends MLWriter {
324331

325332
override protected def saveImpl(path: String): Unit = {
326-
DefaultParamsWriter.saveMetadata(instance, path, sc)
333+
val extraMetadata: JObject = Map("numTrainingRecords" -> instance.numTrainingRecords)
334+
DefaultParamsWriter.saveMetadata(instance, path, sc, extraMetadata = Some(extraMetadata))
327335
val dataPath = new Path(path, "data").toString
328336
instance.freqItemsets.write.parquet(dataPath)
329337
}
@@ -335,10 +343,28 @@ object FPGrowthModel extends MLReadable[FPGrowthModel] {
335343
private val className = classOf[FPGrowthModel].getName
336344

337345
override def load(path: String): FPGrowthModel = {
346+
implicit val format = DefaultFormats
338347
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
348+
val (major, minor) = VersionUtils.majorMinorVersion(metadata.sparkVersion)
349+
val numTrainingRecords = if (major.toInt < 2 || (major.toInt == 2 && minor.toInt < 4)) {
350+
// 2.3 and before don't store the count
351+
0L
352+
} else {
353+
// 2.4+
354+
(metadata.metadata \ "numTrainingRecords").extract[Long]
355+
}
339356
val dataPath = new Path(path, "data").toString
340357
val frequentItems = sparkSession.read.parquet(dataPath)
341-
val model = new FPGrowthModel(metadata.uid, frequentItems)
358+
val itemSupport = if (numTrainingRecords == 0L) {
359+
Map.empty[Any, Double]
360+
} else {
361+
frequentItems.rdd.flatMap {
362+
case Row(items: Seq[_], count: Long) if items.length == 1 =>
363+
Some(items.head -> count.toDouble / numTrainingRecords)
364+
case _ => None
365+
}.collectAsMap()
366+
}
367+
val model = new FPGrowthModel(metadata.uid, frequentItems, itemSupport, numTrainingRecords)
342368
metadata.getAndSetParams(model)
343369
model
344370
}
@@ -354,27 +380,30 @@ private[fpm] object AssociationRules {
354380
* @param itemsCol column name for frequent itemsets
355381
* @param freqCol column name for appearance count of the frequent itemsets
356382
* @param minConfidence minimum confidence for generating the association rules
357-
* @return a DataFrame("antecedent"[Array], "consequent"[Array], "confidence"[Double])
358-
* containing the association rules.
383+
* @param itemSupport map containing an item and its support
384+
* @return a DataFrame("antecedent"[Array], "consequent"[Array], "confidence"[Double],
385+
* "lift" [Double]) containing the association rules.
359386
*/
360387
def getAssociationRulesFromFP[T: ClassTag](
361388
dataset: Dataset[_],
362389
itemsCol: String,
363390
freqCol: String,
364-
minConfidence: Double): DataFrame = {
391+
minConfidence: Double,
392+
itemSupport: scala.collection.Map[T, Double]): DataFrame = {
365393

366394
val freqItemSetRdd = dataset.select(itemsCol, freqCol).rdd
367395
.map(row => new FreqItemset(row.getSeq[T](0).toArray, row.getLong(1)))
368396
val rows = new MLlibAssociationRules()
369397
.setMinConfidence(minConfidence)
370-
.run(freqItemSetRdd)
371-
.map(r => Row(r.antecedent, r.consequent, r.confidence))
398+
.run(freqItemSetRdd, itemSupport)
399+
.map(r => Row(r.antecedent, r.consequent, r.confidence, r.lift.orNull))
372400

373401
val dt = dataset.schema(itemsCol).dataType
374402
val schema = StructType(Seq(
375403
StructField("antecedent", dt, nullable = false),
376404
StructField("consequent", dt, nullable = false),
377-
StructField("confidence", DoubleType, nullable = false)))
405+
StructField("confidence", DoubleType, nullable = false),
406+
StructField("lift", DoubleType)))
378407
val rules = dataset.sparkSession.createDataFrame(rows, schema)
379408
rules
380409
}

mllib/src/main/scala/org/apache/spark/mllib/fpm/AssociationRules.scala

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -56,11 +56,24 @@ class AssociationRules private[fpm] (
5656
/**
5757
* Computes the association rules with confidence above `minConfidence`.
5858
* @param freqItemsets frequent itemset model obtained from [[FPGrowth]]
59-
* @return a `Set[Rule[Item]]` containing the association rules.
59+
* @return a `RDD[Rule[Item]]` containing the association rules.
6060
*
6161
*/
6262
@Since("1.5.0")
6363
def run[Item: ClassTag](freqItemsets: RDD[FreqItemset[Item]]): RDD[Rule[Item]] = {
64+
run(freqItemsets, Map.empty[Item, Double])
65+
}
66+
67+
/**
68+
* Computes the association rules with confidence above `minConfidence`.
69+
* @param freqItemsets frequent itemset model obtained from [[FPGrowth]]
70+
* @param itemSupport map containing an item and its support
71+
* @return a `RDD[Rule[Item]]` containing the association rules. The rules will be able to
72+
* compute also the lift metric.
73+
*/
74+
@Since("2.4.0")
75+
def run[Item: ClassTag](freqItemsets: RDD[FreqItemset[Item]],
76+
itemSupport: scala.collection.Map[Item, Double]): RDD[Rule[Item]] = {
6477
// For candidate rule X => Y, generate (X, (Y, freq(X union Y)))
6578
val candidates = freqItemsets.flatMap { itemset =>
6679
val items = itemset.items
@@ -76,8 +89,13 @@ class AssociationRules private[fpm] (
7689
// Join to get (X, ((Y, freq(X union Y)), freq(X))), generate rules, and filter by confidence
7790
candidates.join(freqItemsets.map(x => (x.items.toSeq, x.freq)))
7891
.map { case (antecendent, ((consequent, freqUnion), freqAntecedent)) =>
79-
new Rule(antecendent.toArray, consequent.toArray, freqUnion, freqAntecedent)
80-
}.filter(_.confidence >= minConfidence)
92+
new Rule(antecendent.toArray,
93+
consequent.toArray,
94+
freqUnion,
95+
freqAntecedent,
96+
// the consequent contains always only one element
97+
itemSupport.get(consequent.head))
98+
}.filter(_.confidence >= minConfidence)
8199
}
82100

83101
/**
@@ -107,14 +125,21 @@ object AssociationRules {
107125
@Since("1.5.0") val antecedent: Array[Item],
108126
@Since("1.5.0") val consequent: Array[Item],
109127
freqUnion: Double,
110-
freqAntecedent: Double) extends Serializable {
128+
freqAntecedent: Double,
129+
freqConsequent: Option[Double]) extends Serializable {
111130

112131
/**
113132
* Returns the confidence of the rule.
114133
*
115134
*/
116135
@Since("1.5.0")
117-
def confidence: Double = freqUnion.toDouble / freqAntecedent
136+
def confidence: Double = freqUnion / freqAntecedent
137+
138+
/**
139+
* Returns the lift of the rule.
140+
*/
141+
@Since("2.4.0")
142+
def lift: Option[Double] = freqConsequent.map(fCons => confidence / fCons)
118143

119144
require(antecedent.toSet.intersect(consequent.toSet).isEmpty, {
120145
val sharedItems = antecedent.toSet.intersect(consequent.toSet)
@@ -142,7 +167,7 @@ object AssociationRules {
142167

143168
override def toString: String = {
144169
s"${antecedent.mkString("{", ",", "}")} => " +
145-
s"${consequent.mkString("{", ",", "}")}: ${confidence}"
170+
s"${consequent.mkString("{", ",", "}")}: (confidence: $confidence; lift: $lift)"
146171
}
147172
}
148173
}

mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -48,17 +48,22 @@ import org.apache.spark.storage.StorageLevel
4848
* @tparam Item item type
4949
*/
5050
@Since("1.3.0")
51-
class FPGrowthModel[Item: ClassTag] @Since("1.3.0") (
52-
@Since("1.3.0") val freqItemsets: RDD[FreqItemset[Item]])
51+
class FPGrowthModel[Item: ClassTag] @Since("2.4.0") (
52+
@Since("1.3.0") val freqItemsets: RDD[FreqItemset[Item]],
53+
@Since("2.4.0") val itemSupport: Map[Item, Double])
5354
extends Saveable with Serializable {
55+
56+
@Since("1.3.0")
57+
def this(freqItemsets: RDD[FreqItemset[Item]]) = this(freqItemsets, Map.empty)
58+
5459
/**
5560
* Generates association rules for the `Item`s in [[freqItemsets]].
5661
* @param confidence minimal confidence of the rules produced
5762
*/
5863
@Since("1.5.0")
5964
def generateAssociationRules(confidence: Double): RDD[AssociationRules.Rule[Item]] = {
6065
val associationRules = new AssociationRules(confidence)
61-
associationRules.run(freqItemsets)
66+
associationRules.run(freqItemsets, itemSupport)
6267
}
6368

6469
/**
@@ -213,9 +218,12 @@ class FPGrowth private[spark] (
213218
val minCount = math.ceil(minSupport * count).toLong
214219
val numParts = if (numPartitions > 0) numPartitions else data.partitions.length
215220
val partitioner = new HashPartitioner(numParts)
216-
val freqItems = genFreqItems(data, minCount, partitioner)
217-
val freqItemsets = genFreqItemsets(data, minCount, freqItems, partitioner)
218-
new FPGrowthModel(freqItemsets)
221+
val freqItemsCount = genFreqItems(data, minCount, partitioner)
222+
val freqItemsets = genFreqItemsets(data, minCount, freqItemsCount.map(_._1), partitioner)
223+
val itemSupport = freqItemsCount.map {
224+
case (item, cnt) => item -> cnt.toDouble / count
225+
}.toMap
226+
new FPGrowthModel(freqItemsets, itemSupport)
219227
}
220228

221229
/**
@@ -231,12 +239,12 @@ class FPGrowth private[spark] (
231239
* Generates frequent items by filtering the input data using minimal support level.
232240
* @param minCount minimum count for frequent itemsets
233241
* @param partitioner partitioner used to distribute items
234-
* @return array of frequent pattern ordered by their frequencies
242+
* @return array of frequent patterns and their frequencies ordered by their frequencies
235243
*/
236244
private def genFreqItems[Item: ClassTag](
237245
data: RDD[Array[Item]],
238246
minCount: Long,
239-
partitioner: Partitioner): Array[Item] = {
247+
partitioner: Partitioner): Array[(Item, Long)] = {
240248
data.flatMap { t =>
241249
val uniq = t.toSet
242250
if (t.length != uniq.size) {
@@ -248,7 +256,6 @@ class FPGrowth private[spark] (
248256
.filter(_._2 >= minCount)
249257
.collect()
250258
.sortBy(-_._2)
251-
.map(_._1)
252259
}
253260

254261
/**

mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,9 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
3939
val model = new FPGrowth().setMinSupport(0.5).fit(data)
4040
val generatedRules = model.setMinConfidence(0.5).associationRules
4141
val expectedRules = spark.createDataFrame(Seq(
42-
(Array("2"), Array("1"), 1.0),
43-
(Array("1"), Array("2"), 0.75)
44-
)).toDF("antecedent", "consequent", "confidence")
42+
(Array("2"), Array("1"), 1.0, 1.0),
43+
(Array("1"), Array("2"), 0.75, 1.0)
44+
)).toDF("antecedent", "consequent", "confidence", "lift")
4545
.withColumn("antecedent", col("antecedent").cast(ArrayType(dt)))
4646
.withColumn("consequent", col("consequent").cast(ArrayType(dt)))
4747
assert(expectedRules.sort("antecedent").rdd.collect().sameElements(

project/MimaExcludes.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,10 @@ object MimaExcludes {
3636

3737
// Exclude rules for 2.4.x
3838
lazy val v24excludes = v23excludes ++ Seq(
39+
// [SPARK-10697][ML] Add lift to Association rules
40+
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.fpm.FPGrowthModel.this"),
41+
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.fpm.AssociationRules#Rule.this"),
42+
3943
// [SPARK-25044] Address translation of LMF closure primitive args to Object in Scala 2.12
4044
ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.sql.expressions.UserDefinedFunction$"),
4145
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.apply"),

python/pyspark/ml/fpm.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,10 +145,11 @@ def freqItemsets(self):
145145
@since("2.2.0")
146146
def associationRules(self):
147147
"""
148-
DataFrame with three columns:
148+
DataFrame with four columns:
149149
* `antecedent` - Array of the same type as the input column.
150150
* `consequent` - Array of the same type as the input column.
151151
* `confidence` - Confidence for the rule (`DoubleType`).
152+
* `lift` - Lift for the rule (`DoubleType`).
152153
"""
153154
return self._call_java("associationRules")
154155

python/pyspark/ml/tests.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2158,8 +2158,8 @@ def test_association_rules(self):
21582158
fpm = fp.fit(self.data)
21592159

21602160
expected_association_rules = self.spark.createDataFrame(
2161-
[([3], [1], 1.0), ([2], [1], 1.0)],
2162-
["antecedent", "consequent", "confidence"]
2161+
[([3], [1], 1.0, 1.0), ([2], [1], 1.0, 1.0)],
2162+
["antecedent", "consequent", "confidence", "lift"]
21632163
)
21642164
actual_association_rules = fpm.associationRules
21652165

0 commit comments

Comments
 (0)