Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions manifests/rhoai/rbac_progression_patch.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,21 @@
verbs:
- get
- list
# RHAI-specific: Permissions for NetworkPolicy management
# Required to create/update NetworkPolicies that restrict metrics endpoint access
# to controller pods only (security hardening for progression tracking)
# Note: list/watch needed for controller-runtime cache, delete not needed (OwnerReference cleanup)
- op: add
path: /rules/-
value:
apiGroups:
- networking.k8s.io
resources:
- networkpolicies
verbs:
- get
- list
- watch
- create
- update
- patch
5 changes: 5 additions & 0 deletions pkg/controller/trainjob_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ import (

trainer "github.com/kubeflow/trainer/v2/pkg/apis/trainer/v1alpha1"
"github.com/kubeflow/trainer/v2/pkg/constants"
"github.com/kubeflow/trainer/v2/pkg/rhai"
"github.com/kubeflow/trainer/v2/pkg/rhai/progression"
jobruntimes "github.com/kubeflow/trainer/v2/pkg/runtime"
)
Expand Down Expand Up @@ -158,6 +159,10 @@ func (r *TrainJobReconciler) reconcileObjects(ctx context.Context, runtime jobru
return err
}
}
// Reconcile NetworkPolicy for pod isolation
if err := rhai.ReconcileNetworkPolicy(ctx, r.client, trainJob); err != nil {
return err
}
Comment on lines +162 to +165
Copy link

@coderabbitai coderabbitai bot Dec 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Search for the networkpolicy implementation and annotation checks
rg -n -A10 -B2 'func.*ReconcileNetworkPolicy' pkg/rhai/progression/

Repository: opendatahub-io/trainer

Length of output: 2193


🏁 Script executed:

#!/bin/bash
# Also search for annotation checks in the progression package
rg -n 'AnnotationProgressionTracking' pkg/rhai/progression/

Repository: opendatahub-io/trainer

Length of output: 1095


🏁 Script executed:

#!/bin/bash
# Check if there's an annotation constant defined
rg -n 'AnnotationProgressionTracking' --type-list | head -20
rg -n 'AnnotationProgressionTracking' pkg/

Repository: opendatahub-io/trainer

Length of output: 1746


🏁 Script executed:

#!/bin/bash
# Check progression.go around line 92 to see if annotation check wraps ReconcileNetworkPolicy
sed -n '85,110p' pkg/rhai/progression/progression.go

Repository: opendatahub-io/trainer

Length of output: 1048


🏁 Script executed:

#!/bin/bash
# Check the buildNetworkPolicy function
rg -n -A15 'func buildNetworkPolicy' pkg/rhai/progression/networkpolicy.go

Repository: opendatahub-io/trainer

Length of output: 649


🏁 Script executed:

#!/bin/bash
# Check the full context of how ReconcileNetworkPolicy is called in progression.go
rg -n -B5 -A5 'ReconcileNetworkPolicy' pkg/rhai/progression/progression.go

Repository: opendatahub-io/trainer

Length of output: 48


🏁 Script executed:

#!/bin/bash
# Check the trainjob_controller.go to understand the exact call context
sed -n '155,175p' pkg/controller/trainjob_controller.go

Repository: opendatahub-io/trainer

Length of output: 785


🏁 Script executed:

#!/bin/bash
# Search for all calls to ReconcileNetworkPolicy to see if they're conditional elsewhere
rg -n -B3 -A1 'ReconcileNetworkPolicy' pkg/

Repository: opendatahub-io/trainer

Length of output: 1970


🏁 Script executed:

#!/bin/bash
# Check if there are any conditions checking progression tracking before the NetworkPolicy call
sed -n '155,170p' pkg/controller/trainjob_controller.go

Repository: opendatahub-io/trainer

Length of output: 600


🏁 Script executed:

#!/bin/bash
# Check what happens in the broader reconciliation context
sed -n '130,175p' pkg/controller/trainjob_controller.go

Repository: opendatahub-io/trainer

Length of output: 1688


Add conditional check for progression-tracking annotation before NetworkPolicy reconciliation

The ReconcileNetworkPolicy call is unconditional, but the PR description states the NetworkPolicy is added "when progression-tracking is enabled." Since IsProgressionTrackingEnabled() already exists, wrap the call:

if progression.IsProgressionTrackingEnabled(trainJob) {
    if err := progression.ReconcileNetworkPolicy(ctx, r.client, trainJob); err != nil {
        return err
    }
}

Otherwise, NetworkPolicy is created for all TrainJobs regardless of the progression-tracking setting.

