@@ -184,29 +184,14 @@ func (f *Flux) EnforceMLPolicy(info *runtime.Info, trainJob *trainer.TrainJob) e
184184 apply .UpsertVolumes (& info .TemplateSpec .PodSets [psIdx ].Volumes , * curveVolume )
185185
186186 // Important! We have to add this to the JobSet to actually take
187- // Does the initContainer exist?
188- found := false
189- for _ , ic := range ps .InitContainers {
190- for idx , existingIC := range jobSetSpec .ReplicatedJobs [psIdx ].Template .Spec .Template .Spec .InitContainers {
191- if existingIC .Name != nil && * existingIC .Name == ic .Name {
192- jobSetSpec .ReplicatedJobs [psIdx ].Template .Spec .Template .Spec .InitContainers [idx ] = * fluxInstaller
193- found = true
194- break
195- }
196- }
197- }
198-
199- if ! found {
200- jobSetSpec .ReplicatedJobs [psIdx ].Template .Spec .Template .Spec .InitContainers = append (
201- jobSetSpec .ReplicatedJobs [psIdx ].Template .Spec .Template .Spec .InitContainers ,
202- * fluxInstaller ,
203- )
204- }
187+ jobSetSpec .ReplicatedJobs [psIdx ].Template .Spec .Template .Spec .InitContainers = append (
188+ jobSetSpec .ReplicatedJobs [psIdx ].Template .Spec .Template .Spec .InitContainers ,
189+ * fluxInstaller ,
190+ )
205191
206192 // Update Containers in the PodSet
207193 for cIdx , container := range ps .Containers {
208194 if container .Name == constants .Node {
209- // jobSetSpec.ReplicatedJobs[psIdx].Template.Spec.Template.Spec.Containers[cIdx].Command = []string{"/bin/bash", "/etc/flux-config/entrypoint.sh"}
210195 apply .UpsertVolumeMounts (
211196 & info .TemplateSpec .PodSets [psIdx ].Containers [cIdx ].VolumeMounts ,
212197 * corev1ac .VolumeMount ().WithName ("flux-install" ).WithMountPath ("/mnt/flux" ),
@@ -223,12 +208,8 @@ func (f *Flux) EnforceMLPolicy(info *runtime.Info, trainJob *trainer.TrainJob) e
223208// Build creates the extra config map (configuration) and curve secret for Flux.
224209func (f * Flux ) Build (ctx context.Context , info * runtime.Info , trainJob * trainer.TrainJob ) ([]apiruntime.ApplyConfiguration , error ) {
225210
226- // policy defines the Flux HPC cluster setup
227- // Many configuration params cannot be represented in JobSet alone.
228- policy := info .RuntimePolicy .FluxPolicySource
229-
230211 // If the user's chosen runtime does not have the flux policy enabled, skip this plugin
231- if policy == nil {
212+ if info == nil || info . RuntimePolicy . FluxPolicySource == nil {
232213 return nil , nil
233214 }
234215
@@ -287,15 +268,12 @@ func (f *Flux) brokerSettingsFromEnvironment(trainJob *trainer.TrainJob, info *r
287268 // Look through the envars in the runtime spec.
288269 // We only care about the environment defined for the main workers/nodes
289270 if info != nil {
290- for _ , ps := range info .TemplateSpec .PodSets {
291- if ps .Name == constants .Node {
292- for _ , container := range ps .Containers {
293- for _ , envar := range container .Env {
294- if envar .Name != nil && envar .Value != nil {
295- if _ , ok := settings [* envar .Name ]; ok {
296- settings [* envar .Name ] = * envar .Value
297- }
298- }
271+ trainerContainer := info .FindContainerByPodSetAncestorContainerName (constants .AncestorTrainer , constants .Node )
272+ if trainerContainer != nil {
273+ for _ , envar := range trainerContainer .Env {
274+ if envar .Name != nil && envar .Value != nil {
275+ if _ , ok := settings [* envar .Name ]; ok {
276+ settings [* envar .Name ] = * envar .Value
299277 }
300278 }
301279 }
@@ -424,15 +402,9 @@ func getOriginalCommand(trainJob *trainer.TrainJob, info *runtime.Info) string {
424402 var args []string
425403
426404 // check PodSets first
427- for _ , ps := range info .TemplateSpec .PodSets {
428- if ps .Name == constants .Node {
429- for _ , container := range ps .Containers {
430- // Assume for now entire entrypoint logic is in command (with args)
431- if container .Name == constants .Node {
432- command = container .Command
433- }
434- }
435- }
405+ trainerContainer := info .FindContainerByPodSetAncestorContainerName (constants .AncestorTrainer , constants .Node )
406+ if trainerContainer != nil {
407+ command = trainerContainer .Command
436408 }
437409
438410 // Override if user defined them in the top-level Trainer spec
0 commit comments