Skip to content

Commit aa52ddf

Browse files
ayushr2gvisor-bot
authored andcommitted
state: Maintain a separate list with leaf nodes for objectDecodeState.
Profiling shows that we spend a LOT of time in decodeState.Load() in firing all callbacks in the right order. objectDecodeState is organized as a DAG. Callbacks need to be fired on leaf nodes (where blockedBy==0) and then fired on internal nodes only when all their dependencies have executed their callbacks. The earlier approach was to continuously scan a HUGE list of pending nodes searching for a leaf node, fire callbacks on that, check if any of its dependents became leaves and then scan the full list again. We were wasting a lot of time re-evaluating internal nodes repeatedly. For instance, the kernel object created by checkpointing SGLang server running "gemma-3-27b-pt" had 368,245 objects. The whole pending list was scanned 238,102 times. We spent 2 minutes 10 seconds in just executing callbacks. The total kernel load time was 2 minutes 30 seconds. In this change, we track leaf nodes in a separate linked list. We keep appending to it as and when new nodes become leaves. This way we don't need to search for leaves. This improved kernel load time to 12 seconds (-92%). PiperOrigin-RevId: 775402687
1 parent b2e4ea0 commit aa52ddf

File tree

2 files changed

+60
-54
lines changed

2 files changed

+60
-54
lines changed

pkg/state/BUILD

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,14 @@ go_template_instance(
1919
)
2020

2121
go_template_instance(
22-
name = "complete_list",
23-
out = "complete_list.go",
22+
name = "ods_list",
23+
out = "ods_list.go",
2424
package = "state",
25-
prefix = "complete",
25+
prefix = "ods",
2626
template = "//pkg/ilist:generic_list",
2727
types = {
28-
"Element": "*objectDecodeState",
29-
"Linker": "*objectDecodeState",
28+
"Element": "*odsListElem",
29+
"Linker": "*odsListElem",
3030
},
3131
)
3232

