@@ -32,15 +32,15 @@ import (
3232)
3333
3434func TestMnistPyTorchAppWrapperCpu (t * testing.T ) {
35- runMnistPyTorchAppWrapper (t , "cpu" , 0 )
35+ runMnistPyTorchAppWrapper (t , CPU )
3636}
3737
3838func TestMnistPyTorchAppWrapperGpu (t * testing.T ) {
39- runMnistPyTorchAppWrapper (t , "gpu" , 1 )
39+ runMnistPyTorchAppWrapper (t , NVIDIA )
4040}
4141
4242// Trains the MNIST dataset as a batch Job in an AppWrapper, and asserts successful completion of the training job.
43- func runMnistPyTorchAppWrapper (t * testing.T , accelerator string , numberOfGpus int ) {
43+ func runMnistPyTorchAppWrapper (t * testing.T , accelerator Accelerator ) {
4444 test := With (t )
4545
4646 // Create a namespace
@@ -51,7 +51,7 @@ func runMnistPyTorchAppWrapper(t *testing.T, accelerator string, numberOfGpus in
5151 defer func () {
5252 _ = test .Client ().Kueue ().KueueV1beta1 ().ResourceFlavors ().Delete (test .Ctx (), resourceFlavor .Name , metav1.DeleteOptions {})
5353 }()
54- clusterQueue := createClusterQueue (test , resourceFlavor , numberOfGpus )
54+ clusterQueue := createClusterQueue (test , resourceFlavor , accelerator )
5555 defer func () {
5656 _ = test .Client ().Kueue ().KueueV1beta1 ().ClusterQueues ().Delete (test .Ctx (), clusterQueue .Name , metav1.DeleteOptions {})
5757 }()
@@ -109,7 +109,7 @@ func runMnistPyTorchAppWrapper(t *testing.T, accelerator string, numberOfGpus in
109109 {Name : "MNIST_DATASET_URL" , Value : GetMnistDatasetURL ()},
110110 {Name : "PIP_INDEX_URL" , Value : GetPipIndexURL ()},
111111 {Name : "PIP_TRUSTED_HOST" , Value : GetPipTrustedHost ()},
112- {Name : "ACCELERATOR" , Value : accelerator },
112+ {Name : "ACCELERATOR" , Value : accelerator . Type },
113113 },
114114 Command : []string {"/bin/sh" , "-c" , "pip install -r /test/requirements.txt && torchrun /test/mnist.py" },
115115 VolumeMounts : []corev1.VolumeMount {
0 commit comments