@@ -84,10 +84,6 @@ func AddOrOverrideTFClientMissingAnnotationsBeforePatch(pod *v1.Pod, tfInfo Tens
8484 if pod .Annotations == nil {
8585 pod .Annotations = map [string ]string {}
8686 }
87- // add workload to pod annotations just for additional information
88- // so that users will know which GPU workload this pod binds to
89- pod .Annotations [constants .WorkloadKey ] = tfInfo .WorkloadName
90-
9187 // When it's worker, set workload key to label for triggering workload reconcile
9288 if tfInfo .Profile .IsLocalGPU {
9389 if pod .Labels == nil {
@@ -116,7 +112,11 @@ func AddOrOverrideTFClientMissingAnnotationsBeforePatch(pod *v1.Pod, tfInfo Tens
116112 pod .Annotations [constants .InjectContainerAnnotation ] = strings .Join (tfInfo .ContainerNames , "," )
117113}
118114
119- func AppendTFWorkerLabelsAndAnnotationsAfterTemplate (podTmpl * v1.PodTemplate , workload * tfv1.TensorFusionWorkload ) (map [string ]string , map [string ]string ) {
115+ func AppendTFWorkerLabelsAndAnnotationsAfterTemplate (
116+ podTmpl * v1.PodTemplate ,
117+ workload * tfv1.TensorFusionWorkload ,
118+ containerName string ,
119+ ) (map [string ]string , map [string ]string ) {
120120 labels := maps .Clone (podTmpl .Template .Labels )
121121 if labels == nil {
122122 labels = map [string ]string {}
@@ -132,6 +132,7 @@ func AppendTFWorkerLabelsAndAnnotationsAfterTemplate(podTmpl *v1.PodTemplate, wo
132132 annotations [constants .VRAMLimitAnnotation ] = res .Limits .Vram .String ()
133133 annotations [constants .TFLOPSRequestAnnotation ] = res .Requests .Tflops .String ()
134134 annotations [constants .VRAMRequestAnnotation ] = res .Requests .Vram .String ()
135+ annotations [constants .InjectContainerAnnotation ] = containerName
135136 if workload .Spec .Qos == "" {
136137 annotations [constants .QoSLevelAnnotation ] = string (tfv1 .QoSMedium )
137138 } else {
@@ -595,7 +596,7 @@ func AddTFNodeDiscoveryConfAfterTemplate(ctx context.Context, tmpl *v1.PodTempla
595596 }
596597}
597598
598- func AddWorkerConfAfterTemplate (ctx context.Context , spec * v1.PodSpec , workerConfig * tfv1.WorkerConfig , hypervisorConfig * tfv1.HypervisorConfig , workload * tfv1.TensorFusionWorkload ) {
599+ func AddWorkerConfAfterTemplate (ctx context.Context , spec * v1.PodSpec , workerConfig * tfv1.WorkerConfig , hypervisorConfig * tfv1.HypervisorConfig , workload * tfv1.TensorFusionWorkload ) string {
599600 // NOTE: need to set environment variable to make all GPUs visible to the worker,
600601 // vgpu.rs limiter will limit to specific devices after Pod started
601602 spec .Containers [0 ].Name = constants .TFContainerNameWorker
@@ -689,4 +690,6 @@ func AddWorkerConfAfterTemplate(ctx context.Context, spec *v1.PodSpec, workerCon
689690 if len (spec .Containers [0 ].Resources .Requests ) == 0 {
690691 spec .Containers [0 ].Resources .Requests = workerDefaultRequests
691692 }
693+
694+ return spec .Containers [0 ].Name
692695}
0 commit comments