Skip to content

Commit 355c63a

Browse files
committed
generalize tests to cover both driver and executor builders
1 parent 2534ed9 commit 355c63a

File tree

5 files changed

+188
-165
lines changed

5 files changed

+188
-165
lines changed

resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesDriverSpec.scala

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,3 @@ private[spark] case class KubernetesDriverSpec(
2222
pod: SparkPod,
2323
driverKubernetesResources: Seq[HasMetadata],
2424
systemProperties: Map[String, String])
25-
26-
private[spark] object KubernetesDriverSpec {
27-
def initialSpec(initialConf: KubernetesConf[KubernetesDriverSpecificConf]): KubernetesDriverSpec =
28-
KubernetesDriverSpec(
29-
SparkPod.initialPod(),
30-
Seq.empty,
31-
initialConf.sparkConf.getAll.toMap)
32-
}

resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStep.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ private[spark] class BasicDriverFeatureStep(
119119
.addToLabels(conf.roleLabels.asJava)
120120
.addToAnnotations(conf.roleAnnotations.asJava)
121121
.endMetadata()
122-
.withNewSpec()
122+
.editOrNewSpec()
123123
.withRestartPolicy("Never")
124124
.addToNodeSelector(conf.nodeSelector().asJava)
125125
.addToImagePullSecrets(conf.imagePullSecrets(): _*)

resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala

Lines changed: 35 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,10 @@ import java.io.File
2020

2121
import com.google.common.base.Charsets
2222
import com.google.common.io.Files
23-
import io.fabric8.kubernetes.api.model.{DoneablePod, Pod, PodBuilder, PodList}
23+
import io.fabric8.kubernetes.api.model.PodBuilder
2424
import io.fabric8.kubernetes.client.KubernetesClient
25-
import io.fabric8.kubernetes.client.dsl.{MixedOperation, PodResource}
2625
import org.mockito.Matchers._
2726
import org.mockito.Mockito
28-
import org.mockito.Mockito._
2927

3028
import org.apache.spark.{SparkConf, SparkException, SparkFunSuite}
3129
import org.apache.spark.deploy.k8s._
@@ -325,48 +323,42 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite {
325323
}
326324

327325
test("Starts with template if specified") {
328-
val spec = constructSpecWithPodTemplate(
326+
val kubernetesClient = PodBuilderSuiteUtils.loadingMockKubernetesClient()
327+
val sparkConf = new SparkConf(false)
328+
.set(CONTAINER_IMAGE, "spark-driver:latest")
329+
.set(KUBERNETES_DRIVER_PODTEMPLATE_FILE, "template-file.yaml")
330+
val kubernetesConf = new KubernetesConf(
331+
sparkConf,
332+
KubernetesDriverSpecificConf(
333+
Some(JavaMainAppResource("example.jar")),
334+
"test-app",
335+
"main",
336+
Seq.empty),
337+
"prefix",
338+
"appId",
339+
Map.empty,
340+
Map.empty,
341+
Map.empty,
342+
Map.empty,
343+
Map.empty,
344+
Nil,
345+
Seq.empty[String])
346+
val driverSpec = KubernetesDriverBuilder
347+
.apply(kubernetesClient, sparkConf)
348+
.buildFromFeatures(kubernetesConf)
349+
PodBuilderSuiteUtils.verifyPodWithSupportedFeatures(driverSpec.pod)
350+
}
351+
352+
test("Throws on misconfigured pod template") {
353+
val kubernetesClient = PodBuilderSuiteUtils.loadingMockKubernetesClient(
329354
new PodBuilder()
330355
.withNewMetadata()
331356
.addToLabels("test-label-key", "test-label-value")
332357
.endMetadata()
333-
.withNewSpec()
334-
.addNewContainer()
335-
.withName("test-driver-container")
336-
.endContainer()
337-
.endSpec()
338358
.build())
339-
340-
assert(spec.pod.pod.getMetadata.getLabels.containsKey("test-label-key"))
341-
assert(spec.pod.pod.getMetadata.getLabels.get("test-label-key") === "test-label-value")
342-
assert(spec.pod.container.getName === "test-driver-container")
343-
}
344-
345-
test("Throws on misconfigured pod template") {
346-
val exception = intercept[SparkException] {
347-
constructSpecWithPodTemplate(
348-
new PodBuilder()
349-
.withNewMetadata()
350-
.addToLabels("test-label-key", "test-label-value")
351-
.endMetadata()
352-
.build())
353-
}
354-
assert(exception.getMessage.contains("Could not load pod from template file."))
355-
}
356-
357-
private def constructSpecWithPodTemplate(pod: Pod) : KubernetesDriverSpec = {
358-
val kubernetesClient = mock(classOf[KubernetesClient])
359-
val pods =
360-
mock(classOf[MixedOperation[Pod, PodList, DoneablePod, PodResource[Pod, DoneablePod]]])
361-
val podResource = mock(classOf[PodResource[Pod, DoneablePod]])
362-
when(kubernetesClient.pods()).thenReturn(pods)
363-
when(pods.load(any(classOf[File]))).thenReturn(podResource)
364-
when(podResource.get()).thenReturn(pod)
365-
366359
val sparkConf = new SparkConf(false)
367360
.set(CONTAINER_IMAGE, "spark-driver:latest")
368361
.set(KUBERNETES_DRIVER_PODTEMPLATE_FILE, "template-file.yaml")
369-
370362
val kubernetesConf = new KubernetesConf(
371363
sparkConf,
372364
KubernetesDriverSpecificConf(
@@ -384,7 +376,11 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite {
384376
Map.empty,
385377
Nil,
386378
Seq.empty[String])
387-
388-
KubernetesDriverBuilder.apply(kubernetesClient, sparkConf).buildFromFeatures(kubernetesConf)
379+
val exception = intercept[SparkException] {
380+
KubernetesDriverBuilder
381+
.apply(kubernetesClient, sparkConf)
382+
.buildFromFeatures(kubernetesConf)
383+
}
384+
assert(exception.getMessage.contains("Could not load pod from template file."))
389385
}
390386
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
package org.apache.spark.deploy.k8s.submit
18+
19+
import java.io.File
20+
21+
import io.fabric8.kubernetes.api.model._
22+
import io.fabric8.kubernetes.client.KubernetesClient
23+
import io.fabric8.kubernetes.client.dsl.{MixedOperation, PodResource}
24+
import org.mockito.Matchers.any
25+
import org.mockito.Mockito.{mock, when}
26+
import org.scalatest.FlatSpec
27+
import scala.collection.JavaConverters._
28+
29+
import org.apache.spark.deploy.k8s.SparkPod
30+
31+
object PodBuilderSuiteUtils extends FlatSpec {
32+
33+
def loadingMockKubernetesClient(pod: Pod = podWithSupportedFeatures()): KubernetesClient = {
34+
val kubernetesClient = mock(classOf[KubernetesClient])
35+
val pods =
36+
mock(classOf[MixedOperation[Pod, PodList, DoneablePod, PodResource[Pod, DoneablePod]]])
37+
val podResource = mock(classOf[PodResource[Pod, DoneablePod]])
38+
when(kubernetesClient.pods()).thenReturn(pods)
39+
when(pods.load(any(classOf[File]))).thenReturn(podResource)
40+
when(podResource.get()).thenReturn(pod)
41+
kubernetesClient
42+
}
43+
44+
def verifyPodWithSupportedFeatures(pod: SparkPod): Unit = {
45+
val metadata = pod.pod.getMetadata
46+
assert(metadata.getLabels.containsKey("test-label-key"))
47+
assert(metadata.getAnnotations.containsKey("test-annotation-key"))
48+
assert(metadata.getNamespace === "namespace")
49+
assert(metadata.getOwnerReferences.asScala.exists(_.getName == "owner-reference"))
50+
val spec = pod.pod.getSpec
51+
assert(!spec.getContainers.asScala.exists(_.getName == "executor-container"))
52+
assert(spec.getDnsPolicy === "dns-policy")
53+
assert(spec.getHostAliases.asScala.exists(_.getHostnames.asScala.exists(_ == "hostname")))
54+
assert(spec.getImagePullSecrets.asScala.exists(_.getName == "local-reference"))
55+
assert(spec.getInitContainers.asScala.exists(_.getName == "init-container"))
56+
assert(spec.getNodeName == "node-name")
57+
assert(spec.getNodeSelector.get("node-selector-key") === "node-selector-value")
58+
assert(spec.getSchedulerName === "scheduler")
59+
assert(spec.getSecurityContext.getRunAsUser === 1000L)
60+
assert(spec.getServiceAccount === "service-account")
61+
assert(spec.getSubdomain === "subdomain")
62+
assert(spec.getTolerations.asScala.exists(_.getKey == "toleration-key"))
63+
assert(spec.getVolumes.asScala.exists(_.getName == "test-volume"))
64+
val container = pod.container
65+
assert(container.getName === "executor-container")
66+
assert(container.getArgs.contains("arg"))
67+
assert(container.getCommand.equals(List("command").asJava))
68+
assert(container.getEnv.asScala.exists(_.getName == "env-key"))
69+
assert(container.getResources.getLimits.get("gpu") ===
70+
new QuantityBuilder().withAmount("1").build())
71+
assert(container.getSecurityContext.getRunAsNonRoot)
72+
assert(container.getStdin)
73+
assert(container.getTerminationMessagePath === "termination-message-path")
74+
assert(container.getTerminationMessagePolicy === "termination-message-policy")
75+
assert(pod.container.getVolumeMounts.asScala.exists(_.getName == "test-volume"))
76+
77+
}
78+
79+
80+
def podWithSupportedFeatures(): Pod = new PodBuilder()
81+
.withNewMetadata()
82+
.addToLabels("test-label-key", "test-label-value")
83+
.addToAnnotations("test-annotation-key", "test-annotation-value")
84+
.withNamespace("namespace")
85+
.addNewOwnerReference()
86+
.withController(true)
87+
.withName("owner-reference")
88+
.endOwnerReference()
89+
.endMetadata()
90+
.withNewSpec()
91+
.withDnsPolicy("dns-policy")
92+
.withHostAliases(new HostAliasBuilder().withHostnames("hostname").build())
93+
.withImagePullSecrets(
94+
new LocalObjectReferenceBuilder().withName("local-reference").build())
95+
.withInitContainers(new ContainerBuilder().withName("init-container").build())
96+
.withNodeName("node-name")
97+
.withNodeSelector(Map("node-selector-key" -> "node-selector-value").asJava)
98+
.withSchedulerName("scheduler")
99+
.withNewSecurityContext()
100+
.withRunAsUser(1000L)
101+
.endSecurityContext()
102+
.withServiceAccount("service-account")
103+
.withSubdomain("subdomain")
104+
.withTolerations(new TolerationBuilder()
105+
.withKey("toleration-key")
106+
.withOperator("Equal")
107+
.withEffect("NoSchedule")
108+
.build())
109+
.addNewVolume()
110+
.withNewHostPath()
111+
.withPath("/test")
112+
.endHostPath()
113+
.withName("test-volume")
114+
.endVolume()
115+
.addNewContainer()
116+
.withArgs("arg")
117+
.withCommand("command")
118+
.addNewEnv()
119+
.withName("env-key")
120+
.withValue("env-value")
121+
.endEnv()
122+
.withImagePullPolicy("Always")
123+
.withName("executor-container")
124+
.withNewResources()
125+
.withLimits(Map("gpu" -> new QuantityBuilder().withAmount("1").build()).asJava)
126+
.endResources()
127+
.withNewSecurityContext()
128+
.withRunAsNonRoot(true)
129+
.endSecurityContext()
130+
.withStdin(true)
131+
.withTerminationMessagePath("termination-message-path")
132+
.withTerminationMessagePolicy("termination-message-policy")
133+
.addToVolumeMounts(
134+
new VolumeMountBuilder()
135+
.withName("test-volume")
136+
.withMountPath("/test")
137+
.build())
138+
.endContainer()
139+
.endSpec()
140+
.build()
141+
142+
}

0 commit comments

Comments
 (0)