Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 19 additions & 4 deletions slice/cmd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ import (

"tpu-slice-controller/api/v1beta1"
"tpu-slice-controller/internal/controller"
"tpu-slice-controller/internal/core"
"tpu-slice-controller/internal/util/cert"
"tpu-slice-controller/internal/webhooks"

Expand Down Expand Up @@ -74,6 +75,7 @@ func main() {
var probeAddr string
var secureMetrics bool
var enableHTTP2 bool
var sliceHealthNodeAffinityMode string
var tlsOpts []func(*tls.Config)
flag.StringVar(&metricsAddr, "metrics-bind-address", "0", "The address the metrics endpoint binds to. "+
"Use :8443 for HTTPS or :8080 for HTTP, or leave as 0 to disable the metrics service.")
Expand All @@ -95,12 +97,25 @@ func main() {
flag.StringVar(&metricsCertKey, "metrics-cert-key", "tls.key", "The name of the metrics server key file.")
flag.BoolVar(&enableHTTP2, "enable-http2", false,
"If set, HTTP/2 will be enabled for the metrics and webhook servers")
flag.StringVar(&sliceHealthNodeAffinityMode, "default-slice-health-node-affinity", "HEALTHY",
"Default slice health node affinity. Possible values are HEALTHY or HEALTHY_AND_DEGRADED.")
opts := zap.Options{
Development: true,
}
opts.BindFlags(flag.CommandLine)
flag.Parse()

var sliceHealthValues []string
switch sliceHealthNodeAffinityMode {
case "HEALTHY":
sliceHealthValues = []string{core.TPUSliceHealthNodeSelectorHealthy}
case "HEALTHY_AND_DEGRADED":
sliceHealthValues = []string{core.TPUSliceHealthNodeSelectorHealthy, core.TPUSliceHealthNodeSelectorDegraded}
default:
setupLog.Error(errors.New("invalid flag value"), "Invalid value for default-slice-health-node-affinity", "value", sliceHealthNodeAffinityMode)
os.Exit(1)
}

ctrl.SetLogger(zap.New(zap.UseFlagOptions(&opts)))

// if the enable-http2 flag is false (the default), http/2 should be disabled
Expand Down Expand Up @@ -244,7 +259,7 @@ func main() {
os.Exit(1)
}

go setupControllers(mgr, certsReady, activationTimeout, retryDelayOnSliceFailure)
go setupControllers(mgr, certsReady, activationTimeout, retryDelayOnSliceFailure, sliceHealthValues)

setupProbeEndpoints(mgr, certsReady)

Expand All @@ -255,17 +270,17 @@ func main() {
}
}

func setupControllers(mgr ctrl.Manager, certsReady chan struct{}, activationTimeout time.Duration, retryDelay time.Duration) {
func setupControllers(mgr ctrl.Manager, certsReady chan struct{}, activationTimeout time.Duration, retryDelay time.Duration, sliceHealthValues []string) {
// The controllers won't work until the webhooks are operating, and the webhook won't work until the
// certs are all in place.
cert.WaitForCertsReady(setupLog, certsReady)

// Register the webhooks
if err := webhooks.SetupWebhookWithManager(mgr); err != nil {
if err := webhooks.SetupWebhookWithManager(mgr, sliceHealthValues); err != nil {
setupLog.Error(err, "Unable to create webhook", "webhook", "JobSet")
os.Exit(1)
}
if err := webhooks.SetupJobWebhookWithManager(mgr); err != nil {
if err := webhooks.SetupJobWebhookWithManager(mgr, sliceHealthValues); err != nil {
setupLog.Error(err, "Unable to create webhook", "webhook", "Job")
os.Exit(1)
}
Expand Down
4 changes: 2 additions & 2 deletions slice/internal/webhooks/defaults.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ func getTPUsRequestedPerPod(spec corev1.PodSpec) int64 {
return totalTPUs
}

func annotatePodTemplateSpecWithSliceHealth(template *corev1.PodTemplateSpec) {
func annotatePodTemplateSpecWithSliceHealth(template *corev1.PodTemplateSpec, defaultSliceHealthValues []string) {
// 1. If there is NodeSelector with TPUSliceHealthNodeSelectorKey, we do nothing.
if _, ok := template.Spec.NodeSelector[core.TPUSliceHealthNodeSelectorKey]; ok {
return
Expand All @@ -58,7 +58,7 @@ func annotatePodTemplateSpecWithSliceHealth(template *corev1.PodTemplateSpec) {
}

// 3. If neither of these, we add a NodeAffinity.
core.AddNodeAffinity(template, core.TPUSliceHealthNodeSelectorKey, []string{core.TPUSliceHealthNodeSelectorHealthy})
core.AddNodeAffinity(template, core.TPUSliceHealthNodeSelectorKey, defaultSliceHealthValues)
}

func annotatePodTemplateSpecWithTopology(template *corev1.PodTemplateSpec, parallelism *int32, resourceName string, resourceKind string) error {
Expand Down
12 changes: 8 additions & 4 deletions slice/internal/webhooks/job_webhook.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,16 @@ import (
"tpu-slice-controller/internal/core"
)

type JobWebhook struct{}
type JobWebhook struct {
DefaultSliceHealthValues []string
}

func SetupJobWebhookWithManager(mgr ctrl.Manager) error {
func SetupJobWebhookWithManager(mgr ctrl.Manager, defaultSliceHealthValues []string) error {
return ctrl.NewWebhookManagedBy(mgr).
For(&batchv1.Job{}).
WithDefaulter(&JobWebhook{}).
WithDefaulter(&JobWebhook{
DefaultSliceHealthValues: defaultSliceHealthValues,
}).
Complete()
}

Expand All @@ -56,7 +60,7 @@ func (r *JobWebhook) Default(ctx context.Context, obj runtime.Object) error {
return nil
}
log.V(5).Info("Annotating Job")
annotatePodTemplateSpecWithSliceHealth(&job.Spec.Template)
annotatePodTemplateSpecWithSliceHealth(&job.Spec.Template, r.DefaultSliceHealthValues)
err := annotatePodTemplateSpecWithTopology(&job.Spec.Template, job.Spec.Parallelism, job.Name, "job")
if err != nil {
return err
Expand Down
4 changes: 3 additions & 1 deletion slice/internal/webhooks/job_webhook_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,9 @@ func TestJobDefault(t *testing.T) {
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
ctx := t.Context()
webhook := &JobWebhook{}
webhook := &JobWebhook{
DefaultSliceHealthValues: []string{core.TPUSliceHealthNodeSelectorHealthy},
}

gotErr := webhook.Default(ctx, tc.job)
if diff := cmp.Diff(tc.wantErr, gotErr, utiltesting.EquateErrors); diff != "" {
Expand Down
12 changes: 8 additions & 4 deletions slice/internal/webhooks/jobset_webhook.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,16 @@ import (
)

// JobSetWebhook is the schema for your resource (ensure this matches your resource definition).
type JobSetWebhook struct{}
type JobSetWebhook struct {
DefaultSliceHealthValues []string
}

func SetupWebhookWithManager(mgr ctrl.Manager) error {
func SetupWebhookWithManager(mgr ctrl.Manager, defaultSliceHealthValues []string) error {
return ctrl.NewWebhookManagedBy(mgr).
For(&v1alpha2.JobSet{}).
WithDefaulter(&JobSetWebhook{}).
WithDefaulter(&JobSetWebhook{
DefaultSliceHealthValues: defaultSliceHealthValues,
}).
Complete()
}

Expand All @@ -59,7 +63,7 @@ func (r *JobSetWebhook) Default(ctx context.Context, obj runtime.Object) error {
continue
}
log.V(5).Info("Annotating ReplicatedJob")
annotatePodTemplateSpecWithSliceHealth(&rj.Template.Spec.Template)
annotatePodTemplateSpecWithSliceHealth(&rj.Template.Spec.Template, r.DefaultSliceHealthValues)
err := annotatePodTemplateSpecWithTopology(&rj.Template.Spec.Template, rj.Template.Spec.Parallelism, rj.Name, "replicated job")
if err != nil {
return err
Expand Down
46 changes: 42 additions & 4 deletions slice/internal/webhooks/jobset_webhook_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,10 @@ func TestDefault(t *testing.T) {
)

testCases := map[string]struct {
jobSet *jobset.JobSet
wantJobSet *jobset.JobSet
wantErr error
defaultSliceHealthValues []string
jobSet *jobset.JobSet
wantJobSet *jobset.JobSet
wantErr error
}{
"no queue label": {
jobSet: testingjobjobset.MakeJobSet(baseJobSetName, utils.DefaultNamespace).
Expand Down Expand Up @@ -111,6 +112,7 @@ func TestDefault(t *testing.T) {
Obj(),
},
"should set default values": {
defaultSliceHealthValues: []string{core.TPUSliceHealthNodeSelectorHealthy},
jobSet: testingjobjobset.MakeJobSet(baseJobSetName, utils.DefaultNamespace).
Queue("queue-name").
ReplicatedJobs(testingjobjobset.ReplicatedJobRequirements{
Expand Down Expand Up @@ -143,6 +145,40 @@ func TestDefault(t *testing.T) {
RequestAndLimit("rj1", core.TPUResourceName, "4").
Obj(),
},
"should set default values including DEGRADED cube health": {
defaultSliceHealthValues: []string{core.TPUSliceHealthNodeSelectorHealthy, core.TPUSliceHealthNodeSelectorDegraded},
jobSet: testingjobjobset.MakeJobSet(baseJobSetName, utils.DefaultNamespace).
Queue("queue-name").
ReplicatedJobs(testingjobjobset.ReplicatedJobRequirements{
Name: "rj1",
Parallelism: 48,
PodAnnotations: map[string]string{
core.TPUSliceTopologyAnnotation: "4x4x12",
},
NodeSelector: map[string]string{
"cloud.google.com/gke-tpu-accelerator": string(slice.TypeTpu7x),
},
}).
RequestAndLimit("rj1", core.TPUResourceName, "4").
Obj(),
wantJobSet: testingjobjobset.MakeJobSet(baseJobSetName, utils.DefaultNamespace).
Queue("queue-name").
ReplicatedJobs(testingjobjobset.ReplicatedJobRequirements{
Name: "rj1",
Parallelism: 48,
PodAnnotations: map[string]string{
core.TPUSliceTopologyAnnotation: "4x4x12",
"kueue.x-k8s.io/podset-required-topology": "cloud.google.com/gce-topology-block",
"kueue.x-k8s.io/podset-slice-required-topology": core.TPUSubBlockLabel,
"kueue.x-k8s.io/podset-slice-size": "16",
},
NodeSelector: map[string]string{
"cloud.google.com/gke-tpu-accelerator": string(slice.TypeTpu7x),
},
}).NodeAffinity("rj1", core.TPUSliceHealthNodeSelectorKey, []string{core.TPUSliceHealthNodeSelectorHealthy, core.TPUSliceHealthNodeSelectorDegraded}).
RequestAndLimit("rj1", core.TPUResourceName, "4").
Obj(),
},
"shouldn't set default values because invalid topology annotation": {
jobSet: testingjobjobset.MakeJobSet(baseJobSetName, utils.DefaultNamespace).
Queue("queue-name").
Expand Down Expand Up @@ -308,7 +344,9 @@ func TestDefault(t *testing.T) {
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
ctx := t.Context()
webhook := &JobSetWebhook{}
webhook := &JobSetWebhook{
DefaultSliceHealthValues: tc.defaultSliceHealthValues,
}

gotErr := webhook.Default(ctx, tc.jobSet)
if diff := cmp.Diff(tc.wantErr, gotErr, utiltesting.EquateErrors); diff != "" {
Expand Down
Loading