Skip to content

Commit d5e9d87

Browse files
authored
feat: refactoring the gpu scheduler (#56)
1 parent 785c9b7 commit d5e9d87

18 files changed

+582
-424
lines changed

cmd/operator/main.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ func main() {
154154

155155
ctx := context.Background()
156156

157-
scheduler := scheduler.NewNaiveScheduler()
157+
scheduler := scheduler.NewScheduler(mgr.GetClient())
158158
if err = (&controller.TensorFusionConnectionReconciler{
159159
Client: mgr.GetClient(),
160160
Scheme: mgr.GetScheme(),

go.mod

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ require (
1414
github.com/onsi/gomega v1.36.2
1515
github.com/prometheus/client_golang v1.20.5
1616
github.com/samber/lo v1.47.0
17+
github.com/shirou/gopsutil v3.21.11+incompatible
18+
github.com/stretchr/testify v1.10.0
1719
golang.org/x/exp v0.0.0-20241217172543-b2144cdd0a67
1820
gomodules.xyz/jsonpatch/v2 v2.4.0
1921
k8s.io/api v0.32.1
@@ -53,6 +55,7 @@ require (
5355
github.com/go-logr/logr v1.4.2 // indirect
5456
github.com/go-logr/stdr v1.2.2 // indirect
5557
github.com/go-logr/zapr v1.3.0 // indirect
58+
github.com/go-ole/go-ole v1.2.6 // indirect
5659
github.com/go-openapi/jsonpointer v0.21.0 // indirect
5760
github.com/go-openapi/jsonreference v0.21.0 // indirect
5861
github.com/go-openapi/swag v0.23.0 // indirect
@@ -85,16 +88,17 @@ require (
8588
github.com/opentracing/opentracing-go v1.2.1-0.20220228012449-10b1cf09e00b // indirect
8689
github.com/pelletier/go-toml/v2 v2.2.3 // indirect
8790
github.com/pkg/errors v0.9.1 // indirect
91+
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
8892
github.com/prometheus/client_model v0.6.1 // indirect
8993
github.com/prometheus/common v0.61.0 // indirect
9094
github.com/prometheus/procfs v0.15.1 // indirect
91-
github.com/shirou/gopsutil v3.21.11+incompatible // indirect
9295
github.com/spf13/cobra v1.8.1 // indirect
9396
github.com/spf13/pflag v1.0.5 // indirect
9497
github.com/stoewer/go-strcase v1.3.0 // indirect
9598
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
9699
github.com/ugorji/go/codec v1.2.12 // indirect
97100
github.com/x448/float16 v0.8.4 // indirect
101+
github.com/yusufpapurcu/wmi v1.2.4 // indirect
98102
go.opentelemetry.io/auto/sdk v1.1.0 // indirect
99103
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.58.0 // indirect
100104
go.opentelemetry.io/otel v1.33.0 // indirect

go.sum

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,8 @@ github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
7878
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
7979
github.com/go-logr/zapr v1.3.0 h1:XGdV8XW8zdwFiwOA2Dryh1gj2KRQyOOoNmBy4EplIcQ=
8080
github.com/go-logr/zapr v1.3.0/go.mod h1:YKepepNBd1u/oyhd/yQmtjVXmm9uML4IXUgMOwR8/Gg=
81+
github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY=
82+
github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0=
8183
github.com/go-openapi/jsonpointer v0.21.0 h1:YgdVicSA9vH5RiHs9TZW5oyafXZFc6+2Vc1rr/O9oNQ=
8284
github.com/go-openapi/jsonpointer v0.21.0/go.mod h1:IUyH9l/+uyhIYQ/PXVA41Rexl+kOkAPDdXEYns6fzUY=
8385
github.com/go-openapi/jsonreference v0.21.0 h1:Rs+Y7hSXT83Jacb7kFyjn4ijOuVGSvOdF2+tg1TRrwQ=
@@ -218,6 +220,8 @@ github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM=
218220
github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg=
219221
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
220222
github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
223+
github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0=
224+
github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0=
221225
go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA=
222226
go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A=
223227
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.58.0 h1:yd02MEjBdJkG3uabWP9apV+OuWRIXGDuJEUJbOHmCFU=
@@ -284,6 +288,7 @@ golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
284288
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
285289
golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
286290
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
291+
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
287292
golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
288293
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
289294
golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA=

internal/constants/constants.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,9 @@ const (
2323

2424
GPULastReportTimeAnnotationKey = Domain + "/last-sync"
2525

26+
GpuPoolKey = Domain + "/gpupool"
27+
2628
// Annotation key constants
27-
GpuPoolAnnotationKey = Domain + "/gpupool"
2829
// %s -> container_name
2930
TFLOPSRequestAnnotationFormat = Domain + "/tflops-request-%s"
3031
VRAMRequestAnnotationFormat = Domain + "/vram-request-%s"

internal/controller/gpu_controller.go

Lines changed: 61 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,18 @@ package controller
1818

1919
import (
2020
"context"
21+
"fmt"
22+
"strings"
2123

24+
tfv1 "github.com/NexusGPU/tensor-fusion-operator/api/v1"
25+
"github.com/NexusGPU/tensor-fusion-operator/internal/constants"
26+
scheduler "github.com/NexusGPU/tensor-fusion-operator/internal/scheduler"
27+
"github.com/samber/lo"
28+
"k8s.io/apimachinery/pkg/api/errors"
29+
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
2230
"k8s.io/apimachinery/pkg/runtime"
2331
ctrl "sigs.k8s.io/controller-runtime"
2432
"sigs.k8s.io/controller-runtime/pkg/client"
25-
"sigs.k8s.io/controller-runtime/pkg/event"
26-
"sigs.k8s.io/controller-runtime/pkg/predicate"
27-
28-
tfv1 "github.com/NexusGPU/tensor-fusion-operator/api/v1"
29-
scheduler "github.com/NexusGPU/tensor-fusion-operator/internal/scheduler"
3033
)
3134

3235
// GPUReconciler reconciles a GPU object
@@ -43,6 +46,59 @@ type GPUReconciler struct {
4346
// Reconcile is part of the main kubernetes reconciliation loop which aims to
4447
// move the current state of the cluster closer to the desired state.
4548
func (r *GPUReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ctrl.Result, error) {
49+
gpu := &tfv1.GPU{}
50+
if err := r.Get(ctx, req.NamespacedName, gpu); err != nil {
51+
if errors.IsNotFound(err) {
52+
return ctrl.Result{}, nil
53+
}
54+
return ctrl.Result{}, err
55+
}
56+
57+
kgvs, _, err := r.Scheme.ObjectKinds(&tfv1.GPUNode{})
58+
if err != nil {
59+
return ctrl.Result{}, fmt.Errorf("get object kinds for GPUNode: %w", err)
60+
}
61+
62+
owner, ok := lo.Find(gpu.OwnerReferences, func(or metav1.OwnerReference) bool {
63+
for _, kvg := range kgvs {
64+
if kvg.Kind == or.Kind && fmt.Sprintf("%s/%s", kvg.Group, kvg.Version) == or.APIVersion {
65+
return true
66+
}
67+
}
68+
return false
69+
})
70+
71+
if !ok {
72+
return ctrl.Result{}, fmt.Errorf("owner node %s not found", gpu.Name)
73+
}
74+
75+
gpunode := &tfv1.GPUNode{}
76+
if err := r.Get(ctx, client.ObjectKey{Name: owner.Name}, gpunode); err != nil {
77+
return ctrl.Result{}, fmt.Errorf("get node %s: %w", owner.Name, err)
78+
}
79+
80+
var poolName string
81+
for labelKey := range gpunode.Labels {
82+
after, ok := strings.CutPrefix(labelKey, constants.GPUNodePoolIdentifierLabelPrefix)
83+
if ok {
84+
poolName = after
85+
break
86+
}
87+
}
88+
89+
if poolName == "" {
90+
return ctrl.Result{}, fmt.Errorf("node %s is not assigned to any pool", gpunode.Name)
91+
}
92+
93+
if gpu.Labels == nil {
94+
gpu.Labels = make(map[string]string)
95+
}
96+
gpu.Labels[constants.GpuPoolKey] = poolName
97+
98+
// update gpu
99+
if err := r.Update(ctx, gpu); err != nil {
100+
return ctrl.Result{}, fmt.Errorf("update gpu %s: %w", gpu.Name, err)
101+
}
46102
return ctrl.Result{}, nil
47103
}
48104

@@ -51,21 +107,5 @@ func (r *GPUReconciler) SetupWithManager(ctx context.Context, mgr ctrl.Manager)
51107
return ctrl.NewControllerManagedBy(mgr).
52108
For(&tfv1.GPU{}).
53109
Named("gpu").
54-
WithEventFilter(
55-
predicate.Funcs{
56-
CreateFunc: func(e event.CreateEvent) bool {
57-
r.Scheduler.OnAdd(e.Object.(*tfv1.GPU))
58-
return true
59-
},
60-
UpdateFunc: func(e event.UpdateEvent) bool {
61-
r.Scheduler.OnUpdate(e.ObjectOld.(*tfv1.GPU), e.ObjectNew.(*tfv1.GPU))
62-
return true
63-
},
64-
DeleteFunc: func(e event.DeleteEvent) bool {
65-
r.Scheduler.OnDelete(e.Object.(*tfv1.GPU))
66-
return true
67-
},
68-
},
69-
).
70110
Complete(r)
71111
}

internal/controller/gpu_controller_test.go

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,17 @@ package controller
1818

1919
import (
2020
"context"
21+
"fmt"
2122

2223
tfv1 "github.com/NexusGPU/tensor-fusion-operator/api/v1"
24+
"github.com/NexusGPU/tensor-fusion-operator/internal/constants"
2325
. "github.com/onsi/ginkgo/v2"
2426
. "github.com/onsi/gomega"
2527
"k8s.io/apimachinery/pkg/api/errors"
2628
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
2729
"k8s.io/apimachinery/pkg/types"
30+
"k8s.io/client-go/kubernetes/scheme"
31+
"sigs.k8s.io/controller-runtime/pkg/controller/controllerutil"
2832
"sigs.k8s.io/controller-runtime/pkg/reconcile"
2933
)
3034

@@ -36,11 +40,23 @@ var _ = Describe("GPU Controller", func() {
3640

3741
typeNamespacedName := types.NamespacedName{
3842
Name: resourceName,
39-
Namespace: "default", // TODO(user):Modify as needed
43+
Namespace: "default",
4044
}
4145
gpu := &tfv1.GPU{}
46+
gpunode := &tfv1.GPUNode{}
4247

4348
BeforeEach(func() {
49+
By("creating the custom resource for the Kind GPUNode")
50+
gpunode = &tfv1.GPUNode{
51+
ObjectMeta: metav1.ObjectMeta{
52+
Name: resourceName + "-node",
53+
Labels: map[string]string{
54+
fmt.Sprintf(constants.GPUNodePoolIdentifierLabelFormat, "mock"): "true",
55+
},
56+
},
57+
}
58+
Expect(k8sClient.Create(ctx, gpunode)).To(Succeed())
59+
4460
By("creating the custom resource for the Kind GPU")
4561
err := k8sClient.Get(ctx, typeNamespacedName, gpu)
4662
if err != nil && errors.IsNotFound(err) {
@@ -49,20 +65,22 @@ var _ = Describe("GPU Controller", func() {
4965
Name: resourceName,
5066
Namespace: "default",
5167
},
52-
// TODO(user): Specify other spec details if needed.
5368
}
69+
Expect(controllerutil.SetControllerReference(gpunode, resource, scheme.Scheme)).To(Succeed())
5470
Expect(k8sClient.Create(ctx, resource)).To(Succeed())
5571
}
5672
})
5773

5874
AfterEach(func() {
59-
// TODO(user): Cleanup logic after each test, like removing the resource instance.
6075
resource := &tfv1.GPU{}
6176
err := k8sClient.Get(ctx, typeNamespacedName, resource)
6277
Expect(err).NotTo(HaveOccurred())
6378

6479
By("Cleanup the specific resource instance GPU")
6580
Expect(k8sClient.Delete(ctx, resource)).To(Succeed())
81+
82+
By("Cleanup the specific resource instance GPUNode")
83+
Expect(k8sClient.Delete(ctx, gpunode)).To(Succeed())
6684
})
6785
It("should successfully reconcile the resource", func() {
6886
By("Reconciling the created resource")
@@ -75,8 +93,6 @@ var _ = Describe("GPU Controller", func() {
7593
NamespacedName: typeNamespacedName,
7694
})
7795
Expect(err).NotTo(HaveOccurred())
78-
// TODO(user): Add more specific assertions depending on your controller's reconciliation logic.
79-
// Example: If you expect a certain status condition after reconciliation, verify it here.
8096
})
8197
})
8298
})

internal/controller/tensorfusionconnection_controller.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ func (r *TensorFusionConnectionReconciler) Reconcile(ctx context.Context, req ct
8181
if connection.Status.Phase == "" || connection.Status.Phase == tfv1.TensorFusionConnectionPending {
8282
// Try to get an available gpu from scheduler
8383
var err error
84-
gpu, err = r.Scheduler.Schedule(connection.Spec.Resources.Requests)
84+
gpu, err = r.Scheduler.Schedule(ctx, connection.Spec.PoolName, connection.Spec.Resources.Requests)
8585
if err != nil {
8686
log.Error(err, "Failed to schedule gpu instance")
8787
connection.Status.Phase = tfv1.TensorFusionConnectionPending
@@ -186,7 +186,7 @@ func (r *TensorFusionConnectionReconciler) handleDeletion(ctx context.Context, c
186186
}
187187

188188
// Release the resources
189-
if err := r.Scheduler.Release(connection.Spec.Resources.Requests, gpu); err != nil {
189+
if err := r.Scheduler.Release(ctx, connection.Spec.Resources.Requests, gpu); err != nil {
190190
return false, err
191191
}
192192

internal/controller/tensorfusionconnection_controller_test.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,12 @@ var _ = Describe("TensorFusionConnection Controller", func() {
4444
Name: resourceName,
4545
Namespace: "default",
4646
}
47-
scheduler := scheduler.NewNaiveScheduler()
4847
gpu := &tfv1.GPU{
4948
ObjectMeta: metav1.ObjectMeta{
5049
Name: "mock-gpu",
50+
Labels: map[string]string{
51+
constants.GpuPoolKey: "mock",
52+
},
5153
},
5254
}
5355
BeforeEach(func() {
@@ -77,7 +79,6 @@ var _ = Describe("TensorFusionConnection Controller", func() {
7779
Expect(k8sClient.Create(ctx, resource)).To(Succeed())
7880
}
7981

80-
scheduler.OnAdd(gpu)
8182
Expect(k8sClient.Create(ctx, gpu)).To(Succeed())
8283
gpu.Status = tfv1.GPUStatus{
8384
Phase: tfv1.TensorFusionGPUPhaseRunning,
@@ -111,7 +112,7 @@ var _ = Describe("TensorFusionConnection Controller", func() {
111112
controllerReconciler := &TensorFusionConnectionReconciler{
112113
Client: k8sClient,
113114
Scheme: k8sClient.Scheme(),
114-
Scheduler: scheduler,
115+
Scheduler: scheduler.NewScheduler(k8sClient),
115116
}
116117
_, err := controllerReconciler.Reconcile(ctx, reconcile.Request{
117118
NamespacedName: typeNamespacedName,

internal/scheduler/filter.go

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
package scheduler
2+
3+
import (
4+
"context"
5+
6+
tfv1 "github.com/NexusGPU/tensor-fusion-operator/api/v1"
7+
)
8+
9+
// GPUFilter defines an interface for filtering GPU candidates
10+
type GPUFilter interface {
11+
// Filter filters the list of GPUs and returns only those that pass the filter criteria
12+
// The implementation should not modify the input slice
13+
Filter(ctx context.Context, gpus []tfv1.GPU) ([]tfv1.GPU, error)
14+
}
15+
16+
// FilterRegistry provides an immutable collection of GPU filters
17+
// with methods to create new instances with additional filters
18+
type FilterRegistry struct {
19+
parent *FilterRegistry // Reference to parent registry
20+
filters []GPUFilter // Only contains filters added at this level
21+
}
22+
23+
// NewFilterRegistry creates a new empty filter registry
24+
func NewFilterRegistry() *FilterRegistry {
25+
return &FilterRegistry{
26+
parent: nil,
27+
filters: []GPUFilter{},
28+
}
29+
}
30+
31+
// With creates a new FilterRegistry with the provided filters added
32+
// The original FilterRegistry is not modified
33+
func (fr *FilterRegistry) With(filters ...GPUFilter) *FilterRegistry {
34+
if len(filters) == 0 {
35+
return fr
36+
}
37+
38+
// Create a new registry with the current one as parent
39+
return &FilterRegistry{
40+
parent: fr,
41+
filters: filters,
42+
}
43+
}
44+
45+
// Apply applies the filters in this registry to the given GPU list
46+
// Filters are applied in the order they were added (parent filters first)
47+
func (fr *FilterRegistry) Apply(ctx context.Context, gpus []tfv1.GPU) ([]tfv1.GPU, error) {
48+
// First apply parent filters (if any)
49+
filteredGPUs := gpus
50+
var err error
51+
52+
if fr.parent != nil {
53+
filteredGPUs, err = fr.parent.Apply(ctx, filteredGPUs)
54+
if err != nil {
55+
return nil, err
56+
}
57+
58+
// If no GPUs left after parent filtering, return early
59+
if len(filteredGPUs) == 0 {
60+
return filteredGPUs, nil
61+
}
62+
}
63+
64+
// Then apply filters at this level
65+
for _, filter := range fr.filters {
66+
filteredGPUs, err = filter.Filter(ctx, filteredGPUs)
67+
if err != nil {
68+
return nil, err
69+
}
70+
71+
// If no GPUs left after filtering, return early
72+
if len(filteredGPUs) == 0 {
73+
return filteredGPUs, nil
74+
}
75+
}
76+
77+
return filteredGPUs, nil
78+
}

0 commit comments

Comments
 (0)