🤖 Prompt for AI Agents
In pkg/controller/trainjob_controller.go around lines 161 to 164, the call to
progression.ReconcileNetworkPolicy runs unconditionally causing NetworkPolicies
to be created for all TrainJobs; wrap that call with a conditional check using
progression.IsProgressionTrackingEnabled(trainJob) and only call
ReconcileNetworkPolicy if it returns true, preserving the existing error
handling to return the error when reconciliation fails.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I think this is a good suggestion, can we ensure, for any existing trainjob with progression tracking disabled, the pod behaviour does not change.
cc. @abhijeet-dhumal

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like the humans are having a chat. I'll hop back into my burrow for now. If you need me again, just tag @coderabbitai in a new comment, and I'll come hopping out!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kapil27 Actually this was a intentional design decision. The NetworkPolicy serves two purposes as described in PR description :

  • Pod isolation > only same-job pods can communicate, cross-job traffic blocked (primary- not dependent on progression )
  • Progression Metrics port restriction > controller-only access
    These benefits all TrainJobs, regardless of progression tracking. Without it, any pod in the namespace could access training pods.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah to think of it you are right 🤔
we should make progression based rule conditional on IsProgressionTrackingEnabled() since the metrics server only runs when progression tracking is enabled 👀
Thanks @kapil27 , On it !

return nil
}

