@@ -88,7 +88,7 @@ class ArmadaDynamicAllocationIntegrationSuite extends AnyFunSuite with BeforeAnd
8888 sc = createMockSparkContext(conf)
8989 val taskScheduler = createMockTaskScheduler(sc)
9090
91- backend = new ArmadaClusterManagerBackend (
91+ backend = new TestableArmadaClusterManagerBackend (
9292 taskScheduler,
9393 sc,
9494 java.util.concurrent.Executors .newScheduledThreadPool(1 ),
@@ -113,6 +113,13 @@ class ArmadaDynamicAllocationIntegrationSuite extends AnyFunSuite with BeforeAnd
113113 backend.getPendingExecutorCount === 1 ,
114114 " Executor should stay pending until registered with Spark"
115115 )
116+
117+ backend.asInstanceOf [TestableArmadaClusterManagerBackend ].simulateExecutorRegistration(execId)
118+
119+ assert(
120+ backend.getPendingExecutorCount === 0 ,
121+ " Executor should be removed from pending after registering with Spark"
122+ )
116123 }
117124
118125 test(" concurrent executor submissions are thread-safe" ) {
@@ -183,6 +190,26 @@ class ArmadaDynamicAllocationIntegrationSuite extends AnyFunSuite with BeforeAnd
183190
184191 // Helper methods
185192
193+ private class TestableArmadaClusterManagerBackend (
194+ scheduler : TaskSchedulerImpl ,
195+ sc : SparkContext ,
196+ executorService : java.util.concurrent.ScheduledExecutorService ,
197+ masterURL : String
198+ ) extends ArmadaClusterManagerBackend (scheduler, sc, executorService, masterURL) {
199+
200+ private val testRegisteredExecutors = scala.collection.mutable.Set .empty[String ]
201+
202+ /** Simulate executor registration by adding it to the test set */
203+ def simulateExecutorRegistration (executorId : String ): Unit = {
204+ testRegisteredExecutors += executorId
205+ }
206+
207+ /** Include our test executors */
208+ override def getExecutorIds (): Seq [String ] = synchronized {
209+ super .getExecutorIds() ++ testRegisteredExecutors.toSeq
210+ }
211+ }
212+
186213 private def createMockSparkContext (sparkConf : SparkConf ): SparkContext = {
187214 val sc = org.mockito.Mockito .mock(classOf [SparkContext ])
188215 val env = org.mockito.Mockito .mock(classOf [SparkEnv ])
0 commit comments