@@ -32,6 +32,7 @@ import (
3232 . "github.com/opendatahub-io/distributed-workloads/tests/common"
3333 . "github.com/opendatahub-io/distributed-workloads/tests/common/support"
3434 kfto "github.com/opendatahub-io/distributed-workloads/tests/kfto"
35+ trainerutils "github.com/opendatahub-io/distributed-workloads/tests/trainer/utils"
3536)
3637
3738func TestJobSetWorkflow (t * testing.T ) {
@@ -44,13 +45,15 @@ func TestJobSetWorkflow(t *testing.T) {
4445 // Create PVC for shared storage
4546 pvc := CreatePersistentVolumeClaim (test , namespace , "1Gi" , AccessModes (corev1 .ReadWriteOnce ))
4647
48+ // Get node image from ClusterTrainingRuntime
49+ nodeImage , err := trainerutils .GetImageFromClusterTrainingRuntime (test , trainerutils .DefaultClusterTrainingRuntimeCPU )
50+ test .Expect (err ).NotTo (HaveOccurred ())
51+
4752 // Create TrainingRuntime with initializer jobs
48- trainingRuntime := createTrainingRuntimeWithInitializers (test , namespace , pvc .Name )
49- defer deleteTrainingRuntime (test , namespace , trainingRuntime .Name )
53+ trainingRuntime := createTrainingRuntimeWithInitializers (test , namespace , pvc .Name , nodeImage )
5054
5155 // Create TrainJob referring the TrainingRuntime
5256 trainJob := createTrainJobWithInitializers (test , namespace , trainingRuntime .Name )
53- defer deleteTrainJob (test , namespace , trainJob .Name )
5457
5558 // Verify JobSet creation
5659 test .Eventually (SingleJobSet (test , namespace ), TestTimeoutMedium ).Should (
@@ -77,13 +80,15 @@ func TestFailedJobSetWorkflow(t *testing.T) {
7780 // Create PVC for shared storage
7881 pvc := CreatePersistentVolumeClaim (test , namespace , "1Gi" , AccessModes (corev1 .ReadWriteOnce ))
7982
83+ // Get node image from ClusterTrainingRuntime
84+ nodeImage , err := trainerutils .GetImageFromClusterTrainingRuntime (test , trainerutils .DefaultClusterTrainingRuntimeCPU )
85+ test .Expect (err ).NotTo (HaveOccurred ())
86+
8087 // Create TrainingRuntime With Initializers
81- trainingRuntime := createTrainingRuntimeWithInitializers (test , namespace , pvc .Name )
82- defer deleteTrainingRuntime (test , namespace , trainingRuntime .Name )
88+ trainingRuntime := createTrainingRuntimeWithInitializers (test , namespace , pvc .Name , nodeImage )
8389
8490 // Create TrainJob
8591 trainJob := createTrainJobWithFailingInitializer (test , namespace , trainingRuntime .Name )
86- defer deleteTrainJob (test , namespace , trainJob .Name )
8792
8893 // Wait for JobSet failure
8994 test .Eventually (SingleJobSet (test , namespace ), TestTimeoutMedium ).Should (
@@ -100,7 +105,7 @@ func TestFailedJobSetWorkflow(t *testing.T) {
100105 test .T ().Log ("TrainJob failed as expected" )
101106}
102107
103- func createTrainingRuntimeWithInitializers (test Test , namespace , pvcName string ) * trainerv1alpha1.TrainingRuntime {
108+ func createTrainingRuntimeWithInitializers (test Test , namespace , pvcName , nodeImage string ) * trainerv1alpha1.TrainingRuntime {
104109 test .T ().Helper ()
105110
106111 trainingRuntime := & trainerv1alpha1.TrainingRuntime {
@@ -288,7 +293,7 @@ func createTrainingRuntimeWithInitializers(test Test, namespace, pvcName string)
288293 Containers : []corev1.Container {
289294 {
290295 Name : "node" ,
291- Image : GetTrainingCudaPyTorch28Image () ,
296+ Image : nodeImage ,
292297 ImagePullPolicy : corev1 .PullIfNotPresent ,
293298 Resources : corev1.ResourceRequirements {
294299 Requests : corev1.ResourceList {
@@ -502,21 +507,6 @@ func createTrainJobWithFailingInitializer(test Test, namespace, runtimeName stri
502507 return createdTrainJob
503508}
504509
505- func deleteTrainingRuntime (test Test , namespace , name string ) {
506- test .T ().Helper ()
507-
508- err := test .Client ().Trainer ().TrainerV1alpha1 ().TrainingRuntimes (namespace ).Delete (
509- test .Ctx (),
510- name ,
511- metav1.DeleteOptions {},
512- )
513- if err != nil {
514- test .T ().Logf ("Warning: Failed to delete TrainingRuntime %s/%s: %v" , namespace , name , err )
515- } else {
516- test .T ().Logf ("Deleted TrainingRuntime %s/%s successfully" , namespace , name )
517- }
518- }
519-
520510func verifySequentialJobExecution (test Test , namespace string ) {
521511 test .T ().Helper ()
522512
0 commit comments