@@ -66,12 +66,12 @@ go_library(
6666
srcs = [
6767
"addr_range.go",
6868
"addr_set.go",
69-
"complete_list.go",
7069
"decode.go",
7170
"decode_unsafe.go",
7271
"deferred_list.go",
7372
"encode.go",
7473
"encode_unsafe.go",
74+
"ods_list.go",
7575
"state.go",
7676
"state_norace.go",
7777
"state_race.go",

pkg/state/decode.go

Lines changed: 54 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ type internalCallback interface {
3232
source() *objectDecodeState
3333

3434
// callbackRun executes the callback.
35-
callbackRun()
35+
callbackRun(ds *decodeState)
3636
}
3737

3838
// userCallback is an implementation of internalCallback.
@@ -44,7 +44,7 @@ func (userCallback) source() *objectDecodeState {
4444
}
4545

4646
// callbackRun implements internalCallback.callbackRun.
47-
func (uc userCallback) callbackRun() {
47+
func (uc userCallback) callbackRun(*decodeState) {
4848
uc()
4949
}
5050

@@ -84,7 +84,13 @@ type objectDecodeState struct {
8484
// callbacks is a set of callbacks to execute on load.
8585
callbacks []internalCallback
8686

87-
completeEntry
87+
pendingEntry odsListElem
88+
leafEntry odsListElem
89+
}
90+
91+
type odsListElem struct {
92+
ods *objectDecodeState
93+
odsEntry
8894
}
8995

9096
// addCallback adds a callback to the objectDecodeState.
@@ -122,8 +128,13 @@ func (ods *objectDecodeState) source() *objectDecodeState {
122128
}
123129

124130
// callbackRun implements internalCallback.callbackRun.
125-
func (ods *objectDecodeState) callbackRun() {
131+
func (ods *objectDecodeState) callbackRun(ds *decodeState) {
126132
ods.blockedBy--
133+
if ods.blockedBy == 0 {
134+
ds.leaves.PushBack(&ods.leafEntry)
135+
} else if ods.blockedBy < 0 {
136+
Failf("object %d has negative blockedBy: %d", ods.id, ods.blockedBy)
137+
}
127138
}
128139

129140
// decodeState is a graph of objects in the process of being decoded.
@@ -155,7 +166,11 @@ type decodeState struct {
155166
deferred map[objectID]wire.Object
156167

157168
// pending is the set of objects that are not yet complete.
158-
pending completeList
169+
pending odsList
170+
171+
// leaves is the set of objects that have no dependencies (blockedBy == 0).
172+
// leaves are consumed from the front and appended to the back.
173+
leaves odsList
159174

160175
// stats tracks time data.
161176
stats Stats
@@ -185,20 +200,12 @@ func (ds *decodeState) checkComplete(ods *objectDecodeState) bool {
185200

186201
// Fire all callbacks.
187202
for _, ic := range ods.callbacks {
188-
ic.callbackRun()
203+
ic.callbackRun(ds)
189204
}
190205

191206
// Mark completed.
192-
cbs := ods.callbacks
193207
ods.callbacks = nil
194-
ds.pending.Remove(ods)
195-
196-
// Recursively check others.
197-
for _, ic := range cbs {
198-
if other := ic.source(); other != nil && other.blockedBy == 0 {
199-
ds.checkComplete(other)
200-
}
201-
}
208+
ds.pending.Remove(&ods.pendingEntry)
202209

203210
return true // All set.
204211
}
@@ -223,6 +230,9 @@ func (ds *decodeState) wait(waiter *objectDecodeState, id objectID, callback fun
223230

224231
// Mark as blocked.
225232
waiter.blockedBy++
233+
if waiter.blockedBy == 1 {
234+
ds.leaves.Remove(&waiter.leafEntry)
235+
}
226236

227237
// No nil can be returned here.
228238
other := ds.lookup(id)
@@ -280,6 +290,26 @@ func walkChild(path []wire.Dot, obj reflect.Value) reflect.Value {
280290
return obj
281291
}
282292

293+
func (ds *decodeState) growObjectsByID(id objectID) {
294+
if len(ds.objectsByID) < int(id) {
295+
ds.objectsByID = append(ds.objectsByID, make([]*objectDecodeState, int(id)-len(ds.objectsByID))...)
296+
}
297+
}
298+
299+
func (ds *decodeState) addObject(id objectID, obj reflect.Value) *objectDecodeState {
300+
ods := &objectDecodeState{
301+
id: id,
302+
obj: obj,
303+
}
304+
ods.pendingEntry.ods = ods
305+
ods.leafEntry.ods = ods
306+
ds.growObjectsByID(id)
307+
ds.objectsByID[id-1] = ods
308+
ds.pending.PushBack(&ods.pendingEntry)
309+
ds.leaves.PushBack(&ods.leafEntry)
310+
return ods
311+
}
312+
283313
// register registers a decode with a type.
284314
//
285315
// This type is only used to instantiate a new object if it has not been
@@ -288,11 +318,9 @@ func walkChild(path []wire.Dot, obj reflect.Value) reflect.Value {
288318
func (ds *decodeState) register(r *wire.Ref, typ reflect.Type) reflect.Value {
289319
// Grow the objectsByID slice.
290320
id := objectID(r.Root)
291-
if len(ds.objectsByID) < int(id) {
292-
ds.objectsByID = append(ds.objectsByID, make([]*objectDecodeState, int(id)-len(ds.objectsByID))...)
293-
}
294321

295322
// Does this object already exist?
323+
ds.growObjectsByID(id)
296324
ods := ds.objectsByID[id-1]
297325
if ods != nil {
298326
return walkChild(r.Dots, ods.obj)
@@ -303,12 +331,7 @@ func (ds *decodeState) register(r *wire.Ref, typ reflect.Type) reflect.Value {
303331
typ = ds.findType(r.Type)
304332
}
305333
v := reflect.New(typ)
306-
ods = &objectDecodeState{
307-
id: id,
308-
obj: v.Elem(),
309-
}
310-
ds.objectsByID[id-1] = ods
311-
ds.pending.PushBack(ods)
334+
ods = ds.addObject(id, v.Elem())
312335

313336
// Process any deferred objects & callbacks.
314337
if encoded, ok := ds.deferred[id]; ok {
@@ -581,13 +604,8 @@ func (ds *decodeState) Load(obj reflect.Value) {
581604
return ds.types.LookupName(id)
582605
})
583606

584-
// Create the root object.
585-
rootOds := &objectDecodeState{
586-
id: 1,
587-
obj: obj,
588-
}
589-
ds.objectsByID = append(ds.objectsByID, rootOds)
590-
ds.pending.PushBack(rootOds)
607+
// Add the root object with ID 1.
608+
_ = ds.addObject(1, obj)
591609

592610
// Read the number of objects.
593611
numObjects, object, err := ReadHeader(&ds.r)
@@ -603,7 +621,6 @@ func (ds *decodeState) Load(obj reflect.Value) {
603621
encoded wire.Object
604622
ods *objectDecodeState
605623
id objectID
606-
tid = typeID(1)
607624
)
608625
if err := safely(func() {
609626
// Decode all objects in the stream.
@@ -616,7 +633,6 @@ func (ds *decodeState) Load(obj reflect.Value) {
616633
switch we := encoded.(type) {
617634
case *wire.Type:
618635
ds.types.Register(we)
619-
tid++
620636
encoded = nil
621637
continue
622638
case wire.Uint:
@@ -673,32 +689,22 @@ func (ds *decodeState) Load(obj reflect.Value) {
673689
// objects become complete (there is a dependency cycle).
674690
//
675691
// Note that we iterate backwards here, because there will be a strong
676-
// tendendcy for blocking relationships to go from earlier objects to
692+
// tendency for blocking relationships to go from earlier objects to
677693
// later (deeper) objects in the graph. This will reduce the number of
678694
// iterations required to finish all objects.
679695
if err := safely(func() {
680-
for ds.pending.Back() != nil {
681-
thisCycle := false
682-
for ods = ds.pending.Back(); ods != nil; {
683-
if ds.checkComplete(ods) {
684-
thisCycle = true
685-
break
686-
}
687-
ods = ods.Prev()
688-
}
689-
if !thisCycle {
690-
break
691-
}
696+
for elem := ds.leaves.Front(); elem != nil; elem = elem.Next() {
697+
ds.checkComplete(elem.ods)
692698
}
693699
}); err != nil {
694700
Failf("error executing callbacks: %w\nfor object %#v", err, ods.obj.Interface())
695701
}
696702

697703
// Check if we have any remaining dependency cycles. If there are any
698704
// objects left in the pending list, then it must be due to a cycle.
699-
if ods := ds.pending.Front(); ods != nil {
705+
if elem := ds.pending.Front(); elem != nil {
700706
// This must be the result of a dependency cycle.
701-
cycle := ods.findCycle()
707+
cycle := elem.ods.findCycle()
702708
var buf bytes.Buffer
703709
buf.WriteString("dependency cycle: {")
704710
for i, cycleOS := range cycle {

0 commit comments

Comments
 (0)