@@ -21,6 +21,7 @@ import (
2121 "fmt"
2222 "net/http"
2323 "net/url"
24+ "strings"
2425 "testing"
2526
2627 . "github.com/onsi/gomega"
@@ -40,14 +41,18 @@ import (
4041// directly managed by Kueue, and asserts successful completion of the training job.
4142
4243func TestMnistRayJobRayClusterCpu (t * testing.T ) {
43- runMnistRayJobRayCluster (t , "cpu" , 0 )
44+ runMnistRayJobRayCluster (t , "cpu" , 0 , "nvidia.com/gpu" , GetRayImage () )
4445}
4546
46- func TestMnistRayJobRayClusterGpu (t * testing.T ) {
47- runMnistRayJobRayCluster (t , "gpu" , 1 )
47+ func TestMnistRayJobRayClusterCudaGpu (t * testing.T ) {
48+ runMnistRayJobRayCluster (t , "gpu" , 1 , "nvidia.com/gpu" , GetRayImage () )
4849}
4950
50- func runMnistRayJobRayCluster (t * testing.T , accelerator string , numberOfGpus int ) {
51+ func TestMnistRayJobRayClusterROCmGpu (t * testing.T ) {
52+ runMnistRayJobRayCluster (t , "gpu" , 1 , "amd.com/gpu" , GetRayROCmImage ())
53+ }
54+
55+ func runMnistRayJobRayCluster (t * testing.T , accelerator string , numberOfGpus int , gpuResourceName string , rayImage string ) {
5156 test := With (t )
5257
5358 // Create a static namespace to ensure a consistent Ray Dashboard hostname entry in /etc/hosts before executing the test.
@@ -58,11 +63,11 @@ func runMnistRayJobRayCluster(t *testing.T, accelerator string, numberOfGpus int
5863 defer func () {
5964 _ = test .Client ().Kueue ().KueueV1beta1 ().ResourceFlavors ().Delete (test .Ctx (), resourceFlavor .Name , metav1.DeleteOptions {})
6065 }()
61- clusterQueue := createClusterQueue (test , resourceFlavor , numberOfGpus )
66+ clusterQueue := createClusterQueue (test , resourceFlavor , numberOfGpus , gpuResourceName )
6267 defer func () {
6368 _ = test .Client ().Kueue ().KueueV1beta1 ().ClusterQueues ().Delete (test .Ctx (), clusterQueue .Name , metav1.DeleteOptions {})
6469 }()
65- CreateKueueLocalQueue (test , namespace .Name , clusterQueue .Name , AsDefaultQueue )
70+ localQueue := CreateKueueLocalQueue (test , namespace .Name , clusterQueue .Name , AsDefaultQueue )
6671
6772 // Create MNIST training script
6873 mnist := constructMNISTConfigMap (test , namespace )
@@ -71,7 +76,7 @@ func runMnistRayJobRayCluster(t *testing.T, accelerator string, numberOfGpus int
7176 test .T ().Logf ("Created ConfigMap %s/%s successfully" , mnist .Namespace , mnist .Name )
7277
7378 // Create RayCluster and assign it to the localqueue
74- rayCluster := constructRayCluster (test , namespace , mnist , numberOfGpus )
79+ rayCluster := constructRayCluster (test , namespace , localQueue . Name , mnist , numberOfGpus , gpuResourceName , rayImage )
7580 rayCluster , err = test .Client ().Ray ().RayV1 ().RayClusters (namespace .Name ).Create (test .Ctx (), rayCluster , metav1.CreateOptions {})
7681 test .Expect (err ).NotTo (HaveOccurred ())
7782 test .T ().Logf ("Created RayCluster %s/%s successfully" , rayCluster .Namespace , rayCluster .Name )
@@ -81,7 +86,7 @@ func runMnistRayJobRayCluster(t *testing.T, accelerator string, numberOfGpus int
8186 Should (WithTransform (RayClusterState , Equal (rayv1 .Ready )))
8287
8388 // Create RayJob
84- rayJob := constructRayJob (test , namespace , rayCluster , accelerator , numberOfGpus )
89+ rayJob := constructRayJob (test , namespace , rayCluster , accelerator , numberOfGpus , gpuResourceName , rayImage )
8590 rayJob , err = test .Client ().Ray ().RayV1 ().RayJobs (namespace .Name ).Create (test .Ctx (), rayJob , metav1.CreateOptions {})
8691 test .Expect (err ).NotTo (HaveOccurred ())
8792 test .T ().Logf ("Created RayJob %s/%s successfully" , rayJob .Namespace , rayJob .Name )
@@ -110,15 +115,19 @@ func runMnistRayJobRayCluster(t *testing.T, accelerator string, numberOfGpus int
110115}
111116
112117func TestMnistRayJobRayClusterAppWrapperCpu (t * testing.T ) {
113- runMnistRayJobRayClusterAppWrapper (t , "cpu" , 0 )
118+ runMnistRayJobRayClusterAppWrapper (t , "cpu" , 0 , "nvidia.com/gpu" , GetRayImage () )
114119}
115120
116- func TestMnistRayJobRayClusterAppWrapperGpu (t * testing.T ) {
117- runMnistRayJobRayClusterAppWrapper (t , "gpu" , 1 )
121+ func TestMnistRayJobRayClusterAppWrapperCudaGpu (t * testing.T ) {
122+ runMnistRayJobRayClusterAppWrapper (t , "gpu" , 1 , "nvidia.com/gpu" , GetRayImage ())
123+ }
124+
125+ func TestMnistRayJobRayClusterAppWrapperROCmGpu (t * testing.T ) {
126+ runMnistRayJobRayClusterAppWrapper (t , "gpu" , 1 , "amd.com/gpu" , GetRayROCmImage ())
118127}
119128
120129// Same as TestMNISTRayJobRayCluster, except the RayCluster is wrapped in an AppWrapper
121- func runMnistRayJobRayClusterAppWrapper (t * testing.T , accelerator string , numberOfGpus int ) {
130+ func runMnistRayJobRayClusterAppWrapper (t * testing.T , accelerator string , numberOfGpus int , gpuResourceName string , rayImage string ) {
122131 test := With (t )
123132
124133 // Create a static namespace to ensure a consistent Ray Dashboard hostname entry in /etc/hosts before executing the test.
@@ -129,7 +138,7 @@ func runMnistRayJobRayClusterAppWrapper(t *testing.T, accelerator string, number
129138 defer func () {
130139 _ = test .Client ().Kueue ().KueueV1beta1 ().ResourceFlavors ().Delete (test .Ctx (), resourceFlavor .Name , metav1.DeleteOptions {})
131140 }()
132- clusterQueue := createClusterQueue (test , resourceFlavor , numberOfGpus )
141+ clusterQueue := createClusterQueue (test , resourceFlavor , numberOfGpus , gpuResourceName )
133142 defer func () {
134143 _ = test .Client ().Kueue ().KueueV1beta1 ().ClusterQueues ().Delete (test .Ctx (), clusterQueue .Name , metav1.DeleteOptions {})
135144 }()
@@ -142,7 +151,7 @@ func runMnistRayJobRayClusterAppWrapper(t *testing.T, accelerator string, number
142151 test .T ().Logf ("Created ConfigMap %s/%s successfully" , mnist .Namespace , mnist .Name )
143152
144153 // Create RayCluster, wrap in AppWrapper and assign to localqueue
145- rayCluster := constructRayCluster (test , namespace , mnist , numberOfGpus )
154+ rayCluster := constructRayCluster (test , namespace , localQueue . Name , mnist , numberOfGpus , gpuResourceName , rayImage )
146155 raw := Raw (test , rayCluster )
147156 raw = RemoveCreationTimestamp (test , raw )
148157
@@ -183,7 +192,7 @@ func runMnistRayJobRayClusterAppWrapper(t *testing.T, accelerator string, number
183192 Should (WithTransform (RayClusterState , Equal (rayv1 .Ready )))
184193
185194 // Create RayJob
186- rayJob := constructRayJob (test , namespace , rayCluster , accelerator , numberOfGpus )
195+ rayJob := constructRayJob (test , namespace , rayCluster , accelerator , numberOfGpus , gpuResourceName , rayImage )
187196 rayJob , err = test .Client ().Ray ().RayV1 ().RayJobs (namespace .Name ).Create (test .Ctx (), rayJob , metav1.CreateOptions {})
188197 test .Expect (err ).NotTo (HaveOccurred ())
189198 test .T ().Logf ("Created RayJob %s/%s successfully" , rayJob .Namespace , rayJob .Name )
@@ -223,11 +232,11 @@ func TestRayClusterImagePullSecret(t *testing.T) {
223232 defer func () {
224233 _ = test .Client ().Kueue ().KueueV1beta1 ().ResourceFlavors ().Delete (test .Ctx (), resourceFlavor .Name , metav1.DeleteOptions {})
225234 }()
226- clusterQueue := createClusterQueue (test , resourceFlavor , 0 )
235+ clusterQueue := createClusterQueue (test , resourceFlavor , 0 , "nvidia.com/gpu" )
227236 defer func () {
228237 _ = test .Client ().Kueue ().KueueV1beta1 ().ClusterQueues ().Delete (test .Ctx (), clusterQueue .Name , metav1.DeleteOptions {})
229238 }()
230- CreateKueueLocalQueue (test , namespace .Name , clusterQueue .Name , AsDefaultQueue )
239+ localQueue := CreateKueueLocalQueue (test , namespace .Name , clusterQueue .Name , AsDefaultQueue )
231240
232241 // Create MNIST training script
233242 mnist := constructMNISTConfigMap (test , namespace )
@@ -236,7 +245,7 @@ func TestRayClusterImagePullSecret(t *testing.T) {
236245 test .T ().Logf ("Created ConfigMap %s/%s successfully" , mnist .Namespace , mnist .Name )
237246
238247 // Create RayCluster with imagePullSecret and assign it to the localqueue
239- rayCluster := constructRayCluster (test , namespace , mnist , 0 )
248+ rayCluster := constructRayCluster (test , namespace , localQueue . Name , mnist , 0 , "nvidia.com/gpu" , GetRayImage () )
240249 rayCluster .Spec .HeadGroupSpec .Template .Spec .ImagePullSecrets = []corev1.LocalObjectReference {{Name : "custom-pull-secret" }}
241250 rayCluster , err = test .Client ().Ray ().RayV1 ().RayClusters (namespace .Name ).Create (test .Ctx (), rayCluster , metav1.CreateOptions {})
242251 test .Expect (err ).NotTo (HaveOccurred ())
@@ -266,7 +275,7 @@ func constructMNISTConfigMap(test Test, namespace *corev1.Namespace) *corev1.Con
266275 }
267276}
268277
269- func constructRayCluster (_ Test , namespace * corev1.Namespace , mnist * corev1.ConfigMap , numberOfGpus int ) * rayv1.RayCluster {
278+ func constructRayCluster (_ Test , namespace * corev1.Namespace , localQueueName string , mnist * corev1.ConfigMap , numberOfGpus int , gpuResourceName string , rayImage string ) * rayv1.RayCluster {
270279 return & rayv1.RayCluster {
271280 TypeMeta : metav1.TypeMeta {
272281 APIVersion : rayv1 .GroupVersion .String (),
@@ -275,6 +284,9 @@ func constructRayCluster(_ Test, namespace *corev1.Namespace, mnist *corev1.Conf
275284 ObjectMeta : metav1.ObjectMeta {
276285 Name : "raycluster" ,
277286 Namespace : namespace .Name ,
287+ Labels : map [string ]string {
288+ "kueue.x-k8s.io/queue-name" : localQueueName ,
289+ },
278290 },
279291 Spec : rayv1.RayClusterSpec {
280292 RayVersion : GetRayVersion (),
@@ -287,7 +299,7 @@ func constructRayCluster(_ Test, namespace *corev1.Namespace, mnist *corev1.Conf
287299 Containers : []corev1.Container {
288300 {
289301 Name : "ray-head" ,
290- Image : GetRayImage () ,
302+ Image : rayImage ,
291303 Ports : []corev1.ContainerPort {
292304 {
293305 ContainerPort : 6379 ,
@@ -335,14 +347,14 @@ func constructRayCluster(_ Test, namespace *corev1.Namespace, mnist *corev1.Conf
335347 Spec : corev1.PodSpec {
336348 Tolerations : []corev1.Toleration {
337349 {
338- Key : "nvidia.com/gpu" ,
350+ Key : gpuResourceName ,
339351 Operator : corev1 .TolerationOpExists ,
340352 },
341353 },
342354 Containers : []corev1.Container {
343355 {
344356 Name : "ray-worker" ,
345- Image : GetRayImage () ,
357+ Image : rayImage ,
346358 Lifecycle : & corev1.Lifecycle {
347359 PreStop : & corev1.LifecycleHandler {
348360 Exec : & corev1.ExecAction {
@@ -352,14 +364,14 @@ func constructRayCluster(_ Test, namespace *corev1.Namespace, mnist *corev1.Conf
352364 },
353365 Resources : corev1.ResourceRequirements {
354366 Requests : corev1.ResourceList {
355- corev1 .ResourceCPU : resource .MustParse ("250m" ),
356- corev1 .ResourceMemory : resource .MustParse ("1G" ),
357- "nvidia.com/gpu" : resource .MustParse (fmt .Sprint (numberOfGpus )),
367+ corev1 .ResourceCPU : resource .MustParse ("250m" ),
368+ corev1 .ResourceMemory : resource .MustParse ("1G" ),
369+ corev1 . ResourceName ( gpuResourceName ): resource .MustParse (fmt .Sprint (numberOfGpus )),
358370 },
359371 Limits : corev1.ResourceList {
360- corev1 .ResourceCPU : resource .MustParse ("2" ),
361- corev1 .ResourceMemory : resource .MustParse ("4G" ),
362- "nvidia.com/gpu" : resource .MustParse (fmt .Sprint (numberOfGpus )),
372+ corev1 .ResourceCPU : resource .MustParse ("2" ),
373+ corev1 .ResourceMemory : resource .MustParse ("4G" ),
374+ corev1 . ResourceName ( gpuResourceName ): resource .MustParse (fmt .Sprint (numberOfGpus )),
363375 },
364376 },
365377 VolumeMounts : []corev1.VolumeMount {
@@ -390,7 +402,22 @@ func constructRayCluster(_ Test, namespace *corev1.Namespace, mnist *corev1.Conf
390402 }
391403}
392404
393- func constructRayJob (_ Test , namespace * corev1.Namespace , rayCluster * rayv1.RayCluster , accelerator string , numberOfGpus int ) * rayv1.RayJob {
405+ func constructRayJob (_ Test , namespace * corev1.Namespace , rayCluster * rayv1.RayCluster , accelerator string , numberOfGpus int , gpuResourceName string , rayImage string ) * rayv1.RayJob {
406+ pipPackages := []string {
407+ "pytorch_lightning==2.4.0" ,
408+ "torchmetrics==1.6.0" ,
409+ "torchvision==0.19.1" ,
410+ }
411+
412+ // Append AMD-specific packages
413+ if gpuResourceName == "amd.com/gpu" {
414+ pipPackages = append (pipPackages ,
415+ "--extra-index-url https://download.pytorch.org/whl/rocm6.1" ,
416+ "torch==2.4.1+rocm6.1" ,
417+ )
418+ }
419+
420+ // Construct RayJob with the final pip list
394421 return & rayv1.RayJob {
395422 TypeMeta : metav1.TypeMeta {
396423 APIVersion : rayv1 .GroupVersion .String (),
@@ -402,17 +429,15 @@ func constructRayJob(_ Test, namespace *corev1.Namespace, rayCluster *rayv1.RayC
402429 },
403430 Spec : rayv1.RayJobSpec {
404431 Entrypoint : "python /home/ray/jobs/mnist.py" ,
405- RuntimeEnvYAML : `
406- pip:
407- - pytorch_lightning==2.4.0
408- - torchmetrics==1.6.0
409- - torchvision==0.20.1
410- env_vars:
411- MNIST_DATASET_URL: "` + GetMnistDatasetURL () + `"
412- PIP_INDEX_URL: "` + GetPipIndexURL () + `"
413- PIP_TRUSTED_HOST: "` + GetPipTrustedHost () + `"
414- ACCELERATOR: "` + accelerator + `"
415- ` ,
432+ RuntimeEnvYAML : fmt .Sprintf (`
433+ pip:
434+ - %s
435+ env_vars:
436+ MNIST_DATASET_URL: "%s"
437+ PIP_INDEX_URL: "%s"
438+ PIP_TRUSTED_HOST: "%s"
439+ ACCELERATOR: "%s"
440+ ` , strings .Join (pipPackages , "\n - " ), GetMnistDatasetURL (), GetPipIndexURL (), GetPipTrustedHost (), accelerator ),
416441 ClusterSelector : map [string ]string {
417442 RayJobDefaultClusterSelectorKey : rayCluster .Name ,
418443 },
@@ -422,7 +447,7 @@ func constructRayJob(_ Test, namespace *corev1.Namespace, rayCluster *rayv1.RayC
422447 RestartPolicy : corev1 .RestartPolicyNever ,
423448 Containers : []corev1.Container {
424449 {
425- Image : GetRayImage () ,
450+ Image : rayImage ,
426451 Name : "rayjob-submitter-pod" ,
427452 },
428453 },
@@ -477,12 +502,12 @@ func getRayDashboardURL(test Test, namespace, rayClusterName string) string {
477502}
478503
479504// Create ClusterQueue
480- func createClusterQueue (test Test , resourceFlavor * v1beta1.ResourceFlavor , numberOfGpus int ) * v1beta1.ClusterQueue {
505+ func createClusterQueue (test Test , resourceFlavor * v1beta1.ResourceFlavor , numberOfGpus int , gpuResourceName string ) * v1beta1.ClusterQueue {
481506 cqSpec := v1beta1.ClusterQueueSpec {
482507 NamespaceSelector : & metav1.LabelSelector {},
483508 ResourceGroups : []v1beta1.ResourceGroup {
484509 {
485- CoveredResources : []corev1.ResourceName {corev1 .ResourceName ("cpu" ), corev1 .ResourceName ("memory" ), corev1 .ResourceName ("nvidia.com/gpu" )},
510+ CoveredResources : []corev1.ResourceName {corev1 .ResourceName ("cpu" ), corev1 .ResourceName ("memory" ), corev1 .ResourceName (gpuResourceName )},
486511 Flavors : []v1beta1.FlavorQuotas {
487512 {
488513 Name : v1beta1 .ResourceFlavorReference (resourceFlavor .Name ),
@@ -496,7 +521,7 @@ func createClusterQueue(test Test, resourceFlavor *v1beta1.ResourceFlavor, numbe
496521 NominalQuota : resource .MustParse ("12Gi" ),
497522 },
498523 {
499- Name : corev1 .ResourceName ("nvidia.com/gpu" ),
524+ Name : corev1 .ResourceName (gpuResourceName ),
500525 NominalQuota : resource .MustParse (fmt .Sprint (numberOfGpus )),
501526 },
502527 },
0 commit comments