Skip to content

Commit cd6029d

Browse files
viiryaRobert Kruszewski
authored andcommitted
[SPARK-21052][SQL] Add hash map metrics to join
## What changes were proposed in this pull request? This adds the average hash map probe metrics to join operator such as `BroadcastHashJoin` and `ShuffledHashJoin`. This PR adds the API to `HashedRelation` to get average hash map probe. ## How was this patch tested? Related test cases are added. Author: Liang-Chi Hsieh <[email protected]> Closes apache#18301 from viirya/SPARK-21052.
1 parent 4120a17 commit cd6029d

File tree

8 files changed

+296
-60
lines changed

8 files changed

+296
-60
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ case class HashAggregateExec(
6060
"peakMemory" -> SQLMetrics.createSizeMetric(sparkContext, "peak memory"),
6161
"spillSize" -> SQLMetrics.createSizeMetric(sparkContext, "spill size"),
6262
"aggTime" -> SQLMetrics.createTimingMetric(sparkContext, "aggregate time"),
63-
"avgHashmapProbe" -> SQLMetrics.createAverageMetric(sparkContext, "avg hashmap probe"))
63+
"avgHashProbe" -> SQLMetrics.createAverageMetric(sparkContext, "avg hash probe"))
6464

6565
override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute)
6666

