Skip to content

Commit 2f47ec3

Browse files
committed
fix race condition in promotion
I believe there is a race condition here, and we are closing the channel during promotion before we finish filling it. This is because the WaitGroup is incremented _after_ we add to the channel. Refactor to make the logic (hopefully) clearer.
1 parent d7788b5 commit 2f47ec3

File tree

3 files changed

+340
-245
lines changed

3 files changed

+340
-245
lines changed

internal/legacy/dockerregistry/inventory.go

Lines changed: 127 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1364,6 +1364,59 @@ func (sc *SyncContext) ExecRequests(
13641364
return err
13651365
}
13661366

1367+
type ForkJoinFunc[K any, V any] func(k K) (V, error)
1368+
1369+
type ForkJoinResult[K any, V any] struct {
1370+
Input K
1371+
Output V
1372+
Error error
1373+
}
1374+
1375+
// ExecRequests uses the Worker Pool pattern, where maxConcurrentRequests
1376+
// determines the number of workers to spawn.
1377+
func ForkJoin[K any, V any](
1378+
maxConcurrentRequests int,
1379+
requests []K,
1380+
processRequest ForkJoinFunc[K, V],
1381+
) []ForkJoinResult[K, V] {
1382+
1383+
reqChan := make(chan K, maxConcurrentRequests*2)
1384+
1385+
var wg sync.WaitGroup
1386+
var resultsMutex sync.Mutex
1387+
var results []ForkJoinResult[K, V]
1388+
1389+
// Log any errors encountered.
1390+
for range maxConcurrentRequests {
1391+
wg.Add(1)
1392+
go func() {
1393+
defer wg.Done()
1394+
1395+
for req := range reqChan {
1396+
result, err := processRequest(req)
1397+
1398+
resultsMutex.Lock()
1399+
results = append(results, ForkJoinResult[K, V]{
1400+
Input: req,
1401+
Output: result,
1402+
Error: err,
1403+
})
1404+
resultsMutex.Unlock()
1405+
}
1406+
}()
1407+
}
1408+
1409+
for _, req := range requests {
1410+
reqChan <- req
1411+
}
1412+
close(reqChan)
1413+
1414+
// Wait for all workers to finish draining the jobs.
1415+
wg.Wait()
1416+
1417+
return results
1418+
}
1419+
13671420
func extractRegistryTags(reader io.Reader) (*ggcrV1Google.Tags, error) {
13681421
tags := ggcrV1Google.Tags{}
13691422
decoder := json.NewDecoder(reader)
@@ -1479,18 +1532,17 @@ func (sc *SyncContext) ValidateEdge(edge *PromotionEdge) error {
14791532
return nil
14801533
}
14811534

1482-
// MKPopulateRequestsForPromotionEdges takes in a map of PromotionEdges to promote
1483-
// and a PromotionContext and returns a PopulateRequests which can generate
1484-
// requests to be processed.
1485-
func MKPopulateRequestsForPromotionEdges(
1535+
// BuildPopulateRequestsForPromotionEdges takes in a map of PromotionEdges to promote
1536+
// and a PromotionContext and returns the promotion requests.
1537+
func (sc *SyncContext) BuildPopulateRequestsForPromotionEdges(
14861538
toPromote map[PromotionEdge]interface{},
1487-
) PopulateRequests {
1488-
return func(sc *SyncContext, reqs chan<- stream.ExternalRequest, wg *sync.WaitGroup) {
1489-
if len(toPromote) == 0 {
1490-
logrus.Info("Nothing to promote.")
1491-
return
1492-
}
1539+
) []stream.ExternalRequest {
14931540

1541+
var requests []stream.ExternalRequest
1542+
1543+
if len(toPromote) == 0 {
1544+
logrus.Info("Nothing to promote.")
1545+
} else {
14941546
if sc.Confirm {
14951547
logrus.Info("---------- BEGIN PROMOTION ----------")
14961548
} else {
@@ -1533,10 +1585,11 @@ func MKPopulateRequestsForPromotionEdges(
15331585
promoteMe.DstImageTag.Tag,
15341586
}
15351587

1536-
wg.Add(1)
1537-
reqs <- req
1588+
requests = append(requests, req)
15381589
}
15391590
}
1591+
1592+
return requests
15401593
}
15411594

15421595
// RunChecks runs defined PreChecks in order to check the promotion.
@@ -1665,7 +1718,7 @@ func getRegistriesToRead(edges map[PromotionEdge]interface{}) []registry.Context
16651718
// Manifest.
16661719
func (sc *SyncContext) Promote(
16671720
edges map[PromotionEdge]interface{},
1668-
customProcessRequest *ProcessRequest,
1721+
customProcessRequest ProcessRequestFunc,
16691722
) error {
16701723
if len(edges) == 0 {
16711724
logrus.Info("Nothing to promote.")
@@ -1693,17 +1746,11 @@ func (sc *SyncContext) Promote(
16931746
}
16941747

16951748
var (
1696-
populateRequests = MKPopulateRequestsForPromotionEdges(edges)
1697-
1698-
processRequest ProcessRequest
1699-
processRequestReal ProcessRequest = func(
1700-
_ *SyncContext,
1701-
reqs chan stream.ExternalRequest,
1702-
requestResults chan<- RequestResult,
1703-
_ *sync.WaitGroup,
1704-
_ *sync.Mutex,
1705-
) {
1706-
for req := range reqs {
1749+
promotionRequests = sc.BuildPopulateRequestsForPromotionEdges(edges)
1750+
1751+
processRequest func(req stream.ExternalRequest) (RequestResult, error)
1752+
processRequestReal ProcessRequestFunc = func(req stream.ExternalRequest) (RequestResult, error) {
1753+
{
17071754
reqRes := RequestResult{Context: req}
17081755
errors := make(Errors, 0)
17091756
// If we're adding or moving (i.e., creating a new image or
@@ -1756,8 +1803,20 @@ func (sc *SyncContext) Promote(
17561803
logrus.Infof("deletions are no longer supported")
17571804
}
17581805

1806+
if len(errors) > 0 {
1807+
logrus.Errorf(
1808+
// TODO(log): Consider logging with fields
1809+
"request %v: error(s) encountered: %v",
1810+
reqRes.Context,
1811+
reqRes.Errors)
1812+
} else {
1813+
// TODO(log): Consider logging with fields
1814+
logrus.Infof("request %v: OK", reqRes.Context.RequestParams)
1815+
}
1816+
// Log the HTTP request to GCR.
1817+
reqcounter.Increment()
17591818
reqRes.Errors = errors
1760-
requestResults <- reqRes
1819+
return reqRes, nil
17611820
}
17621821
}
17631822
)
@@ -1767,16 +1826,34 @@ func (sc *SyncContext) Promote(
17671826
if sc.Confirm {
17681827
processRequest = processRequestReal
17691828
} else {
1770-
processRequestDryRun := MkRequestCapturer(&captured)
1829+
processRequestDryRun := MkRequestCapturerFunc(&captured)
17711830
processRequest = processRequestDryRun
17721831
}
17731832

17741833
if customProcessRequest != nil {
1775-
processRequest = *customProcessRequest
1834+
processRequest = customProcessRequest
17761835
}
17771836

17781837
sc.PrintCapturedRequests(&captured)
1779-
return sc.ExecRequests(populateRequests, processRequest)
1838+
1839+
// Run concurrent requests.
1840+
maxConcurrentRequests := 10
1841+
if sc.Threads > 0 {
1842+
maxConcurrentRequests = sc.Threads
1843+
}
1844+
1845+
results := ForkJoin(maxConcurrentRequests, promotionRequests, processRequest)
1846+
1847+
var errs []error
1848+
for _, result := range results {
1849+
if len(result.Output.Errors) > 0 {
1850+
sc.Logs.Errors = append(sc.Logs.Errors, result.Output.Errors...)
1851+
errs = append(errs, errors.New("encountered an error while executing requests"))
1852+
}
1853+
errs = append(errs, result.Error)
1854+
}
1855+
1856+
return errors.Join(errs...)
17801857
}
17811858

17821859
// PrintCapturedRequests pretty-prints all given PromotionRequests.
@@ -1845,6 +1922,7 @@ func (pr *PromotionRequest) PrettyValue() string {
18451922

18461923
// MkRequestCapturer returns a function that simply records requests as they are
18471924
// captured (slurped out from the reqs channel).
1925+
// Deprecated: Prefer MkRequestCapturerFunc
18481926
func MkRequestCapturer(captured *CapturedRequests) ProcessRequest {
18491927
return func(
18501928
_ *SyncContext,
@@ -1870,6 +1948,27 @@ func MkRequestCapturer(captured *CapturedRequests) ProcessRequest {
18701948
}
18711949
}
18721950

1951+
// MkRequestCapturer returns a function that simply records requests as they are
1952+
// captured (slurped out from the reqs channel).
1953+
func MkRequestCapturerFunc(captured *CapturedRequests) ProcessRequestFunc {
1954+
var mutex sync.Mutex
1955+
1956+
return func(req stream.ExternalRequest) (RequestResult, error) {
1957+
{
1958+
pr := req.RequestParams.(PromotionRequest)
1959+
1960+
mutex.Lock()
1961+
(*captured)[pr]++
1962+
mutex.Unlock()
1963+
1964+
// Add a request result to signal the processing of this "request".
1965+
// This is necessary because ExecRequests() is the sole function in
1966+
// the codebase that decrements the WaitGroup semaphore.
1967+
return RequestResult{}, nil
1968+
}
1969+
}
1970+
}
1971+
18731972
func supportedMediaType(v string) (ggcrV1Types.MediaType, error) {
18741973
switch ggcrV1Types.MediaType(v) {
18751974
case ggcrV1Types.DockerManifestList:

0 commit comments

Comments
 (0)