Skip to content

Commit 9983670

Browse files
MrBagoRobert Kruszewski
authored andcommitted
[SPARK-23859][ML] Initial PR for Instrumentation improvements: UUID and logging levels
## What changes were proposed in this pull request? Initial PR for Instrumentation improvements: UUID and logging levels. This PR takes over apache#20837 Closes apache#20837 ## How was this patch tested? Manual. Author: Bago Amirbekian <[email protected]> Author: WeichenXu <[email protected]> Closes apache#20982 from WeichenXu123/better-instrumentation.
1 parent d15dbfb commit 9983670

File tree

2 files changed

+41
-14
lines changed

2 files changed

+41
-14
lines changed

mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -517,6 +517,9 @@ class LogisticRegression @Since("1.2.0") (
517517
(new MultivariateOnlineSummarizer, new MultiClassSummarizer)
518518
)(seqOp, combOp, $(aggregationDepth))
519519
}
520+
instr.logNamedValue(Instrumentation.loggerTags.numExamples, summarizer.count)
521+
instr.logNamedValue("lowestLabelWeight", labelSummarizer.histogram.min.toString)
522+
instr.logNamedValue("highestLabelWeight", labelSummarizer.histogram.max.toString)
520523

521524
val histogram = labelSummarizer.histogram
522525
val numInvalid = labelSummarizer.countInvalid
@@ -560,15 +563,15 @@ class LogisticRegression @Since("1.2.0") (
560563
if (numInvalid != 0) {
561564
val msg = s"Classification labels should be in [0 to ${numClasses - 1}]. " +
562565
s"Found $numInvalid invalid labels."
563-
logError(msg)
566+
instr.logError(msg)
564567
throw new SparkException(msg)
565568
}
566569

567570
val isConstantLabel = histogram.count(_ != 0.0) == 1
568571

569572
if ($(fitIntercept) && isConstantLabel && !usingBoundConstrainedOptimization) {
570-
logWarning(s"All labels are the same value and fitIntercept=true, so the coefficients " +
571-
s"will be zeros. Training is not needed.")
573+
instr.logWarning(s"All labels are the same value and fitIntercept=true, so the " +
574+
s"coefficients will be zeros. Training is not needed.")
572575
val constantLabelIndex = Vectors.dense(histogram).argmax
573576
val coefMatrix = new SparseMatrix(numCoefficientSets, numFeatures,
574577
new Array[Int](numCoefficientSets + 1), Array.empty[Int], Array.empty[Double],
@@ -581,7 +584,7 @@ class LogisticRegression @Since("1.2.0") (
581584
(coefMatrix, interceptVec, Array.empty[Double])
582585
} else {
583586
if (!$(fitIntercept) && isConstantLabel) {
584-
logWarning(s"All labels belong to a single class and fitIntercept=false. It's a " +
587+
instr.logWarning(s"All labels belong to a single class and fitIntercept=false. It's a " +
585588
s"dangerous ground, so the algorithm may not converge.")
586589
}
587590

@@ -590,7 +593,7 @@ class LogisticRegression @Since("1.2.0") (
590593

591594
if (!$(fitIntercept) && (0 until numFeatures).exists { i =>
592595
featuresStd(i) == 0.0 && featuresMean(i) != 0.0 }) {
593-
logWarning("Fitting LogisticRegressionModel without intercept on dataset with " +
596+
instr.logWarning("Fitting LogisticRegressionModel without intercept on dataset with " +
594597
"constant nonzero column, Spark MLlib outputs zero coefficients for constant " +
595598
"nonzero columns. This behavior is the same as R glmnet but different from LIBSVM.")
596599
}
@@ -708,7 +711,7 @@ class LogisticRegression @Since("1.2.0") (
708711
(_initialModel.interceptVector.size == numCoefficientSets) &&
709712
(_initialModel.getFitIntercept == $(fitIntercept))
710713
if (!modelIsValid) {
711-
logWarning(s"Initial coefficients will be ignored! Its dimensions " +
714+
instr.logWarning(s"Initial coefficients will be ignored! Its dimensions " +
712715
s"(${providedCoefs.numRows}, ${providedCoefs.numCols}) did not match the " +
713716
s"expected size ($numCoefficientSets, $numFeatures)")
714717
}

mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
package org.apache.spark.ml.util
1919

20-
import java.util.concurrent.atomic.AtomicLong
20+
import java.util.UUID
2121

2222
import org.json4s._
2323
import org.json4s.JsonDSL._
@@ -42,7 +42,7 @@ import org.apache.spark.sql.Dataset
4242
private[spark] class Instrumentation[E <: Estimator[_]] private (
4343
estimator: E, dataset: RDD[_]) extends Logging {
4444

45-
private val id = Instrumentation.counter.incrementAndGet()
45+
private val id = UUID.randomUUID()
4646
private val prefix = {
4747
val className = estimator.getClass.getSimpleName
4848
s"$className-${estimator.uid}-${dataset.hashCode()}-$id: "
@@ -56,12 +56,31 @@ private[spark] class Instrumentation[E <: Estimator[_]] private (
5656
}
5757

5858
/**
59-
* Logs a message with a prefix that uniquely identifies the training session.
59+
* Logs a warning message with a prefix that uniquely identifies the training session.
6060
*/
61-
def log(msg: String): Unit = {
62-
logInfo(prefix + msg)
61+
override def logWarning(msg: => String): Unit = {
62+
super.logWarning(prefix + msg)
6363
}
6464

65+
/**
66+
* Logs a error message with a prefix that uniquely identifies the training session.
67+
*/
68+
override def logError(msg: => String): Unit = {
69+
super.logError(prefix + msg)
70+
}
71+
72+
/**
73+
* Logs an info message with a prefix that uniquely identifies the training session.
74+
*/
75+
override def logInfo(msg: => String): Unit = {
76+
super.logInfo(prefix + msg)
77+
}
78+
79+
/**
80+
* Alias for logInfo, see above.
81+
*/
82+
def log(msg: String): Unit = logInfo(msg)
83+
6584
/**
6685
* Logs the value of the given parameters for the estimator being used in this session.
6786
*/
@@ -77,11 +96,11 @@ private[spark] class Instrumentation[E <: Estimator[_]] private (
7796
}
7897

7998
def logNumFeatures(num: Long): Unit = {
80-
log(compact(render("numFeatures" -> num)))
99+
logNamedValue(Instrumentation.loggerTags.numFeatures, num)
81100
}
82101

83102
def logNumClasses(num: Long): Unit = {
84-
log(compact(render("numClasses" -> num)))
103+
logNamedValue(Instrumentation.loggerTags.numClasses, num)
85104
}
86105

87106
/**
@@ -107,7 +126,12 @@ private[spark] class Instrumentation[E <: Estimator[_]] private (
107126
* Some common methods for logging information about a training session.
108127
*/
109128
private[spark] object Instrumentation {
110-
private val counter = new AtomicLong(0)
129+
130+
object loggerTags {
131+
val numFeatures = "numFeatures"
132+
val numClasses = "numClasses"
133+
val numExamples = "numExamples"
134+
}
111135

112136
/**
113137
* Creates an instrumentation object for a training session.

0 commit comments

Comments
 (0)