Skip to content

Commit bc1bdc6

Browse files
committed
Avoid panic on parallel walking on DefinitionOp
Signed-off-by: Kohei Tokunaga <[email protected]>
1 parent 52c2fe5 commit bc1bdc6

File tree

2 files changed

+45
-5
lines changed

2 files changed

+45
-5
lines changed

client/llb/definition.go

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ type DefinitionOp struct {
2424
platforms map[digest.Digest]*ocispecs.Platform
2525
dgst digest.Digest
2626
index pb.OutputIndex
27-
inputCache map[digest.Digest][]*DefinitionOp
27+
inputCache *sync.Map // shared and written among DefinitionOps so avoid race on this map using sync.Map
2828
}
2929

3030
// NewDefinitionOp returns a new operation from a marshalled definition.
@@ -101,7 +101,7 @@ func NewDefinitionOp(def *pb.Definition) (*DefinitionOp, error) {
101101
platforms: platforms,
102102
dgst: dgst,
103103
index: index,
104-
inputCache: make(map[digest.Digest][]*DefinitionOp),
104+
inputCache: new(sync.Map),
105105
}, nil
106106
}
107107

@@ -180,6 +180,18 @@ func (d *DefinitionOp) Output() Output {
180180
}}
181181
}
182182

183+
func (d *DefinitionOp) loadInputCache(dgst digest.Digest) ([]*DefinitionOp, bool) {
184+
a, ok := d.inputCache.Load(dgst.String())
185+
if ok {
186+
return a.([]*DefinitionOp), true
187+
}
188+
return nil, false
189+
}
190+
191+
func (d *DefinitionOp) storeInputCache(dgst digest.Digest, c []*DefinitionOp) {
192+
d.inputCache.Store(dgst.String(), c)
193+
}
194+
183195
func (d *DefinitionOp) Inputs() []Output {
184196
if d.dgst == "" {
185197
return nil
@@ -195,7 +207,7 @@ func (d *DefinitionOp) Inputs() []Output {
195207
for _, input := range op.Inputs {
196208
var vtx *DefinitionOp
197209
d.mu.Lock()
198-
if existingIndexes, ok := d.inputCache[input.Digest]; ok {
210+
if existingIndexes, ok := d.loadInputCache(input.Digest); ok {
199211
if int(input.Index) < len(existingIndexes) && existingIndexes[input.Index] != nil {
200212
vtx = existingIndexes[input.Index]
201213
}
@@ -211,14 +223,14 @@ func (d *DefinitionOp) Inputs() []Output {
211223
inputCache: d.inputCache,
212224
sources: d.sources,
213225
}
214-
existingIndexes := d.inputCache[input.Digest]
226+
existingIndexes, _ := d.loadInputCache(input.Digest)
215227
indexDiff := int(input.Index) - len(existingIndexes)
216228
if indexDiff >= 0 {
217229
// make room in the slice for the new index being set
218230
existingIndexes = append(existingIndexes, make([]*DefinitionOp, indexDiff+1)...)
219231
}
220232
existingIndexes[input.Index] = vtx
221-
d.inputCache[input.Digest] = existingIndexes
233+
d.storeInputCache(input.Digest, existingIndexes)
222234
}
223235
d.mu.Unlock()
224236

client/llb/definition_test.go

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,14 @@ package llb
33
import (
44
"bytes"
55
"context"
6+
"fmt"
67
"testing"
78

89
"github.com/containerd/containerd/platforms"
910
"github.com/moby/buildkit/solver/pb"
1011
digest "github.com/opencontainers/go-digest"
1112
"github.com/stretchr/testify/require"
13+
"golang.org/x/sync/errgroup"
1214
)
1315

1416
func TestDefinitionEquivalence(t *testing.T) {
@@ -117,10 +119,36 @@ func TestDefinitionInputCache(t *testing.T) {
117119
require.NoError(t, err)
118120
// 1 exec + 2x2 mounts from stA and stB + 1 src = 6 vertexes
119121
require.Equal(t, 6, len(vertexCache))
122+
123+
// make sure that walking vertices in parallel doesn't cause panic
124+
var all []RunOption
125+
for i := 0; i < 100; i++ {
126+
var sts []RunOption
127+
for j := 0; j < 100; j++ {
128+
sts = append(sts, AddMount("/mnt", Scratch().Run(Shlex(fmt.Sprintf("%d-%d", i, j))).Root()))
129+
}
130+
all = append(all, AddMount("/mnt", Scratch().Run(append([]RunOption{Shlex("args")}, sts...)...).Root()))
131+
}
132+
def, err = Scratch().Run(append([]RunOption{Shlex("args")}, all...)...).Root().Marshal(context.TODO())
133+
require.NoError(t, err)
134+
op, err = NewDefinitionOp(def.ToPB())
135+
require.NoError(t, err)
136+
require.NoError(t, testParallelWalk(context.Background(), op.Output()))
120137
}
121138

122139
func TestDefinitionNil(t *testing.T) {
123140
// should be an error, not a panic
124141
_, err := NewDefinitionOp(nil)
125142
require.Error(t, err)
126143
}
144+
145+
func testParallelWalk(ctx context.Context, out Output) error {
146+
eg, egCtx := errgroup.WithContext(ctx)
147+
for _, o := range out.Vertex(ctx, nil).Inputs() {
148+
o := o
149+
eg.Go(func() error {
150+
return testParallelWalk(egCtx, o)
151+
})
152+
}
153+
return eg.Wait()
154+
}

0 commit comments

Comments
 (0)