Skip to content

Commit c0df748

Browse files
feat: Add network policy for TrainJobs
Signed-off-by: abhijeet-dhumal <abhijeetdhumal652@gmail.com>
1 parent e99c261 commit c0df748

File tree

5 files changed

+690
-3
lines changed

5 files changed

+690
-3
lines changed

manifests/rhoai/rbac_progression_patch.yaml

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,21 @@
1111
verbs:
1212
- get
1313
- list
14+
# RHAI-specific: Permissions for NetworkPolicy management
15+
# Required to create/update NetworkPolicies that restrict metrics endpoint access
16+
# to controller pods only (security hardening for progression tracking)
17+
# Note: list/watch needed for controller-runtime cache, delete not needed (OwnerReference cleanup)
18+
- op: add
19+
path: /rules/-
20+
value:
21+
apiGroups:
22+
- networking.k8s.io
23+
resources:
24+
- networkpolicies
25+
verbs:
26+
- get
27+
- list
28+
- watch
29+
- create
30+
- update
31+
- patch

pkg/controller/trainjob_controller.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ import (
4343

4444
trainer "github.com/kubeflow/trainer/v2/pkg/apis/trainer/v1alpha1"
4545
"github.com/kubeflow/trainer/v2/pkg/constants"
46+
"github.com/kubeflow/trainer/v2/pkg/rhai"
4647
"github.com/kubeflow/trainer/v2/pkg/rhai/progression"
4748
jobruntimes "github.com/kubeflow/trainer/v2/pkg/runtime"
4849
)
@@ -158,6 +159,10 @@ func (r *TrainJobReconciler) reconcileObjects(ctx context.Context, runtime jobru
158159
return err
159160
}
160161
}
162+
// Reconcile NetworkPolicy for pod isolation
163+
if err := rhai.ReconcileNetworkPolicy(ctx, r.client, trainJob); err != nil {
164+
return err
165+
}
161166
return nil
162167
}
163168

