@@ -21,6 +21,7 @@ import (
2121 "fmt"
2222 "net/http"
2323 "net/url"
24+ "strings"
2425 "testing"
2526
2627 . "github.com/onsi/gomega"
@@ -40,14 +41,18 @@ import (
4041// directly managed by Kueue, and asserts successful completion of the training job.
4142
4243func TestMnistRayJobRayClusterCpu (t * testing.T ) {
43- runMnistRayJobRayCluster (t , "cpu" , 0 )
44+ runMnistRayJobRayCluster (t , CPU , 0 , GetRayImage () )
4445}
4546
46- func TestMnistRayJobRayClusterGpu (t * testing.T ) {
47- runMnistRayJobRayCluster (t , "gpu" , 1 )
47+ func TestMnistRayJobRayClusterCudaGpu (t * testing.T ) {
48+ runMnistRayJobRayCluster (t , NVIDIA , 1 , GetRayImage () )
4849}
4950
50- func runMnistRayJobRayCluster (t * testing.T , accelerator string , numberOfGpus int ) {
51+ func TestMnistRayJobRayClusterROCmGpu (t * testing.T ) {
52+ runMnistRayJobRayCluster (t , AMD , 1 , GetRayROCmImage ())
53+ }
54+
55+ func runMnistRayJobRayCluster (t * testing.T , accelerator Accelerator , numberOfGpus int , rayImage string ) {
5156 test := With (t )
5257
5358 // Create a static namespace to ensure a consistent Ray Dashboard hostname entry in /etc/hosts before executing the test.
@@ -58,11 +63,12 @@ func runMnistRayJobRayCluster(t *testing.T, accelerator string, numberOfGpus int
5863 defer func () {
5964 _ = test .Client ().Kueue ().KueueV1beta1 ().ResourceFlavors ().Delete (test .Ctx (), resourceFlavor .Name , metav1.DeleteOptions {})
6065 }()
61- clusterQueue := createClusterQueue (test , resourceFlavor , numberOfGpus )
66+
67+ clusterQueue := createClusterQueue (test , resourceFlavor , numberOfGpus , accelerator )
6268 defer func () {
6369 _ = test .Client ().Kueue ().KueueV1beta1 ().ClusterQueues ().Delete (test .Ctx (), clusterQueue .Name , metav1.DeleteOptions {})
6470 }()
65- CreateKueueLocalQueue (test , namespace .Name , clusterQueue .Name , AsDefaultQueue )
71+ localQueue := CreateKueueLocalQueue (test , namespace .Name , clusterQueue .Name , AsDefaultQueue )
6672
6773 // Create MNIST training script
6874 mnist := constructMNISTConfigMap (test , namespace )
@@ -71,7 +77,7 @@ func runMnistRayJobRayCluster(t *testing.T, accelerator string, numberOfGpus int
7177 test .T ().Logf ("Created ConfigMap %s/%s successfully" , mnist .Namespace , mnist .Name )
7278
7379 // Create RayCluster and assign it to the localqueue
74- rayCluster := constructRayCluster (test , namespace , mnist , numberOfGpus )
80+ rayCluster := constructRayCluster (test , namespace , localQueue . Name , mnist , numberOfGpus , accelerator , rayImage )
7581 rayCluster , err = test .Client ().Ray ().RayV1 ().RayClusters (namespace .Name ).Create (test .Ctx (), rayCluster , metav1.CreateOptions {})
7682 test .Expect (err ).NotTo (HaveOccurred ())
7783 test .T ().Logf ("Created RayCluster %s/%s successfully" , rayCluster .Namespace , rayCluster .Name )
@@ -81,7 +87,7 @@ func runMnistRayJobRayCluster(t *testing.T, accelerator string, numberOfGpus int
8187 Should (WithTransform (RayClusterState , Equal (rayv1 .Ready )))
8288
8389 // Create RayJob
84- rayJob := constructRayJob (test , namespace , rayCluster , accelerator , numberOfGpus )
90+ rayJob := constructRayJob (test , namespace , rayCluster , accelerator , numberOfGpus , rayImage )
8591 rayJob , err = test .Client ().Ray ().RayV1 ().RayJobs (namespace .Name ).Create (test .Ctx (), rayJob , metav1.CreateOptions {})
8692 test .Expect (err ).NotTo (HaveOccurred ())
8793 test .T ().Logf ("Created RayJob %s/%s successfully" , rayJob .Namespace , rayJob .Name )
@@ -110,15 +116,19 @@ func runMnistRayJobRayCluster(t *testing.T, accelerator string, numberOfGpus int
110116}
111117
112118func TestMnistRayJobRayClusterAppWrapperCpu (t * testing.T ) {
113- runMnistRayJobRayClusterAppWrapper (t , "cpu" , 0 )
119+ runMnistRayJobRayClusterAppWrapper (t , CPU , 0 , GetRayImage ())
120+ }
121+
122+ func TestMnistRayJobRayClusterAppWrapperCudaGpu (t * testing.T ) {
123+ runMnistRayJobRayClusterAppWrapper (t , NVIDIA , 1 , GetRayImage ())
114124}
115125
116- func TestMnistRayJobRayClusterAppWrapperGpu (t * testing.T ) {
117- runMnistRayJobRayClusterAppWrapper (t , "gpu" , 1 )
126+ func TestMnistRayJobRayClusterAppWrapperROCmGpu (t * testing.T ) {
127+ runMnistRayJobRayClusterAppWrapper (t , AMD , 1 , GetRayROCmImage () )
118128}
119129
120130// Same as TestMNISTRayJobRayCluster, except the RayCluster is wrapped in an AppWrapper
121- func runMnistRayJobRayClusterAppWrapper (t * testing.T , accelerator string , numberOfGpus int ) {
131+ func runMnistRayJobRayClusterAppWrapper (t * testing.T , accelerator Accelerator , numberOfGpus int , rayImage string ) {
122132 test := With (t )
123133
124134 // Create a static namespace to ensure a consistent Ray Dashboard hostname entry in /etc/hosts before executing the test.
@@ -129,7 +139,7 @@ func runMnistRayJobRayClusterAppWrapper(t *testing.T, accelerator string, number
129139 defer func () {
130140 _ = test .Client ().Kueue ().KueueV1beta1 ().ResourceFlavors ().Delete (test .Ctx (), resourceFlavor .Name , metav1.DeleteOptions {})
131141 }()
132- clusterQueue := createClusterQueue (test , resourceFlavor , numberOfGpus )
142+ clusterQueue := createClusterQueue (test , resourceFlavor , numberOfGpus , accelerator )
133143 defer func () {
134144 _ = test .Client ().Kueue ().KueueV1beta1 ().ClusterQueues ().Delete (test .Ctx (), clusterQueue .Name , metav1.DeleteOptions {})
135145 }()
@@ -142,7 +152,7 @@ func runMnistRayJobRayClusterAppWrapper(t *testing.T, accelerator string, number
142152 test .T ().Logf ("Created ConfigMap %s/%s successfully" , mnist .Namespace , mnist .Name )
143153
144154 // Create RayCluster, wrap in AppWrapper and assign to localqueue
145- rayCluster := constructRayCluster (test , namespace , mnist , numberOfGpus )
155+ rayCluster := constructRayCluster (test , namespace , localQueue . Name , mnist , numberOfGpus , accelerator , rayImage )
146156 raw := Raw (test , rayCluster )
147157 raw = RemoveCreationTimestamp (test , raw )
148158
@@ -183,7 +193,7 @@ func runMnistRayJobRayClusterAppWrapper(t *testing.T, accelerator string, number
183193 Should (WithTransform (RayClusterState , Equal (rayv1 .Ready )))
184194
185195 // Create RayJob
186- rayJob := constructRayJob (test , namespace , rayCluster , accelerator , numberOfGpus )
196+ rayJob := constructRayJob (test , namespace , rayCluster , accelerator , numberOfGpus , rayImage )
187197 rayJob , err = test .Client ().Ray ().RayV1 ().RayJobs (namespace .Name ).Create (test .Ctx (), rayJob , metav1.CreateOptions {})
188198 test .Expect (err ).NotTo (HaveOccurred ())
189199 test .T ().Logf ("Created RayJob %s/%s successfully" , rayJob .Namespace , rayJob .Name )
@@ -223,11 +233,11 @@ func TestRayClusterImagePullSecret(t *testing.T) {
223233 defer func () {
224234 _ = test .Client ().Kueue ().KueueV1beta1 ().ResourceFlavors ().Delete (test .Ctx (), resourceFlavor .Name , metav1.DeleteOptions {})
225235 }()
226- clusterQueue := createClusterQueue (test , resourceFlavor , 0 )
236+ clusterQueue := createClusterQueue (test , resourceFlavor , 0 , CPU )
227237 defer func () {
228238 _ = test .Client ().Kueue ().KueueV1beta1 ().ClusterQueues ().Delete (test .Ctx (), clusterQueue .Name , metav1.DeleteOptions {})
229239 }()
230- CreateKueueLocalQueue (test , namespace .Name , clusterQueue .Name , AsDefaultQueue )
240+ localQueue := CreateKueueLocalQueue (test , namespace .Name , clusterQueue .Name , AsDefaultQueue )
231241
232242 // Create MNIST training script
233243 mnist := constructMNISTConfigMap (test , namespace )
@@ -236,7 +246,7 @@ func TestRayClusterImagePullSecret(t *testing.T) {
236246 test .T ().Logf ("Created ConfigMap %s/%s successfully" , mnist .Namespace , mnist .Name )
237247
238248 // Create RayCluster with imagePullSecret and assign it to the localqueue
239- rayCluster := constructRayCluster (test , namespace , mnist , 0 )
249+ rayCluster := constructRayCluster (test , namespace , localQueue . Name , mnist , 0 , CPU , GetRayImage () )
240250 rayCluster .Spec .HeadGroupSpec .Template .Spec .ImagePullSecrets = []corev1.LocalObjectReference {{Name : "custom-pull-secret" }}
241251 rayCluster , err = test .Client ().Ray ().RayV1 ().RayClusters (namespace .Name ).Create (test .Ctx (), rayCluster , metav1.CreateOptions {})
242252 test .Expect (err ).NotTo (HaveOccurred ())
@@ -266,15 +276,18 @@ func constructMNISTConfigMap(test Test, namespace *corev1.Namespace) *corev1.Con
266276 }
267277}
268278
269- func constructRayCluster (_ Test , namespace * corev1.Namespace , mnist * corev1.ConfigMap , numberOfGpus int ) * rayv1.RayCluster {
270- return & rayv1.RayCluster {
279+ func constructRayCluster (_ Test , namespace * corev1.Namespace , localQueueName string , mnist * corev1.ConfigMap , numberOfGpus int , accelerator Accelerator , rayImage string ) * rayv1.RayCluster {
280+ raycluster := rayv1.RayCluster {
271281 TypeMeta : metav1.TypeMeta {
272282 APIVersion : rayv1 .GroupVersion .String (),
273283 Kind : "RayCluster" ,
274284 },
275285 ObjectMeta : metav1.ObjectMeta {
276286 Name : "raycluster" ,
277287 Namespace : namespace .Name ,
288+ Labels : map [string ]string {
289+ "kueue.x-k8s.io/queue-name" : localQueueName ,
290+ },
278291 },
279292 Spec : rayv1.RayClusterSpec {
280293 RayVersion : GetRayVersion (),
@@ -287,7 +300,7 @@ func constructRayCluster(_ Test, namespace *corev1.Namespace, mnist *corev1.Conf
287300 Containers : []corev1.Container {
288301 {
289302 Name : "ray-head" ,
290- Image : GetRayImage () ,
303+ Image : rayImage ,
291304 Ports : []corev1.ContainerPort {
292305 {
293306 ContainerPort : 6379 ,
@@ -342,7 +355,7 @@ func constructRayCluster(_ Test, namespace *corev1.Namespace, mnist *corev1.Conf
342355 Containers : []corev1.Container {
343356 {
344357 Name : "ray-worker" ,
345- Image : GetRayImage () ,
358+ Image : rayImage ,
346359 Lifecycle : & corev1.Lifecycle {
347360 PreStop : & corev1.LifecycleHandler {
348361 Exec : & corev1.ExecAction {
@@ -352,14 +365,14 @@ func constructRayCluster(_ Test, namespace *corev1.Namespace, mnist *corev1.Conf
352365 },
353366 Resources : corev1.ResourceRequirements {
354367 Requests : corev1.ResourceList {
355- corev1 .ResourceCPU : resource .MustParse ("250m" ),
356- corev1 .ResourceMemory : resource .MustParse ("1G" ),
357- "nvidia.com/gpu" : resource .MustParse (fmt .Sprint (numberOfGpus )),
368+ corev1 .ResourceCPU : resource .MustParse ("250m" ),
369+ corev1 .ResourceMemory : resource .MustParse ("1G" ),
370+ corev1 . ResourceName ( "nvidia.com/gpu" ): resource .MustParse (fmt .Sprint (numberOfGpus )),
358371 },
359372 Limits : corev1.ResourceList {
360- corev1 .ResourceCPU : resource .MustParse ("2" ),
361- corev1 .ResourceMemory : resource .MustParse ("4G" ),
362- "nvidia.com/gpu" : resource .MustParse (fmt .Sprint (numberOfGpus )),
373+ corev1 .ResourceCPU : resource .MustParse ("2" ),
374+ corev1 .ResourceMemory : resource .MustParse ("4G" ),
375+ corev1 . ResourceName ( "nvidia.com/gpu" ): resource .MustParse (fmt .Sprint (numberOfGpus )),
363376 },
364377 },
365378 VolumeMounts : []corev1.VolumeMount {
@@ -388,9 +401,37 @@ func constructRayCluster(_ Test, namespace *corev1.Namespace, mnist *corev1.Conf
388401 },
389402 },
390403 }
404+
405+ if accelerator .ResourceLabel == "amd.com/gpu" {
406+ // Remove the nvidia.com/gpu resource
407+ delete (raycluster .Spec .WorkerGroupSpecs [0 ].Template .Spec .Containers [0 ].Resources .Requests , corev1 .ResourceName ("nvidia.com/gpu" ))
408+ delete (raycluster .Spec .WorkerGroupSpecs [0 ].Template .Spec .Containers [0 ].Resources .Limits , corev1 .ResourceName ("nvidia.com/gpu" ))
409+
410+ // update with amd.com/gpu resource
411+ raycluster .Spec .WorkerGroupSpecs [0 ].Template .Spec .Tolerations [0 ].Key = "amd.com/gpu"
412+ raycluster .Spec .WorkerGroupSpecs [0 ].Template .Spec .Containers [0 ].Resources .Requests [corev1 .ResourceName ("amd.com/gpu" )] = resource .MustParse (fmt .Sprint (numberOfGpus ))
413+ raycluster .Spec .WorkerGroupSpecs [0 ].Template .Spec .Containers [0 ].Resources .Limits [corev1 .ResourceName ("amd.com/gpu" )] = resource .MustParse (fmt .Sprint (numberOfGpus ))
414+ }
415+
416+ return & raycluster
391417}
392418
393- func constructRayJob (_ Test , namespace * corev1.Namespace , rayCluster * rayv1.RayCluster , accelerator string , numberOfGpus int ) * rayv1.RayJob {
419+ func constructRayJob (_ Test , namespace * corev1.Namespace , rayCluster * rayv1.RayCluster , accelerator Accelerator , numberOfGpus int , rayImage string ) * rayv1.RayJob {
420+ pipPackages := []string {
421+ "pytorch_lightning==2.4.0" ,
422+ "torchmetrics==1.6.0" ,
423+ "torchvision==0.19.1" ,
424+ }
425+
426+ // Append AMD-specific packages
427+ if accelerator .ResourceLabel == "amd.com/gpu" {
428+ pipPackages = append (pipPackages ,
429+ "--extra-index-url https://download.pytorch.org/whl/rocm6.1" ,
430+ "torch==2.4.1+rocm6.1" ,
431+ )
432+ }
433+
434+ // Construct RayJob with the final pip list
394435 return & rayv1.RayJob {
395436 TypeMeta : metav1.TypeMeta {
396437 APIVersion : rayv1 .GroupVersion .String (),
@@ -402,17 +443,15 @@ func constructRayJob(_ Test, namespace *corev1.Namespace, rayCluster *rayv1.RayC
402443 },
403444 Spec : rayv1.RayJobSpec {
404445 Entrypoint : "python /home/ray/jobs/mnist.py" ,
405- RuntimeEnvYAML : `
406- pip:
407- - pytorch_lightning==2.4.0
408- - torchmetrics==1.6.0
409- - torchvision==0.20.1
410- env_vars:
411- MNIST_DATASET_URL: "` + GetMnistDatasetURL () + `"
412- PIP_INDEX_URL: "` + GetPipIndexURL () + `"
413- PIP_TRUSTED_HOST: "` + GetPipTrustedHost () + `"
414- ACCELERATOR: "` + accelerator + `"
415- ` ,
446+ RuntimeEnvYAML : fmt .Sprintf (`
447+ pip:
448+ - %s
449+ env_vars:
450+ MNIST_DATASET_URL: "%s"
451+ PIP_INDEX_URL: "%s"
452+ PIP_TRUSTED_HOST: "%s"
453+ ACCELERATOR: "%s"
454+ ` , strings .Join (pipPackages , "\n - " ), GetMnistDatasetURL (), GetPipIndexURL (), GetPipTrustedHost (), accelerator .Type ),
416455 ClusterSelector : map [string ]string {
417456 RayJobDefaultClusterSelectorKey : rayCluster .Name ,
418457 },
@@ -422,7 +461,7 @@ func constructRayJob(_ Test, namespace *corev1.Namespace, rayCluster *rayv1.RayC
422461 RestartPolicy : corev1 .RestartPolicyNever ,
423462 Containers : []corev1.Container {
424463 {
425- Image : GetRayImage () ,
464+ Image : rayImage ,
426465 Name : "rayjob-submitter-pod" ,
427466 },
428467 },
@@ -477,7 +516,7 @@ func getRayDashboardURL(test Test, namespace, rayClusterName string) string {
477516}
478517
479518// Create ClusterQueue
480- func createClusterQueue (test Test , resourceFlavor * v1beta1.ResourceFlavor , numberOfGpus int ) * v1beta1.ClusterQueue {
519+ func createClusterQueue (test Test , resourceFlavor * v1beta1.ResourceFlavor , numberOfGpus int , accelerator Accelerator ) * v1beta1.ClusterQueue {
481520 cqSpec := v1beta1.ClusterQueueSpec {
482521 NamespaceSelector : & metav1.LabelSelector {},
483522 ResourceGroups : []v1beta1.ResourceGroup {
@@ -505,5 +544,11 @@ func createClusterQueue(test Test, resourceFlavor *v1beta1.ResourceFlavor, numbe
505544 },
506545 },
507546 }
547+
548+ if accelerator .ResourceLabel == "amd.com/gpu" {
549+ cqSpec .ResourceGroups [0 ].CoveredResources [2 ] = corev1 .ResourceName (accelerator .ResourceLabel )
550+ cqSpec .ResourceGroups [0 ].Flavors [0 ].Resources [2 ].Name = corev1 .ResourceName (accelerator .ResourceLabel )
551+ }
552+
508553 return CreateKueueClusterQueue (test , cqSpec )
509554}
0 commit comments