@@ -21,9 +21,10 @@ import java.io.{Externalizable, ObjectInput, ObjectOutput}
21
21
import java .lang .Thread .UncaughtExceptionHandler
22
22
import java .nio .ByteBuffer
23
23
import java .util .Properties
24
- import java .util .concurrent .{CountDownLatch , TimeUnit }
24
+ import java .util .concurrent .{ConcurrentHashMap , CountDownLatch , TimeUnit }
25
25
import java .util .concurrent .atomic .AtomicBoolean
26
26
27
+ import scala .collection .mutable .ArrayBuffer
27
28
import scala .collection .mutable .Map
28
29
import scala .concurrent .duration ._
29
30
import scala .language .postfixOps
@@ -33,22 +34,25 @@ import org.mockito.Matchers.{any, eq => meq}
33
34
import org .mockito .Mockito .{inOrder , verify , when }
34
35
import org .mockito .invocation .InvocationOnMock
35
36
import org .mockito .stubbing .Answer
37
+ import org .scalatest .PrivateMethodTester
36
38
import org .scalatest .concurrent .Eventually
37
39
import org .scalatest .mockito .MockitoSugar
38
40
39
41
import org .apache .spark ._
40
42
import org .apache .spark .TaskState .TaskState
41
- import org .apache .spark .memory .MemoryManager
43
+ import org .apache .spark .internal .config ._
44
+ import org .apache .spark .memory .TestMemoryManager
42
45
import org .apache .spark .metrics .MetricsSystem
43
46
import org .apache .spark .rdd .RDD
44
- import org .apache .spark .rpc .RpcEnv
45
- import org .apache .spark .scheduler .{FakeTask , ResultTask , TaskDescription }
47
+ import org .apache .spark .rpc .{ RpcEndpointRef , RpcEnv , RpcTimeout }
48
+ import org .apache .spark .scheduler .{FakeTask , ResultTask , Task , TaskDescription }
46
49
import org .apache .spark .serializer .{JavaSerializer , SerializerManager }
47
50
import org .apache .spark .shuffle .FetchFailedException
48
- import org .apache .spark .storage .BlockManagerId
49
- import org .apache .spark .util .UninterruptibleThread
51
+ import org .apache .spark .storage .{ BlockManager , BlockManagerId }
52
+ import org .apache .spark .util .{ LongAccumulator , UninterruptibleThread }
50
53
51
- class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSugar with Eventually {
54
+ class ExecutorSuite extends SparkFunSuite
55
+ with LocalSparkContext with MockitoSugar with Eventually with PrivateMethodTester {
52
56
53
57
test(" SPARK-15963: Catch `TaskKilledException` correctly in Executor.TaskRunner" ) {
54
58
// mock some objects to make Executor.launchTask() happy
@@ -252,18 +256,107 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug
252
256
}
253
257
}
254
258
259
+ test(" Heartbeat should drop zero accumulator updates" ) {
260
+ heartbeatZeroAccumulatorUpdateTest(true )
261
+ }
262
+
263
+ test(" Heartbeat should not drop zero accumulator updates when the conf is disabled" ) {
264
+ heartbeatZeroAccumulatorUpdateTest(false )
265
+ }
266
+
267
+ private def withHeartbeatExecutor (confs : (String , String )* )
268
+ (f : (Executor , ArrayBuffer [Heartbeat ]) => Unit ): Unit = {
269
+ val conf = new SparkConf
270
+ confs.foreach { case (k, v) => conf.set(k, v) }
271
+ val serializer = new JavaSerializer (conf)
272
+ val env = createMockEnv(conf, serializer)
273
+ val executor =
274
+ new Executor (" id" , " localhost" , SparkEnv .get, userClassPath = Nil , isLocal = true )
275
+ val executorClass = classOf [Executor ]
276
+
277
+ // Save all heartbeats sent into an ArrayBuffer for verification
278
+ val heartbeats = ArrayBuffer [Heartbeat ]()
279
+ val mockReceiver = mock[RpcEndpointRef ]
280
+ when(mockReceiver.askSync(any[Heartbeat ], any[RpcTimeout ])(any))
281
+ .thenAnswer(new Answer [HeartbeatResponse ] {
282
+ override def answer (invocation : InvocationOnMock ): HeartbeatResponse = {
283
+ val args = invocation.getArguments()
284
+ val mock = invocation.getMock
285
+ heartbeats += args(0 ).asInstanceOf [Heartbeat ]
286
+ HeartbeatResponse (false )
287
+ }
288
+ })
289
+ val receiverRef = executorClass.getDeclaredField(" heartbeatReceiverRef" )
290
+ receiverRef.setAccessible(true )
291
+ receiverRef.set(executor, mockReceiver)
292
+
293
+ f(executor, heartbeats)
294
+ }
295
+
296
+ private def heartbeatZeroAccumulatorUpdateTest (dropZeroMetrics : Boolean ): Unit = {
297
+ val c = EXECUTOR_HEARTBEAT_DROP_ZERO_ACCUMULATOR_UPDATES .key -> dropZeroMetrics.toString
298
+ withHeartbeatExecutor(c) { (executor, heartbeats) =>
299
+ val reportHeartbeat = PrivateMethod [Unit ](' reportHeartBeat )
300
+
301
+ // When no tasks are running, there should be no accumulators sent in heartbeat
302
+ executor.invokePrivate(reportHeartbeat())
303
+ // invokeReportHeartbeat(executor)
304
+ assert(heartbeats.length == 1 )
305
+ assert(heartbeats(0 ).accumUpdates.length == 0 ,
306
+ " No updates should be sent when no tasks are running" )
307
+
308
+ // When we start a task with a nonzero accumulator, that should end up in the heartbeat
309
+ val metrics = new TaskMetrics ()
310
+ val nonZeroAccumulator = new LongAccumulator ()
311
+ nonZeroAccumulator.add(1 )
312
+ metrics.registerAccumulator(nonZeroAccumulator)
313
+
314
+ val executorClass = classOf [Executor ]
315
+ val tasksMap = {
316
+ val field =
317
+ executorClass.getDeclaredField(" org$apache$spark$executor$Executor$$runningTasks" )
318
+ field.setAccessible(true )
319
+ field.get(executor).asInstanceOf [ConcurrentHashMap [Long , executor.TaskRunner ]]
320
+ }
321
+ val mockTaskRunner = mock[executor.TaskRunner ]
322
+ val mockTask = mock[Task [Any ]]
323
+ when(mockTask.metrics).thenReturn(metrics)
324
+ when(mockTaskRunner.taskId).thenReturn(6 )
325
+ when(mockTaskRunner.task).thenReturn(mockTask)
326
+ when(mockTaskRunner.startGCTime).thenReturn(1 )
327
+ tasksMap.put(6 , mockTaskRunner)
328
+
329
+ executor.invokePrivate(reportHeartbeat())
330
+ assert(heartbeats.length == 2 )
331
+ val updates = heartbeats(1 ).accumUpdates
332
+ assert(updates.length == 1 && updates(0 )._1 == 6 ,
333
+ " Heartbeat should only send update for the one task running" )
334
+ val accumsSent = updates(0 )._2.length
335
+ assert(accumsSent > 0 , " The nonzero accumulator we added should be sent" )
336
+ if (dropZeroMetrics) {
337
+ assert(accumsSent == metrics.accumulators().count(! _.isZero),
338
+ " The number of accumulators sent should match the number of nonzero accumulators" )
339
+ } else {
340
+ assert(accumsSent == metrics.accumulators().length,
341
+ " The number of accumulators sent should match the number of total accumulators" )
342
+ }
343
+ }
344
+ }
345
+
255
346
private def createMockEnv (conf : SparkConf , serializer : JavaSerializer ): SparkEnv = {
256
347
val mockEnv = mock[SparkEnv ]
257
348
val mockRpcEnv = mock[RpcEnv ]
258
349
val mockMetricsSystem = mock[MetricsSystem ]
259
- val mockMemoryManager = mock[MemoryManager ]
350
+ val mockBlockManager = mock[BlockManager ]
260
351
when(mockEnv.conf).thenReturn(conf)
261
352
when(mockEnv.serializer).thenReturn(serializer)
262
353
when(mockEnv.serializerManager).thenReturn(mock[SerializerManager ])
263
354
when(mockEnv.rpcEnv).thenReturn(mockRpcEnv)
264
355
when(mockEnv.metricsSystem).thenReturn(mockMetricsSystem)
265
- when(mockEnv.memoryManager).thenReturn(mockMemoryManager )
356
+ when(mockEnv.memoryManager).thenReturn(new TestMemoryManager (conf) )
266
357
when(mockEnv.closureSerializer).thenReturn(serializer)
358
+ when(mockBlockManager.blockManagerId).thenReturn(BlockManagerId (" 1" , " hostA" , 1234 ))
359
+ when(mockEnv.blockManager).thenReturn(mockBlockManager)
267
360
SparkEnv .set(mockEnv)
268
361
mockEnv
269
362
}
0 commit comments