Expand Down
31 changes: 28 additions & 3 deletions pkg/rhai/constants/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ const (
// Progression tracking feature annotations

// AnnotationProgressionTracking enables/disables progression tracking for a TrainJob.
// Value: "enabled" to enable tracking, any other value or absence disables it.
// Example: trainer.opendatahub.io/progression-tracking: "enabled"
// Value: "true" to enable tracking, any other value or absence disables it.
// Example: trainer.opendatahub.io/progression-tracking: "true"
AnnotationProgressionTracking string = "trainer.opendatahub.io/progression-tracking"

// AnnotationTrainerStatus stores the JSON-encoded training status/progress.
Expand All @@ -31,7 +31,9 @@ const (
AnnotationTrainerStatus string = "trainer.opendatahub.io/trainerStatus"

// AnnotationMetricsPort specifies the port where the training pod exposes metrics.
// Default: 28080
// Default: 28080. Valid range: 1024-65535 (non-privileged ports).
// Ports 0-1023 require root privileges and are incompatible with OpenShift
// restricted SCCs and Kubernetes non-root security policies.
// Example: trainer.opendatahub.io/metrics-port: "8080"
AnnotationMetricsPort string = "trainer.opendatahub.io/metrics-port"

Expand Down Expand Up @@ -60,4 +62,27 @@ const (
// TerminationGraceBufferSecs is added to preStop duration for pod termination grace period.
// This allows time for graceful process shutdown after preStop hook completes.
TerminationGraceBufferSecs int = 30

// NetworkPolicy constants for metrics endpoint security

// DefaultControllerNamespace is the fallback when SA namespace file is unavailable.
DefaultControllerNamespace string = "opendatahub"

// ControllerPodLabelName is the label key used to identify the controller pod.
// NetworkPolicy uses this to allow controller access to training pod metrics.
ControllerPodLabelName string = "app.kubernetes.io/name"

// ControllerPodLabelNameValue is the expected value for the controller name label.
// Must match the label applied to controller pods in deployment manifests.
// RHOAI: Set via kustomization.yaml labels overlay.
ControllerPodLabelNameValue string = "trainer"

// ControllerPodLabelComponent is the label key for component identification.
ControllerPodLabelComponent string = "app.kubernetes.io/component"

// ControllerPodLabelComponentValue is the expected value for the controller component label.
// RHOAI uses "controller" (set via kustomization.yaml).
// Upstream Kubeflow uses "manager" (set in base/manager/manager.yaml).
// This value must match your deployment's controller pod labels.
ControllerPodLabelComponentValue string = "controller"
)
180 changes: 180 additions & 0 deletions pkg/rhai/networkpolicy.go
Copy link
Collaborator

@robert-bell robert-bell Dec 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: does it make sense to move this netpol code into the rhai package, rather than progression package?

I'm happy for this to be merged as is though :)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah definitely.. as this netpol is not centralised for progression scope.. it would be good to refactor accordingly
Thanks Rob, on it!!

Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
/*
Copyright 2024 The Kubeflow Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package rhai

import (
"context"
"fmt"
"os"
"strconv"
"strings"

corev1 "k8s.io/api/core/v1"
networkingv1 "k8s.io/api/networking/v1"
apierrors "k8s.io/apimachinery/pkg/api/errors"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/util/intstr"
"k8s.io/klog/v2"
"sigs.k8s.io/controller-runtime/pkg/client"

trainer "github.com/kubeflow/trainer/v2/pkg/apis/trainer/v1alpha1"
"github.com/kubeflow/trainer/v2/pkg/rhai/constants"
"github.com/kubeflow/trainer/v2/pkg/rhai/progression"
)

const serviceAccountNamespaceFile = "/var/run/secrets/kubernetes.io/serviceaccount/namespace"

// getControllerNamespace returns the controller's namespace from SA mount.
func getControllerNamespace() string {
if data, err := os.ReadFile(serviceAccountNamespaceFile); err == nil {
if ns := strings.TrimSpace(string(data)); ns != "" {
return ns
}
}
return constants.DefaultControllerNamespace
}

func getNetworkPolicyName(trainJob *trainer.TrainJob) string {
return trainJob.Name
}

// buildNetworkPolicy creates a NetworkPolicy for the TrainJob's pods.
// Rule 1 (same-job pods → all ports) is always added for pod isolation.
// Rule 2 (controller → metrics port) is only added when progression tracking is enabled.
func buildNetworkPolicy(trainJob *trainer.TrainJob) *networkingv1.NetworkPolicy {
ingressRules := []networkingv1.NetworkPolicyIngressRule{}

// Rule 1: Same-job pods → all ports (always, for NCCL/MPI/gRPC)
ingressRules = append(ingressRules, networkingv1.NetworkPolicyIngressRule{
From: []networkingv1.NetworkPolicyPeer{
{
PodSelector: &metav1.LabelSelector{
MatchLabels: map[string]string{
"jobset.sigs.k8s.io/jobset-name": trainJob.Name,
},
},
},
},
})

// Rule 2: Controller → metrics port (only when progression tracking enabled)
if progression.IsProgressionTrackingEnabled(trainJob) {
metricsPort := progression.GetMetricsPort(trainJob)
portNum, err := strconv.Atoi(metricsPort)
if err != nil {
klog.Warningf("Invalid metrics port %q for TrainJob %s/%s, falling back to default %s",
metricsPort, trainJob.Namespace, trainJob.Name, constants.DefaultMetricsPort)
portNum, _ = strconv.Atoi(constants.DefaultMetricsPort)
}
port := intstr.FromInt(portNum)
controllerNamespace := getControllerNamespace()

ingressRules = append(ingressRules, networkingv1.NetworkPolicyIngressRule{
From: []networkingv1.NetworkPolicyPeer{
{
NamespaceSelector: &metav1.LabelSelector{
MatchLabels: map[string]string{
"kubernetes.io/metadata.name": controllerNamespace,
},
},
PodSelector: &metav1.LabelSelector{
MatchLabels: map[string]string{
constants.ControllerPodLabelName: constants.ControllerPodLabelNameValue,
constants.ControllerPodLabelComponent: constants.ControllerPodLabelComponentValue,
},
},
},
},
Ports: []networkingv1.NetworkPolicyPort{
{
Protocol: protocolPtr(corev1.ProtocolTCP),
Port: &port,
},
},
})
}

return &networkingv1.NetworkPolicy{
ObjectMeta: metav1.ObjectMeta{
Name: getNetworkPolicyName(trainJob),
Namespace: trainJob.Namespace,
Labels: map[string]string{
"trainer.kubeflow.org/trainjob-name": trainJob.Name,
"trainer.kubeflow.org/component": "network-policy",
},
OwnerReferences: []metav1.OwnerReference{
{
APIVersion: trainer.SchemeGroupVersion.String(),
Kind: "TrainJob",
Name: trainJob.Name,
UID: trainJob.UID,
Controller: boolPtr(true),
BlockOwnerDeletion: boolPtr(true),
},
},
},
Spec: networkingv1.NetworkPolicySpec{
PodSelector: metav1.LabelSelector{
MatchLabels: map[string]string{
"jobset.sigs.k8s.io/jobset-name": trainJob.Name,
},
},
PolicyTypes: []networkingv1.PolicyType{
networkingv1.PolicyTypeIngress,
},
Ingress: ingressRules,
},
}
}

func boolPtr(b bool) *bool {
return &b
}

func protocolPtr(p corev1.Protocol) *corev1.Protocol {
return &p
}

// ReconcileNetworkPolicy creates/updates NetworkPolicy for the TrainJob.
// Uses OwnerReference for automatic cleanup.
func ReconcileNetworkPolicy(ctx context.Context, c client.Client, trainJob *trainer.TrainJob) error {
desiredPolicy := buildNetworkPolicy(trainJob)
existingPolicy := &networkingv1.NetworkPolicy{}
err := c.Get(ctx, client.ObjectKey{
Namespace: trainJob.Namespace,
Name: getNetworkPolicyName(trainJob),
}, existingPolicy)

if apierrors.IsNotFound(err) {
if createErr := c.Create(ctx, desiredPolicy); createErr != nil {
return fmt.Errorf("failed to create NetworkPolicy: %w", createErr)
}
return nil
}

if err != nil {
return fmt.Errorf("failed to get NetworkPolicy: %w", err)
}

existingPolicy.Spec = desiredPolicy.Spec
existingPolicy.Labels = desiredPolicy.Labels
if updateErr := c.Update(ctx, existingPolicy); updateErr != nil {
return fmt.Errorf("failed to update NetworkPolicy: %w", updateErr)
}

return nil
}
Loading
Loading