@@ -20,18 +20,24 @@ import scala.collection.mutable.ArrayBuffer
20
20
21
21
import io .fabric8 .kubernetes .api .model .{Pod , PodSpec , PodStatus }
22
22
import org .mockito .Mockito ._
23
+ import org .scalatest .BeforeAndAfter
23
24
24
25
import org .apache .spark .{SparkContext , SparkFunSuite }
26
+ import org .apache .spark .deploy .kubernetes .config ._
25
27
import org .apache .spark .scheduler .{FakeTask , FakeTaskScheduler , HostTaskLocation , TaskLocation }
26
28
27
- class KubernetesTaskSetManagerSuite extends SparkFunSuite {
29
+ class KubernetesTaskSetManagerSuite extends SparkFunSuite with BeforeAndAfter {
28
30
29
31
val sc = new SparkContext (" local" , " test" )
30
32
val sched = new FakeTaskScheduler (sc,
31
33
(" execA" , " 10.0.0.1" ), (" execB" , " 10.0.0.2" ), (" execC" , " 10.0.0.3" ))
32
34
val backend = mock(classOf [KubernetesClusterSchedulerBackend ])
33
35
sched.backend = backend
34
36
37
+ before {
38
+ sc.conf.remove(KUBERNETES_DRIVER_CLUSTER_NODENAME_DNS_LOOKUP_ENABLED )
39
+ }
40
+
35
41
test(" Find pending tasks for executors using executor pod IP addresses" ) {
36
42
val taskSet = FakeTask .createTaskSet(3 ,
37
43
Seq (TaskLocation (" 10.0.0.1" , " execA" )), // Task 0 runs on executor pod 10.0.0.1.
@@ -76,7 +82,33 @@ class KubernetesTaskSetManagerSuite extends SparkFunSuite {
76
82
assert(manager.getPendingTasksForHost(" 10.0.0.1" ) == ArrayBuffer (1 , 0 ))
77
83
}
78
84
85
+ test(" Test DNS lookup is disabled by default for cluster node full hostnames" ) {
86
+ assert(! sc.conf.get(KUBERNETES_DRIVER_CLUSTER_NODENAME_DNS_LOOKUP_ENABLED ))
87
+ }
88
+
89
+ test(" Find pending tasks for executors, but avoid looking up cluster node FQDNs from DNS" ) {
90
+ sc.conf.set(KUBERNETES_DRIVER_CLUSTER_NODENAME_DNS_LOOKUP_ENABLED , false )
91
+ val taskSet = FakeTask .createTaskSet(2 ,
92
+ Seq (HostTaskLocation (" kube-node1.domain1" )), // Task 0's partition belongs to datanode here.
93
+ Seq (HostTaskLocation (" kube-node1.domain1" )) // task 1's partition belongs to datanode here.
94
+ )
95
+ val spec1 = mock(classOf [PodSpec ])
96
+ when(spec1.getNodeName).thenReturn(" kube-node1" )
97
+ val pod1 = mock(classOf [Pod ])
98
+ when(pod1.getSpec).thenReturn(spec1)
99
+ val status1 = mock(classOf [PodStatus ])
100
+ when(status1.getHostIP).thenReturn(" 196.0.0.5" )
101
+ when(pod1.getStatus).thenReturn(status1)
102
+ val inetAddressUtil = mock(classOf [InetAddressUtil ])
103
+ when(inetAddressUtil.getFullHostName(" 196.0.0.5" )).thenReturn(" kube-node1.domain1" )
104
+ when(backend.getExecutorPodByIP(" 10.0.0.1" )).thenReturn(Some (pod1))
105
+
106
+ val manager = new KubernetesTaskSetManager (sched, taskSet, maxTaskFailures = 2 , inetAddressUtil)
107
+ assert(manager.getPendingTasksForHost(" 10.0.0.1" ) == ArrayBuffer ())
108
+ }
109
+
79
110
test(" Find pending tasks for executors using cluster node FQDNs that executor pods run on" ) {
111
+ sc.conf.set(KUBERNETES_DRIVER_CLUSTER_NODENAME_DNS_LOOKUP_ENABLED , true )
80
112
val taskSet = FakeTask .createTaskSet(2 ,
81
113
Seq (HostTaskLocation (" kube-node1.domain1" )), // Task 0's partition belongs to datanode here.
82
114
Seq (HostTaskLocation (" kube-node1.domain1" )) // task 1's partition belongs to datanode here.
0 commit comments