@@ -94,7 +94,7 @@ case class HashAggregateExec(
9494
val numOutputRows = longMetric("numOutputRows")
9595
val peakMemory = longMetric("peakMemory")
9696
val spillSize = longMetric("spillSize")
97-
val avgHashmapProbe = longMetric("avgHashmapProbe")
97+
val avgHashProbe = longMetric("avgHashProbe")
9898

9999
child.execute().mapPartitions { iter =>
100100

@@ -119,7 +119,7 @@ case class HashAggregateExec(
119119
numOutputRows,
120120
peakMemory,
121121
spillSize,
122-
avgHashmapProbe)
122+
avgHashProbe)
123123
if (!hasInput && groupingExpressions.isEmpty) {
124124
numOutputRows += 1
125125
Iterator.single[UnsafeRow](aggregationIterator.outputForEmptyGroupingKeyWithoutInput())
@@ -344,7 +344,7 @@ case class HashAggregateExec(
344344
sorter: UnsafeKVExternalSorter,
345345
peakMemory: SQLMetric,
346346
spillSize: SQLMetric,
347-
avgHashmapProbe: SQLMetric): KVIterator[UnsafeRow, UnsafeRow] = {
347+
avgHashProbe: SQLMetric): KVIterator[UnsafeRow, UnsafeRow] = {
348348

349349
// update peak execution memory
350350
val mapMemory = hashMap.getPeakMemoryUsedBytes
@@ -355,8 +355,7 @@ case class HashAggregateExec(
355355
metrics.incPeakExecutionMemory(maxMemory)
356356

357357
// Update average hashmap probe
358-
val avgProbes = hashMap.getAverageProbesPerLookup()
359-
avgHashmapProbe.add(avgProbes.ceil.toLong)
358+
avgHashProbe.set(hashMap.getAverageProbesPerLookup())
360359

361360
if (sorter == null) {
362361
// not spilled
@@ -584,7 +583,7 @@ case class HashAggregateExec(
584583
val doAgg = ctx.freshName("doAggregateWithKeys")
585584
val peakMemory = metricTerm(ctx, "peakMemory")
586585
val spillSize = metricTerm(ctx, "spillSize")
587-
val avgHashmapProbe = metricTerm(ctx, "avgHashmapProbe")
586+
val avgHashProbe = metricTerm(ctx, "avgHashProbe")
588587

589588
def generateGenerateCode(): String = {
590589
if (isFastHashMapEnabled) {
@@ -611,7 +610,7 @@ case class HashAggregateExec(
611610
s"$iterTermForFastHashMap = $fastHashMapTerm.rowIterator();"} else ""}
612611

613612
$iterTerm = $thisPlan.finishAggregate($hashMapTerm, $sorterTerm, $peakMemory, $spillSize,
614-
$avgHashmapProbe);
613+
$avgHashProbe);
615614
}
616615
""")
617616

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ class TungstenAggregationIterator(
8989
numOutputRows: SQLMetric,
9090
peakMemory: SQLMetric,
9191
spillSize: SQLMetric,
92-
avgHashmapProbe: SQLMetric)
92+
avgHashProbe: SQLMetric)
9393
extends AggregationIterator(
9494
groupingExpressions,
9595
originalInputAttributes,
@@ -367,6 +367,22 @@ class TungstenAggregationIterator(
367367
}
368368
}
369369

370+
TaskContext.get().addTaskCompletionListener(_ => {
371+
// At the end of the task, update the task's peak memory usage. Since we destroy
372+
// the map to create the sorter, their memory usages should not overlap, so it is safe
373+
// to just use the max of the two.
374+
val mapMemory = hashMap.getPeakMemoryUsedBytes
375+
val sorterMemory = Option(externalSorter).map(_.getPeakMemoryUsedBytes).getOrElse(0L)
376+
val maxMemory = Math.max(mapMemory, sorterMemory)
377+
val metrics = TaskContext.get().taskMetrics()
378+
peakMemory.set(maxMemory)
379+
spillSize.set(metrics.memoryBytesSpilled - spillSizeBefore)
380+
metrics.incPeakExecutionMemory(maxMemory)
381+
382+
// Updating average hashmap probe
383+
avgHashProbe.set(hashMap.getAverageProbesPerLookup())
384+
})
385+
370386
///////////////////////////////////////////////////////////////////////////
371387
// Part 7: Iterator's public methods.
372388
///////////////////////////////////////////////////////////////////////////
@@ -409,22 +425,6 @@ class TungstenAggregationIterator(
409425
}
410426
}
411427

412-
// If this is the last record, update the task's peak memory usage. Since we destroy
413-
// the map to create the sorter, their memory usages should not overlap, so it is safe
414-
// to just use the max of the two.
415-
if (!hasNext) {
416-
val mapMemory = hashMap.getPeakMemoryUsedBytes
417-
val sorterMemory = Option(externalSorter).map(_.getPeakMemoryUsedBytes).getOrElse(0L)
418-
val maxMemory = Math.max(mapMemory, sorterMemory)
419-
val metrics = TaskContext.get().taskMetrics()
420-
peakMemory += maxMemory
421-
spillSize += metrics.memoryBytesSpilled - spillSizeBefore
422-
metrics.incPeakExecutionMemory(maxMemory)
423-
424-
// Update average hashmap probe if this is the last record.
425-
val averageProbes = hashMap.getAverageProbesPerLookup()
426-
avgHashmapProbe.add(averageProbes.ceil.toLong)
427-
}
428428
numOutputRows += 1
429429
res
430430
} else {

sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.plans.physical.{BroadcastDistribution, Dist
2828
import org.apache.spark.sql.execution.{BinaryExecNode, CodegenSupport, SparkPlan}
2929
import org.apache.spark.sql.execution.metric.SQLMetrics
3030
import org.apache.spark.sql.types.LongType
31+
import org.apache.spark.util.TaskCompletionListener
3132

3233
/**
3334
* Performs an inner hash join of two child relations. When the output RDD of this operator is
@@ -46,7 +47,8 @@ case class BroadcastHashJoinExec(
4647
extends BinaryExecNode with HashJoin with CodegenSupport {
4748

4849
override lazy val metrics = Map(
49-
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
50+
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"),
51+
"avgHashProbe" -> SQLMetrics.createAverageMetric(sparkContext, "avg hash probe"))
5052

5153
override def requiredChildDistribution: Seq[Distribution] = {
5254
val mode = HashedRelationBroadcastMode(buildKeys)
@@ -60,12 +62,13 @@ case class BroadcastHashJoinExec(
6062

6163
protected override def doExecute(): RDD[InternalRow] = {
6264
val numOutputRows = longMetric("numOutputRows")
65+
val avgHashProbe = longMetric("avgHashProbe")
6366

6467
val broadcastRelation = buildPlan.executeBroadcast[HashedRelation]()
6568
streamedPlan.execute().mapPartitions { streamedIter =>
6669
val hashed = broadcastRelation.value.asReadOnlyCopy()
6770
TaskContext.get().taskMetrics().incPeakExecutionMemory(hashed.estimatedSize)
68-
join(streamedIter, hashed, numOutputRows)
71+
join(streamedIter, hashed, numOutputRows, avgHashProbe)
6972
}
7073
}
7174

@@ -90,6 +93,23 @@ case class BroadcastHashJoinExec(
9093
}
9194
}
9295

96+
/**
97+
* Returns the codes used to add a task completion listener to update avg hash probe
98+
* at the end of the task.
99+
*/
100+
private def genTaskListener(avgHashProbe: String, relationTerm: String): String = {
101+
val listenerClass = classOf[TaskCompletionListener].getName
102+
val taskContextClass = classOf[TaskContext].getName
103+
s"""
104+
| $taskContextClass$$.MODULE$$.get().addTaskCompletionListener(new $listenerClass() {
105+
| @Override
106+
| public void onTaskCompletion($taskContextClass context) {
107+
| $avgHashProbe.set($relationTerm.getAverageProbesPerLookup());
108+
| }
109+
| });
110+
""".stripMargin
111+
}
112+
93113
/**
94114
* Returns a tuple of Broadcast of HashedRelation and the variable name for it.
95115
*/
@@ -99,10 +119,16 @@ case class BroadcastHashJoinExec(
99119
val broadcast = ctx.addReferenceObj("broadcast", broadcastRelation)
100120
val relationTerm = ctx.freshName("relation")
101121
val clsName = broadcastRelation.value.getClass.getName
122+
123+
// At the end of the task, we update the avg hash probe.
124+
val avgHashProbe = metricTerm(ctx, "avgHashProbe")
125+
val addTaskListener = genTaskListener(avgHashProbe, relationTerm)
126+
102127
ctx.addMutableState(clsName, relationTerm,
103128
s"""
104129
| $relationTerm = (($clsName) $broadcast.value()).asReadOnlyCopy();
105130
| incPeakExecutionMemory($relationTerm.estimatedSize());
131+
| $addTaskListener
106132
""".stripMargin)
107133
(broadcastRelation, relationTerm)
108134
}

sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
package org.apache.spark.sql.execution.joins
1919

20+
import org.apache.spark.TaskContext
2021
import org.apache.spark.sql.catalyst.InternalRow
2122
import org.apache.spark.sql.catalyst.expressions._
2223
import org.apache.spark.sql.catalyst.plans._
@@ -193,7 +194,8 @@ trait HashJoin {
193194
protected def join(
194195
streamedIter: Iterator[InternalRow],
195196
hashed: HashedRelation,
196-
numOutputRows: SQLMetric): Iterator[InternalRow] = {
197+
numOutputRows: SQLMetric,
198+
avgHashProbe: SQLMetric): Iterator[InternalRow] = {
197199

198200
val joinedIter = joinType match {
199201
case _: InnerLike =>
@@ -211,6 +213,10 @@ trait HashJoin {
211213
s"BroadcastHashJoin should not take $x as the JoinType")
212214
}
213215

216+
// At the end of the task, we update the avg hash probe.
217+
TaskContext.get().addTaskCompletionListener(_ =>
218+
avgHashProbe.set(hashed.getAverageProbesPerLookup()))
219+
214220
val resultProj = createResultProjection
215221
joinedIter.map { r =>
216222
numOutputRows += 1

sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,11 @@ private[execution] sealed trait HashedRelation extends KnownSizeEstimation {
7979
* Release any used resources.
8080
*/
8181
def close(): Unit
82+
83+
/**
84+
* Returns the average number of probes per key lookup.
85+
*/
86+
def getAverageProbesPerLookup(): Double
8287
}
8388

8489
private[execution] object HashedRelation {
@@ -242,7 +247,8 @@ private[joins] class UnsafeHashedRelation(
242247
binaryMap = new BytesToBytesMap(
243248
taskMemoryManager,
244249
(nKeys * 1.5 + 1).toInt, // reduce hash collision
245-
pageSizeBytes)
250+
pageSizeBytes,
251+
true)
246252

247253
var i = 0
248254
var keyBuffer = new Array[Byte](1024)
@@ -273,6 +279,8 @@ private[joins] class UnsafeHashedRelation(
273279
override def read(kryo: Kryo, in: Input): Unit = Utils.tryOrIOException {
274280
read(in.readInt, in.readLong, in.readBytes)
275281
}
282+
283+
override def getAverageProbesPerLookup(): Double = binaryMap.getAverageProbesPerLookup()
276284
}
277285

278286
private[joins] object UnsafeHashedRelation {
@@ -290,7 +298,8 @@ private[joins] object UnsafeHashedRelation {
290298
taskMemoryManager,
291299
// Only 70% of the slots can be used before growing, more capacity help to reduce collision
292300
(sizeEstimate * 1.5 + 1).toInt,
293-
pageSizeBytes)
301+
pageSizeBytes,
302+
true)
294303

295304
// Create a mapping of buildKeys -> rows
296305
val keyGenerator = UnsafeProjection.create(key)
@@ -344,7 +353,7 @@ private[joins] object UnsafeHashedRelation {
344353
* determined by `key1 - minKey`.
345354
*
346355
* The map is created as sparse mode, then key-value could be appended into it. Once finish
347-
* appending, caller could all optimize() to try to turn the map into dense mode, which is faster
356+
* appending, caller could call optimize() to try to turn the map into dense mode, which is faster
348357
* to probe.
349358
*
350359
* see http://java-performance.info/implementing-world-fastest-java-int-to-int-hash-map/
@@ -385,6 +394,10 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap
385394
// The number of unique keys.
386395
private var numKeys = 0L
387396

397+
// Tracking average number of probes per key lookup.
398+
private var numKeyLookups = 0L
399+
private var numProbes = 0L
400+
388401
// needed by serializer
389402
def this() = {
390403
this(
@@ -469,6 +482,8 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap
469482
*/
470483
def getValue(key: Long, resultRow: UnsafeRow): UnsafeRow = {
471484
if (isDense) {
485+
numKeyLookups += 1
486+
numProbes += 1
472487
if (key >= minKey && key <= maxKey) {
473488
val value = array((key - minKey).toInt)
474489
if (value > 0) {
@@ -477,11 +492,14 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap
477492
}
478493
} else {
479494
var pos = firstSlot(key)
495+
numKeyLookups += 1
496+
numProbes += 1
480497
while (array(pos + 1) != 0) {
481498
if (array(pos) == key) {
482499
return getRow(array(pos + 1), resultRow)
483500
}
484501
pos = nextSlot(pos)
502+
numProbes += 1
485503
}
486504
}
487505
null
@@ -509,6 +527,8 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap
509527
*/
510528
def get(key: Long, resultRow: UnsafeRow): Iterator[UnsafeRow] = {
511529
if (isDense) {
530+
numKeyLookups += 1
531+
numProbes += 1
512532
if (key >= minKey && key <= maxKey) {
513533
val value = array((key - minKey).toInt)
514534
if (value > 0) {
@@ -517,11 +537,14 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap
517537
}
518538
} else {
519539
var pos = firstSlot(key)
540+
numKeyLookups += 1
541+
numProbes += 1
520542
while (array(pos + 1) != 0) {
521543
if (array(pos) == key) {
522544
return valueIter(array(pos + 1), resultRow)
523545
}
524546
pos = nextSlot(pos)
547+
numProbes += 1
525548
}
526549
}
527550
null
@@ -573,8 +596,11 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap
573596
private def updateIndex(key: Long, address: Long): Unit = {
574597
var pos = firstSlot(key)
575598
assert(numKeys < array.length / 2)
599+
numKeyLookups += 1
600+
numProbes += 1
576601
while (array(pos) != key && array(pos + 1) != 0) {
577602
pos = nextSlot(pos)
603+
numProbes += 1
578604
}
579605
if (array(pos + 1) == 0) {
580606
// this is the first value for this key, put the address in array.
@@ -686,6 +712,8 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap
686712
writeLong(maxKey)
687713
writeLong(numKeys)
688714
writeLong(numValues)
715+
writeLong(numKeyLookups)
716+
writeLong(numProbes)
689717

690718
writeLong(array.length)
691719
writeLongArray(writeBuffer, array, array.length)
@@ -727,6 +755,8 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap
727755
maxKey = readLong()
728756
numKeys = readLong()
729757
numValues = readLong()
758+
numKeyLookups = readLong()
759+
numProbes = readLong()
730760

731761
val length = readLong().toInt
732762
mask = length - 2
@@ -742,6 +772,11 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap
742772
override def read(kryo: Kryo, in: Input): Unit = {
743773
read(in.readBoolean, in.readLong, in.readBytes)
744774
}
775+
776+
/**
777+
* Returns the average number of probes per key lookup.
778+
*/
779+
def getAverageProbesPerLookup(): Double = numProbes.toDouble / numKeyLookups
745780
}
746781

747782
private[joins] class LongHashedRelation(
@@ -793,6 +828,8 @@ private[joins] class LongHashedRelation(
793828
resultRow = new UnsafeRow(nFields)
794829
map = in.readObject().asInstanceOf[LongToUnsafeRowMap]
795830
}
831+
832+
override def getAverageProbesPerLookup(): Double = map.getAverageProbesPerLookup()
796833
}
797834

798835
/**

0 commit comments

Comments
 (0)