Skip to content

Commit 3eda9d6

Browse files
committed
simulate executor registration and test executor lifecycle properly
Signed-off-by: Sudipto Baral <sudiptobaral.me@gmail.com>
1 parent 865ed08 commit 3eda9d6

File tree

1 file changed

+28
-1
lines changed

1 file changed

+28
-1
lines changed

src/test/scala/org/apache/spark/scheduler/cluster/armada/ArmadaDynamicAllocationIntegrationSuite.scala

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ class ArmadaDynamicAllocationIntegrationSuite extends AnyFunSuite with BeforeAnd
8888
sc = createMockSparkContext(conf)
8989
val taskScheduler = createMockTaskScheduler(sc)
9090

91-
backend = new ArmadaClusterManagerBackend(
91+
backend = new TestableArmadaClusterManagerBackend(
9292
taskScheduler,
9393
sc,
9494
java.util.concurrent.Executors.newScheduledThreadPool(1),
@@ -113,6 +113,13 @@ class ArmadaDynamicAllocationIntegrationSuite extends AnyFunSuite with BeforeAnd
113113
backend.getPendingExecutorCount === 1,
114114
"Executor should stay pending until registered with Spark"
115115
)
116+
117+
backend.asInstanceOf[TestableArmadaClusterManagerBackend].simulateExecutorRegistration(execId)
118+
119+
assert(
120+
backend.getPendingExecutorCount === 0,
121+
"Executor should be removed from pending after registering with Spark"
122+
)
116123
}
117124

118125
test("concurrent executor submissions are thread-safe") {
@@ -183,6 +190,26 @@ class ArmadaDynamicAllocationIntegrationSuite extends AnyFunSuite with BeforeAnd
183190

184191
// Helper methods
185192

193+
private class TestableArmadaClusterManagerBackend(
194+
scheduler: TaskSchedulerImpl,
195+
sc: SparkContext,
196+
executorService: java.util.concurrent.ScheduledExecutorService,
197+
masterURL: String
198+
) extends ArmadaClusterManagerBackend(scheduler, sc, executorService, masterURL) {
199+
200+
private val testRegisteredExecutors = scala.collection.mutable.Set.empty[String]
201+
202+
/** Simulate executor registration by adding it to the test set */
203+
def simulateExecutorRegistration(executorId: String): Unit = {
204+
testRegisteredExecutors += executorId
205+
}
206+
207+
/** Include our test executors */
208+
override def getExecutorIds(): Seq[String] = synchronized {
209+
super.getExecutorIds() ++ testRegisteredExecutors.toSeq
210+
}
211+
}
212+
186213
private def createMockSparkContext(sparkConf: SparkConf): SparkContext = {
187214
val sc = org.mockito.Mockito.mock(classOf[SparkContext])
188215
val env = org.mockito.Mockito.mock(classOf[SparkEnv])

0 commit comments

Comments
 (0)