Skip to content
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
// Licensed under the MIT License.
package com.azure.cosmos.spark

import com.azure.cosmos.implementation.guava25.collect.{HashBiMap, Maps}
import com.azure.cosmos.implementation.{SparkBridgeImplementationInternal, UUIDs}
import com.azure.cosmos.changeFeedMetrics.{ChangeFeedMetricsListener, ChangeFeedMetricsTracker}
import com.azure.cosmos.spark.CosmosPredicates.{assertNotNull, assertNotNullOrEmpty, assertOnSparkDriver}
import com.azure.cosmos.spark.diagnostics.{DiagnosticsContext, LoggerHelper}
import org.apache.spark.broadcast.Broadcast
Expand All @@ -12,7 +14,12 @@ import org.apache.spark.sql.connector.read.{InputPartition, PartitionReaderFacto
import org.apache.spark.sql.types.StructType

import java.time.Duration
import java.util.UUID
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.atomic.AtomicLong

// scalastyle:off underscore.import
import scala.collection.JavaConverters._
// scalastyle:on underscore.import

// scala style rule flaky - even complaining on partial log messages
// scalastyle:off multiple.string.literals
Expand Down Expand Up @@ -57,6 +64,18 @@ private class ChangeFeedMicroBatchStream

private var latestOffsetSnapshot: Option[ChangeFeedOffset] = None

private val partitionIndex = new AtomicLong(0)
private val partitionIndexMap = Maps.synchronizedBiMap(HashBiMap.create[NormalizedRange, Long]())
private val partitionMetricsMap = new ConcurrentHashMap[NormalizedRange, ChangeFeedMetricsTracker]()

// Register metrics listener
if (CosmosConstants.ChangeFeedMetricsListenerConfig.metricsListenerEnabled) {
log.logInfo("Register ChangeFeedMetricsListener")
session.sparkContext.addSparkListener(new ChangeFeedMetricsListener(partitionIndexMap, partitionMetricsMap))
} else {
log.logInfo("ChangeFeedMetricsListener is disabled")
}

override def latestOffset(): Offset = {
// For Spark data streams implementing SupportsAdmissionControl trait
// latestOffset(Offset, ReadLimit) is called instead
Expand Down Expand Up @@ -99,11 +118,15 @@ private class ChangeFeedMicroBatchStream
end
.inputPartitions
.get
.map(partition => partition
.withContinuationState(
SparkBridgeImplementationInternal
.extractChangeFeedStateForRange(start.changeFeedState, partition.feedRange),
clearEndLsn = false))
.map(partition => {
val index = partitionIndexMap.asScala.getOrElseUpdate(partition.feedRange, partitionIndex.incrementAndGet())
partition
.withContinuationState(
SparkBridgeImplementationInternal
.extractChangeFeedStateForRange(start.changeFeedState, partition.feedRange),
clearEndLsn = false)
.withIndex(index)
})
}

/**
Expand Down Expand Up @@ -150,7 +173,8 @@ private class ChangeFeedMicroBatchStream
this.containerConfig,
this.partitioningConfig,
this.defaultParallelism,
this.container
this.container,
Some(this.partitionMetricsMap)
)

if (offset.changeFeedState != startChangeFeedOffset.changeFeedState) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import com.azure.cosmos.implementation.guava25.base.MoreObjects.firstNonNull
import com.azure.cosmos.implementation.guava25.base.Strings.emptyToNull
import com.azure.cosmos.spark.diagnostics.BasicLoggingTrait
import org.apache.spark.TaskContext
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.util.AccumulatorV2

Expand Down Expand Up @@ -41,20 +42,23 @@ object SparkInternalsBridge extends BasicLoggingTrait {
private final lazy val reflectionAccessAllowed = new AtomicBoolean(getSparkReflectionAccessAllowed)

def getInternalCustomTaskMetricsAsSQLMetric(knownCosmosMetricNames: Set[String]): Map[String, SQLMetric] = {
Option.apply(TaskContext.get()) match {
case Some(taskCtx) => getInternalCustomTaskMetricsAsSQLMetricInternal(knownCosmosMetricNames, taskCtx.taskMetrics())
case None => Map.empty[String, SQLMetric]
}
}

def getInternalCustomTaskMetricsAsSQLMetric(knownCosmosMetricNames: Set[String], taskMetrics: TaskMetrics): Map[String, SQLMetric] = {

if (!reflectionAccessAllowed.get) {
Map.empty[String, SQLMetric]
} else {
Option.apply(TaskContext.get()) match {
case Some(taskCtx) => getInternalCustomTaskMetricsAsSQLMetricInternal(knownCosmosMetricNames, taskCtx)
case None => Map.empty[String, SQLMetric]
}
getInternalCustomTaskMetricsAsSQLMetricInternal(knownCosmosMetricNames, taskMetrics)
}
}

private def getAccumulators(taskCtx: TaskContext): Option[ArrayBuffer[AccumulatorV2[_, _]]] = {
private def getAccumulators(taskMetrics: TaskMetrics): Option[ArrayBuffer[AccumulatorV2[_, _]]] = {
try {
val taskMetrics: Object = taskCtx.taskMetrics()
val method = Option(accumulatorsMethod.get) match {
case Some(existing) => existing
case None =>
Expand All @@ -79,8 +83,8 @@ object SparkInternalsBridge extends BasicLoggingTrait {

private def getInternalCustomTaskMetricsAsSQLMetricInternal(
knownCosmosMetricNames: Set[String],
taskCtx: TaskContext): Map[String, SQLMetric] = {
getAccumulators(taskCtx) match {
taskMetrics: TaskMetrics): Map[String, SQLMetric] = {
getAccumulators(taskMetrics) match {
case Some(accumulators) => accumulators
.filter(accumulable => accumulable.isInstanceOf[SQLMetric]
&& accumulable.name.isDefined
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
// Licensed under the MIT License.
package com.azure.cosmos.spark

import com.azure.cosmos.implementation.guava25.collect.{HashBiMap, Maps}
import com.azure.cosmos.implementation.{SparkBridgeImplementationInternal, UUIDs}
import com.azure.cosmos.changeFeedMetrics.{ChangeFeedMetricsListener, ChangeFeedMetricsTracker}
import com.azure.cosmos.spark.CosmosPredicates.{assertNotNull, assertNotNullOrEmpty, assertOnSparkDriver}
import com.azure.cosmos.spark.diagnostics.{DiagnosticsContext, LoggerHelper}
import org.apache.spark.broadcast.Broadcast
Expand All @@ -12,7 +14,12 @@ import org.apache.spark.sql.connector.read.{InputPartition, PartitionReaderFacto
import org.apache.spark.sql.types.StructType

import java.time.Duration
import java.util.UUID
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.atomic.AtomicLong

// scalastyle:off underscore.import
import scala.collection.JavaConverters._
// scalastyle:on underscore.import

// scala style rule flaky - even complaining on partial log messages
// scalastyle:off multiple.string.literals
Expand Down Expand Up @@ -59,6 +66,17 @@ private class ChangeFeedMicroBatchStream

private var latestOffsetSnapshot: Option[ChangeFeedOffset] = None

private val partitionIndex = new AtomicLong(0)
private val partitionIndexMap = Maps.synchronizedBiMap(HashBiMap.create[NormalizedRange, Long]())
private val partitionMetricsMap = new ConcurrentHashMap[NormalizedRange, ChangeFeedMetricsTracker]()

if (CosmosConstants.ChangeFeedMetricsListenerConfig.metricsListenerEnabled) {
log.logInfo("Register ChangeFeedMetricsListener")
session.sparkContext.addSparkListener(new ChangeFeedMetricsListener(partitionIndexMap, partitionMetricsMap))
} else {
log.logInfo("ChangeFeedMetricsListener is disabled")
}

override def latestOffset(): Offset = {
// For Spark data streams implementing SupportsAdmissionControl trait
// latestOffset(Offset, ReadLimit) is called instead
Expand Down Expand Up @@ -101,11 +119,15 @@ private class ChangeFeedMicroBatchStream
end
.inputPartitions
.get
.map(partition => partition
.withContinuationState(
SparkBridgeImplementationInternal
.map(partition => {
val index = partitionIndexMap.asScala.getOrElseUpdate(partition.feedRange, partitionIndex.incrementAndGet())
partition
.withContinuationState(
SparkBridgeImplementationInternal
.extractChangeFeedStateForRange(start.changeFeedState, partition.feedRange),
clearEndLsn = false))
clearEndLsn = false)
.withIndex(index)
})
}

/**
Expand Down Expand Up @@ -152,7 +174,8 @@ private class ChangeFeedMicroBatchStream
this.containerConfig,
this.partitioningConfig,
this.defaultParallelism,
this.container
this.container,
Some(this.partitionMetricsMap)
)

if (offset.changeFeedState != startChangeFeedOffset.changeFeedState) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import com.azure.cosmos.implementation.guava25.base.MoreObjects.firstNonNull
import com.azure.cosmos.implementation.guava25.base.Strings.emptyToNull
import com.azure.cosmos.spark.diagnostics.BasicLoggingTrait
import org.apache.spark.TaskContext
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.util.AccumulatorV2

Expand Down Expand Up @@ -40,20 +41,23 @@ object SparkInternalsBridge extends BasicLoggingTrait {
private final lazy val reflectionAccessAllowed = new AtomicBoolean(getSparkReflectionAccessAllowed)

def getInternalCustomTaskMetricsAsSQLMetric(knownCosmosMetricNames: Set[String]) : Map[String, SQLMetric] = {
Option.apply(TaskContext.get()) match {
case Some(taskCtx) => getInternalCustomTaskMetricsAsSQLMetric(knownCosmosMetricNames, taskCtx.taskMetrics())
case None => Map.empty[String, SQLMetric]
}
}

def getInternalCustomTaskMetricsAsSQLMetric(knownCosmosMetricNames: Set[String], taskMetrics: TaskMetrics) : Map[String, SQLMetric] = {

if (!reflectionAccessAllowed.get) {
Map.empty[String, SQLMetric]
} else {
Option.apply(TaskContext.get()) match {
case Some(taskCtx) => getInternalCustomTaskMetricsAsSQLMetricInternal(knownCosmosMetricNames, taskCtx)
case None => Map.empty[String, SQLMetric]
}
getInternalCustomTaskMetricsAsSQLMetricInternal(knownCosmosMetricNames, taskMetrics)
}
}

private def getAccumulators(taskCtx: TaskContext): Option[Seq[AccumulatorV2[_, _]]] = {
private def getAccumulators(taskMetrics: TaskMetrics): Option[Seq[AccumulatorV2[_, _]]] = {
try {
val taskMetrics: Object = taskCtx.taskMetrics()
val method = Option(accumulatorsMethod.get) match {
case Some(existing) => existing
case None =>
Expand All @@ -78,8 +82,8 @@ object SparkInternalsBridge extends BasicLoggingTrait {

private def getInternalCustomTaskMetricsAsSQLMetricInternal(
knownCosmosMetricNames: Set[String],
taskCtx: TaskContext): Map[String, SQLMetric] = {
getAccumulators(taskCtx) match {
taskMetrics: TaskMetrics): Map[String, SQLMetric] = {
getAccumulators(taskMetrics) match {
case Some(accumulators) => accumulators
.filter(accumulable => accumulable.isInstanceOf[SQLMetric]
&& accumulable.name.isDefined
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
// Licensed under the MIT License.
package com.azure.cosmos.spark

import com.azure.cosmos.changeFeedMetrics.{ChangeFeedMetricsListener, ChangeFeedMetricsTracker}
import com.azure.cosmos.implementation.SparkBridgeImplementationInternal
import com.azure.cosmos.implementation.guava25.collect.{HashBiMap, Maps}
import com.azure.cosmos.spark.CosmosPredicates.{assertNotNull, assertNotNullOrEmpty, assertOnSparkDriver}
import com.azure.cosmos.spark.diagnostics.{DiagnosticsContext, LoggerHelper}
import org.apache.spark.broadcast.Broadcast
Expand All @@ -13,6 +15,12 @@ import org.apache.spark.sql.types.StructType

import java.time.Duration
import java.util.UUID
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.atomic.AtomicLong

// scalastyle:off underscore.import
import scala.collection.JavaConverters._
// scalastyle:on underscore.import

// scala style rule flaky - even complaining on partial log messages
// scalastyle:off multiple.string.literals
Expand Down Expand Up @@ -59,6 +67,17 @@ private class ChangeFeedMicroBatchStream

private var latestOffsetSnapshot: Option[ChangeFeedOffset] = None

private val partitionIndex = new AtomicLong(0)
private val partitionIndexMap = Maps.synchronizedBiMap(HashBiMap.create[NormalizedRange, Long]())
private val partitionMetricsMap = new ConcurrentHashMap[NormalizedRange, ChangeFeedMetricsTracker]()

if (CosmosConstants.ChangeFeedMetricsListenerConfig.metricsListenerEnabled) {
log.logInfo("Register ChangeFeedMetricsListener")
session.sparkContext.addSparkListener(new ChangeFeedMetricsListener(partitionIndexMap, partitionMetricsMap))
} else {
log.logInfo("ChangeFeedMetricsListener is disabled")
}

override def latestOffset(): Offset = {
// For Spark data streams implementing SupportsAdmissionControl trait
// latestOffset(Offset, ReadLimit) is called instead
Expand Down Expand Up @@ -101,11 +120,15 @@ private class ChangeFeedMicroBatchStream
end
.inputPartitions
.get
.map(partition => partition
.withContinuationState(
SparkBridgeImplementationInternal
.map(partition => {
val index = partitionIndexMap.asScala.getOrElseUpdate(partition.feedRange, partitionIndex.incrementAndGet())
partition
.withContinuationState(
SparkBridgeImplementationInternal
.extractChangeFeedStateForRange(start.changeFeedState, partition.feedRange),
clearEndLsn = false))
clearEndLsn = false)
.withIndex(index)
})
}

/**
Expand Down Expand Up @@ -152,7 +175,8 @@ private class ChangeFeedMicroBatchStream
this.containerConfig,
this.partitioningConfig,
this.defaultParallelism,
this.container
this.container,
Some(this.partitionMetricsMap)
)

if (offset.changeFeedState != startChangeFeedOffset.changeFeedState) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import com.azure.cosmos.implementation.guava25.base.MoreObjects.firstNonNull
import com.azure.cosmos.implementation.guava25.base.Strings.emptyToNull
import com.azure.cosmos.spark.diagnostics.BasicLoggingTrait
import org.apache.spark.TaskContext
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.util.AccumulatorV2

Expand Down Expand Up @@ -40,20 +41,23 @@ object SparkInternalsBridge extends BasicLoggingTrait {
private final lazy val reflectionAccessAllowed = new AtomicBoolean(getSparkReflectionAccessAllowed)

def getInternalCustomTaskMetricsAsSQLMetric(knownCosmosMetricNames: Set[String]) : Map[String, SQLMetric] = {
Option.apply(TaskContext.get()) match {
case Some(taskCtx) => getInternalCustomTaskMetricsAsSQLMetric(knownCosmosMetricNames, taskCtx.taskMetrics())
case None => Map.empty[String, SQLMetric]
}
}

def getInternalCustomTaskMetricsAsSQLMetric(knownCosmosMetricNames: Set[String], taskMetrics: TaskMetrics) : Map[String, SQLMetric] = {

if (!reflectionAccessAllowed.get) {
Map.empty[String, SQLMetric]
} else {
Option.apply(TaskContext.get()) match {
case Some(taskCtx) => getInternalCustomTaskMetricsAsSQLMetricInternal(knownCosmosMetricNames, taskCtx)
case None => Map.empty[String, SQLMetric]
}
getInternalCustomTaskMetricsAsSQLMetricInternal(knownCosmosMetricNames, taskMetrics)
}
}

private def getAccumulators(taskCtx: TaskContext): Option[Seq[AccumulatorV2[_, _]]] = {
private def getAccumulators(taskMetrics: TaskMetrics): Option[Seq[AccumulatorV2[_, _]]] = {
try {
val taskMetrics: Object = taskCtx.taskMetrics()
val method = Option(accumulatorsMethod.get) match {
case Some(existing) => existing
case None =>
Expand All @@ -78,8 +82,8 @@ object SparkInternalsBridge extends BasicLoggingTrait {

private def getInternalCustomTaskMetricsAsSQLMetricInternal(
knownCosmosMetricNames: Set[String],
taskCtx: TaskContext): Map[String, SQLMetric] = {
getAccumulators(taskCtx) match {
taskMetrics: TaskMetrics): Map[String, SQLMetric] = {
getAccumulators(taskMetrics) match {
case Some(accumulators) => accumulators
.filter(accumulable => accumulable.isInstanceOf[SQLMetric]
&& accumulable.name.isDefined
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
package com.azure.cosmos.spark

import com.azure.cosmos.changeFeedMetrics.ChangeFeedMetricsTracker

class ChangeFeedMetricsTest extends UnitSpec {

"ChangeFeedMetricsTracker" should "track weighted changes per lsn" in {
val testRange = NormalizedRange("0", "FF")
val metricsTracker = new ChangeFeedMetricsTracker(1, testRange)
metricsTracker.track(10000, 0)
metricsTracker.getWeightedAvgChangesPerLsn shouldBe 10000
}

"ChangeFeedMetricsTracker" should "return none when no metrics tracked" in {
val testRange = NormalizedRange("0", "FF")
val metricsTracker = new ChangeFeedMetricsTracker(1, testRange)

metricsTracker.getWeightedAvgChangesPerLsn shouldBe None
}

"ChangeFeedMetricsTracker" should "track limited metrics history" in {
val testRange = NormalizedRange("0", "FF")
val metricsTracker = new ChangeFeedMetricsTracker(1, testRange)

metricsTracker.track(10000, 1)
for (i <- 1 to 5) {
metricsTracker.track(1, 2000)
}

metricsTracker.getWeightedAvgChangesPerLsn shouldBe 1.toDouble / 2000
}
}
Loading