pkg/rhai/constants/constants.go

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@ const (
2121
// Progression tracking feature annotations
2222

2323
// AnnotationProgressionTracking enables/disables progression tracking for a TrainJob.
24-
// Value: "enabled" to enable tracking, any other value or absence disables it.
25-
// Example: trainer.opendatahub.io/progression-tracking: "enabled"
24+
// Value: "true" to enable tracking, any other value or absence disables it.
25+
// Example: trainer.opendatahub.io/progression-tracking: "true"
2626
AnnotationProgressionTracking string = "trainer.opendatahub.io/progression-tracking"
2727

2828
// AnnotationTrainerStatus stores the JSON-encoded training status/progress.
@@ -31,7 +31,9 @@ const (
3131
AnnotationTrainerStatus string = "trainer.opendatahub.io/trainerStatus"
3232

3333
// AnnotationMetricsPort specifies the port where the training pod exposes metrics.
34-
// Default: 28080
34+
// Default: 28080. Valid range: 1024-65535 (non-privileged ports).
35+
// Ports 0-1023 require root privileges and are incompatible with OpenShift
36+
// restricted SCCs and Kubernetes non-root security policies.
3537
// Example: trainer.opendatahub.io/metrics-port: "8080"
3638
AnnotationMetricsPort string = "trainer.opendatahub.io/metrics-port"
3739

@@ -60,4 +62,27 @@ const (
6062
// TerminationGraceBufferSecs is added to preStop duration for pod termination grace period.
6163
// This allows time for graceful process shutdown after preStop hook completes.
6264
TerminationGraceBufferSecs int = 30
65+
66+
// NetworkPolicy constants for metrics endpoint security
67+
68+
// DefaultControllerNamespace is the fallback when SA namespace file is unavailable.
69+
DefaultControllerNamespace string = "opendatahub"
70+
71+
// ControllerPodLabelName is the label key used to identify the controller pod.
72+
// NetworkPolicy uses this to allow controller access to training pod metrics.
73+
ControllerPodLabelName string = "app.kubernetes.io/name"
74+
75+
// ControllerPodLabelNameValue is the expected value for the controller name label.
76+
// Must match the label applied to controller pods in deployment manifests.
77+
// RHOAI: Set via kustomization.yaml labels overlay.
78+
ControllerPodLabelNameValue string = "trainer"
79+
80+
// ControllerPodLabelComponent is the label key for component identification.
81+
ControllerPodLabelComponent string = "app.kubernetes.io/component"
82+
83+
// ControllerPodLabelComponentValue is the expected value for the controller component label.
84+
// RHOAI uses "controller" (set via kustomization.yaml).
85+
// Upstream Kubeflow uses "manager" (set in base/manager/manager.yaml).
86+
// This value must match your deployment's controller pod labels.
87+
ControllerPodLabelComponentValue string = "controller"
6388
)

pkg/rhai/networkpolicy.go

Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
/*
2+
Copyright 2024 The Kubeflow Authors.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
17+
package rhai
18+
19+
import (
20+
"context"
21+
"fmt"
22+
"os"
23+
"strconv"
24+
"strings"
25+
26+
corev1 "k8s.io/api/core/v1"
27+
networkingv1 "k8s.io/api/networking/v1"
28+
apierrors "k8s.io/apimachinery/pkg/api/errors"
29+
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
30+
"k8s.io/apimachinery/pkg/util/intstr"
31+
"k8s.io/klog/v2"
32+
"sigs.k8s.io/controller-runtime/pkg/client"
33+
34+
trainer "github.com/kubeflow/trainer/v2/pkg/apis/trainer/v1alpha1"
35+
"github.com/kubeflow/trainer/v2/pkg/rhai/constants"
36+
"github.com/kubeflow/trainer/v2/pkg/rhai/progression"
37+
)
38+
39+
const serviceAccountNamespaceFile = "/var/run/secrets/kubernetes.io/serviceaccount/namespace"
40+
41+
// getControllerNamespace returns the controller's namespace from SA mount.
42+
func getControllerNamespace() string {
43+
if data, err := os.ReadFile(serviceAccountNamespaceFile); err == nil {
44+
if ns := strings.TrimSpace(string(data)); ns != "" {
45+
return ns
46+
}
47+
}
48+
return constants.DefaultControllerNamespace
49+
}
50+
51+
func getNetworkPolicyName(trainJob *trainer.TrainJob) string {
52+
return trainJob.Name
53+
}
54+
55+
// buildNetworkPolicy creates a NetworkPolicy for the TrainJob's pods.
56+
// Rule 1 (same-job pods → all ports) is always added for pod isolation.
57+
// Rule 2 (controller → metrics port) is only added when progression tracking is enabled.
58+
func buildNetworkPolicy(trainJob *trainer.TrainJob) *networkingv1.NetworkPolicy {
59+
ingressRules := []networkingv1.NetworkPolicyIngressRule{}
60+
61+
// Rule 1: Same-job pods → all ports (always, for NCCL/MPI/gRPC)
62+
ingressRules = append(ingressRules, networkingv1.NetworkPolicyIngressRule{
63+
From: []networkingv1.NetworkPolicyPeer{
64+
{
65+
PodSelector: &metav1.LabelSelector{
66+
MatchLabels: map[string]string{
67+
"jobset.sigs.k8s.io/jobset-name": trainJob.Name,
68+
},
69+
},
70+
},
71+
},
72+
})
73+
74+
// Rule 2: Controller → metrics port (only when progression tracking enabled)
75+
if progression.IsProgressionTrackingEnabled(trainJob) {
76+
metricsPort := progression.GetMetricsPort(trainJob)
77+
portNum, err := strconv.Atoi(metricsPort)
78+
if err != nil {
79+
klog.Warningf("Invalid metrics port %q for TrainJob %s/%s, falling back to default %s",
80+
metricsPort, trainJob.Namespace, trainJob.Name, constants.DefaultMetricsPort)
81+
portNum, _ = strconv.Atoi(constants.DefaultMetricsPort)
82+
}
83+
port := intstr.FromInt(portNum)
84+
controllerNamespace := getControllerNamespace()
85+
86+
ingressRules = append(ingressRules, networkingv1.NetworkPolicyIngressRule{
87+
From: []networkingv1.NetworkPolicyPeer{
88+
{
89+
NamespaceSelector: &metav1.LabelSelector{
90+
MatchLabels: map[string]string{
91+
"kubernetes.io/metadata.name": controllerNamespace,
92+
},
93+
},
94+
PodSelector: &metav1.LabelSelector{
95+
MatchLabels: map[string]string{
96+
constants.ControllerPodLabelName: constants.ControllerPodLabelNameValue,
97+
constants.ControllerPodLabelComponent: constants.ControllerPodLabelComponentValue,
98+
},
99+
},
100+
},
101+
},
102+
Ports: []networkingv1.NetworkPolicyPort{
103+
{
104+
Protocol: protocolPtr(corev1.ProtocolTCP),
105+
Port: &port,
106+
},
107+
},
108+
})
109+
}
110+
111+
return &networkingv1.NetworkPolicy{
112+
ObjectMeta: metav1.ObjectMeta{
113+
Name: getNetworkPolicyName(trainJob),
114+
Namespace: trainJob.Namespace,
115+
Labels: map[string]string{
116+
"trainer.kubeflow.org/trainjob-name": trainJob.Name,
117+
"trainer.kubeflow.org/component": "network-policy",
118+
},
119+
OwnerReferences: []metav1.OwnerReference{
120+
{
121+
APIVersion: trainer.SchemeGroupVersion.String(),
122+
Kind: "TrainJob",
123+
Name: trainJob.Name,
124+
UID: trainJob.UID,
125+
Controller: boolPtr(true),
126+
BlockOwnerDeletion: boolPtr(true),
127+
},
128+
},
129+
},
130+
Spec: networkingv1.NetworkPolicySpec{
131+
PodSelector: metav1.LabelSelector{
132+
MatchLabels: map[string]string{
133+
"jobset.sigs.k8s.io/jobset-name": trainJob.Name,
134+
},
135+
},
136+
PolicyTypes: []networkingv1.PolicyType{
137+
networkingv1.PolicyTypeIngress,
138+
},
139+
Ingress: ingressRules,
140+
},
141+
}
142+
}
143+
144+
func boolPtr(b bool) *bool {
145+
return &b
146+
}
147+
148+
func protocolPtr(p corev1.Protocol) *corev1.Protocol {
149+
return &p
150+
}
151+
152+
// ReconcileNetworkPolicy creates/updates NetworkPolicy for the TrainJob.
153+
// Uses OwnerReference for automatic cleanup.
154+
func ReconcileNetworkPolicy(ctx context.Context, c client.Client, trainJob *trainer.TrainJob) error {
155+
desiredPolicy := buildNetworkPolicy(trainJob)
156+
existingPolicy := &networkingv1.NetworkPolicy{}
157+
err := c.Get(ctx, client.ObjectKey{
158+
Namespace: trainJob.Namespace,
159+
Name: getNetworkPolicyName(trainJob),
160+
}, existingPolicy)
161+
162+
if apierrors.IsNotFound(err) {
163+
if createErr := c.Create(ctx, desiredPolicy); createErr != nil {
164+
return fmt.Errorf("failed to create NetworkPolicy: %w", createErr)
165+
}
166+
return nil
167+
}
168+
169+
if err != nil {
170+
return fmt.Errorf("failed to get NetworkPolicy: %w", err)
171+
}
172+
173+
existingPolicy.Spec = desiredPolicy.Spec
174+
existingPolicy.Labels = desiredPolicy.Labels
175+
if updateErr := c.Update(ctx, existingPolicy); updateErr != nil {
176+
return fmt.Errorf("failed to update NetworkPolicy: %w", updateErr)
177+
}
178+
179+
return nil
180+
}

0 commit comments

Comments
 (0)