@@ -6,6 +6,7 @@ import com.azure.cosmos.implementation.guava25.base.MoreObjects.firstNonNull
6
6
import com .azure .cosmos .implementation .guava25 .base .Strings .emptyToNull
7
7
import com .azure .cosmos .spark .diagnostics .BasicLoggingTrait
8
8
import org .apache .spark .TaskContext
9
+ import org .apache .spark .executor .TaskMetrics
9
10
import org .apache .spark .sql .execution .metric .SQLMetric
10
11
import org .apache .spark .util .AccumulatorV2
11
12
@@ -40,20 +41,23 @@ object SparkInternalsBridge extends BasicLoggingTrait {
40
41
private final lazy val reflectionAccessAllowed = new AtomicBoolean (getSparkReflectionAccessAllowed)
41
42
42
43
def getInternalCustomTaskMetricsAsSQLMetric (knownCosmosMetricNames : Set [String ]) : Map [String , SQLMetric ] = {
44
+ Option .apply(TaskContext .get()) match {
45
+ case Some (taskCtx) => getInternalCustomTaskMetricsAsSQLMetric(knownCosmosMetricNames, taskCtx.taskMetrics())
46
+ case None => Map .empty[String , SQLMetric ]
47
+ }
48
+ }
49
+
50
+ def getInternalCustomTaskMetricsAsSQLMetric (knownCosmosMetricNames : Set [String ], taskMetrics : TaskMetrics ) : Map [String , SQLMetric ] = {
43
51
44
52
if (! reflectionAccessAllowed.get) {
45
53
Map .empty[String , SQLMetric ]
46
54
} else {
47
- Option .apply(TaskContext .get()) match {
48
- case Some (taskCtx) => getInternalCustomTaskMetricsAsSQLMetricInternal(knownCosmosMetricNames, taskCtx)
49
- case None => Map .empty[String , SQLMetric ]
50
- }
55
+ getInternalCustomTaskMetricsAsSQLMetricInternal(knownCosmosMetricNames, taskMetrics)
51
56
}
52
57
}
53
58
54
- private def getAccumulators (taskCtx : TaskContext ): Option [Seq [AccumulatorV2 [_, _]]] = {
59
+ private def getAccumulators (taskMetrics : TaskMetrics ): Option [Seq [AccumulatorV2 [_, _]]] = {
55
60
try {
56
- val taskMetrics : Object = taskCtx.taskMetrics()
57
61
val method = Option (accumulatorsMethod.get) match {
58
62
case Some (existing) => existing
59
63
case None =>
@@ -78,8 +82,8 @@ object SparkInternalsBridge extends BasicLoggingTrait {
78
82
79
83
private def getInternalCustomTaskMetricsAsSQLMetricInternal (
80
84
knownCosmosMetricNames : Set [String ],
81
- taskCtx : TaskContext ): Map [String , SQLMetric ] = {
82
- getAccumulators(taskCtx ) match {
85
+ taskMetrics : TaskMetrics ): Map [String , SQLMetric ] = {
86
+ getAccumulators(taskMetrics ) match {
83
87
case Some (accumulators) => accumulators
84
88
.filter(accumulable => accumulable.isInstanceOf [SQLMetric ]
85
89
&& accumulable.name.isDefined
0 commit comments