Skip to content

Commit 9cd2dbf

Browse files
committed
fix (probable) race condition in promotion
I believe there is a race condition here, and we are closing the channel during promotion before we finish filling it in ExecRequests. This manifests as images not being promoted, seemingly at random, without errors. Refactor to make the logic (hopefully) clearer.
1 parent d7788b5 commit 9cd2dbf

File tree

3 files changed

+340
-245
lines changed

3 files changed

+340
-245
lines changed

internal/legacy/dockerregistry/inventory.go

Lines changed: 127 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1364,6 +1364,59 @@ func (sc *SyncContext) ExecRequests(
13641364
return err
13651365
}
13661366

1367+
type ForkJoinFunc[K any, V any] func(k K) (V, error)
1368+
1369+
type ForkJoinResult[K any, V any] struct {
1370+
Input K
1371+
Output V
1372+
Error error
1373+
}
1374+
1375+
// ExecRequests uses the Worker Pool pattern, where maxConcurrentRequests
1376+
// determines the number of workers to spawn.
1377+
func ForkJoin[K any, V any](
1378+
maxConcurrentRequests int,
1379+
requests []K,
1380+
processRequest ForkJoinFunc[K, V],
1381+
) []ForkJoinResult[K, V] {
1382+
1383+
reqChan := make(chan K, maxConcurrentRequests*2)
1384+
1385+
var wg sync.WaitGroup
1386+
var resultsMutex sync.Mutex
1387+
var results []ForkJoinResult[K, V]
1388+
1389+
// Log any errors encountered.
1390+
for range maxConcurrentRequests {
1391+
wg.Add(1)
1392+
go func() {
1393+
defer wg.Done()
1394+
1395+
for req := range reqChan {
1396+
result, err := processRequest(req)
1397+
1398+
resultsMutex.Lock()
1399+
results = append(results, ForkJoinResult[K, V]{
1400+
Input: req,
1401+
Output: result,
1402+
Error: err,
1403+
})
1404+
resultsMutex.Unlock()
1405+
}
1406+
}()
1407+
}
1408+
1409+
for _, req := range requests {
1410+
reqChan <- req
1411+
}
1412+
close(reqChan)
1413+
1414+
// Wait for all workers to finish draining the jobs.
1415+
wg.Wait()
1416+
1417+
return results
1418+
}
1419+
13671420
func extractRegistryTags(reader io.Reader) (*ggcrV1Google.Tags, error) {
13681421
tags := ggcrV1Google.Tags{}
13691422
decoder := json.NewDecoder(reader)
@@ -1479,18 +1532,17 @@ func (sc *SyncContext) ValidateEdge(edge *PromotionEdge) error {
14791532
return nil
14801533
}
14811534

1482-
// MKPopulateRequestsForPromotionEdges takes in a map of PromotionEdges to promote
1483-
// and a PromotionContext and returns a PopulateRequests which can generate
1484-
// requests to be processed.
1485-
func MKPopulateRequestsForPromotionEdges(
1535+
// BuildPopulateRequestsForPromotionEdges takes in a map of PromotionEdges to promote
1536+
// and a PromotionContext and returns the promotion requests.
1537+
func (sc *SyncContext) BuildPopulateRequestsForPromotionEdges(
14861538
toPromote map[PromotionEdge]interface{},
1487-
) PopulateRequests {
1488-
return func(sc *SyncContext, reqs chan<- stream.ExternalRequest, wg *sync.WaitGroup) {
1489-
if len(toPromote) == 0 {
1490-
logrus.Info("Nothing to promote.")
1491-
return
1492-
}
1539+
) []stream.ExternalRequest {
14931540

1541+
var requests []stream.ExternalRequest
1542+
1543+
if len(toPromote) == 0 {
1544+
logrus.Info("Nothing to promote.")
1545+
} else {
14941546
if sc.Confirm {
14951547
logrus.Info("---------- BEGIN PROMOTION ----------")
14961548
} else {
@@ -1533,10 +1585,11 @@ func MKPopulateRequestsForPromotionEdges(
15331585
promoteMe.DstImageTag.Tag,
15341586
}
15351587

1536-
wg.Add(1)
1537-
reqs <- req
1588+
requests = append(requests, req)
15381589
}
15391590
}
1591+
1592+
return requests
15401593
}
15411594

15421595
// RunChecks runs defined PreChecks in order to check the promotion.
@@ -1665,7 +1718,7 @@ func getRegistriesToRead(edges map[PromotionEdge]interface{}) []registry.Context
16651718
// Manifest.
16661719
func (sc *SyncContext) Promote(
16671720
edges map[PromotionEdge]interface{},
1668-
customProcessRequest *ProcessRequest,
1721+
customProcessRequest ProcessRequestFunc,
16691722
) error {
16701723
if len(edges) == 0 {
16711724
logrus.Info("Nothing to promote.")
@@ -1693,17 +1746,11 @@ func (sc *SyncContext) Promote(
16931746
}
16941747

16951748
var (
1696-
populateRequests = MKPopulateRequestsForPromotionEdges(edges)
1697-
1698-
processRequest ProcessRequest
1699-
processRequestReal ProcessRequest = func(
1700-
_ *SyncContext,
1701-
reqs chan stream.ExternalRequest,
1702-
requestResults chan<- RequestResult,
1703-
_ *sync.WaitGroup,
1704-
_ *sync.Mutex,
1705-
) {
1706-
for req := range reqs {
1749+
promotionRequests = sc.BuildPopulateRequestsForPromotionEdges(edges)
1750+
1751+
processRequest func(req stream.ExternalRequest) (RequestResult, error)
1752+
processRequestReal ProcessRequestFunc = func(req stream.ExternalRequest) (RequestResult, error) {
1753+
{
17071754
reqRes := RequestResult{Context: req}
17081755
errors := make(Errors, 0)
17091756
// If we're adding or moving (i.e., creating a new image or
@@ -1756,8 +1803,20 @@ func (sc *SyncContext) Promote(
17561803
logrus.Infof("deletions are no longer supported")
17571804
}
17581805

1806+
if len(errors) > 0 {
1807+
logrus.Errorf(
1808+
// TODO(log): Consider logging with fields
1809+
"request %v: error(s) encountered: %v",
1810+
reqRes.Context,
1811+
reqRes.Errors)
1812+
} else {
1813+
// TODO(log): Consider logging with fields
1814+
logrus.Infof("request %v: OK", reqRes.Context.RequestParams)
1815+
}
1816+
// Log the HTTP request to GCR.
1817+
reqcounter.Increment()
17591818
reqRes.Errors = errors
1760-
requestResults <- reqRes
1819+
return reqRes, nil
17611820
}
17621821
}
17631822
)
@@ -1767,16 +1826,34 @@ func (sc *SyncContext) Promote(
17671826
if sc.Confirm {
17681827
processRequest = processRequestReal
17691828
} else {
1770-
processRequestDryRun := MkRequestCapturer(&captured)
1829+
processRequestDryRun := MkRequestCapturerFunc(&captured)
17711830
processRequest = processRequestDryRun
17721831
}
17731832

17741833
if customProcessRequest != nil {
1775-
processRequest = *customProcessRequest
1834+
processRequest = customProcessRequest
17761835
}
17771836

17781837
sc.PrintCapturedRequests(&captured)
1779-
return sc.ExecRequests(populateRequests, processRequest)
1838+
1839+
// Run concurrent requests.
1840+
maxConcurrentRequests := 10
1841+
if sc.Threads > 0 {
1842+
maxConcurrentRequests = sc.Threads
1843+
}
1844+
1845+
results := ForkJoin(maxConcurrentRequests, promotionRequests, processRequest)
1846+
1847+
var errs []error
1848+
for _, result := range results {
1849+
if len(result.Output.Errors) > 0 {
1850+
sc.Logs.Errors = append(sc.Logs.Errors, result.Output.Errors...)
1851+
errs = append(errs, errors.New("encountered an error while executing requests"))
1852+
}
1853+
errs = append(errs, result.Error)
1854+
}
1855+
1856+
return errors.Join(errs...)
17801857
}
17811858

17821859
// PrintCapturedRequests pretty-prints all given PromotionRequests.
@@ -1845,6 +1922,7 @@ func (pr *PromotionRequest) PrettyValue() string {
18451922

18461923
// MkRequestCapturer returns a function that simply records requests as they are
18471924
// captured (slurped out from the reqs channel).
1925+
// Deprecated: Prefer MkRequestCapturerFunc
18481926
func MkRequestCapturer(captured *CapturedRequests) ProcessRequest {
18491927
return func(
18501928
_ *SyncContext,
@@ -1870,6 +1948,27 @@ func MkRequestCapturer(captured *CapturedRequests) ProcessRequest {
18701948
}
18711949
}
18721950

1951+
// MkRequestCapturer returns a function that simply records requests as they are
1952+
// captured (slurped out from the reqs channel).
1953+
func MkRequestCapturerFunc(captured *CapturedRequests) ProcessRequestFunc {
1954+
var mutex sync.Mutex
1955+
1956+
return func(req stream.ExternalRequest) (RequestResult, error) {
1957+
{
1958+
pr := req.RequestParams.(PromotionRequest)
1959+
1960+
mutex.Lock()
1961+
(*captured)[pr]++
1962+
mutex.Unlock()
1963+
1964+
// Add a request result to signal the processing of this "request".
1965+
// This is necessary because ExecRequests() is the sole function in
1966+
// the codebase that decrements the WaitGroup semaphore.
1967+
return RequestResult{}, nil
1968+
}
1969+
}
1970+
}
1971+
18731972
func supportedMediaType(v string) (ggcrV1Types.MediaType, error) {
18741973
switch ggcrV1Types.MediaType(v) {
18751974
case ggcrV1Types.DockerManifestList:

0 commit comments

Comments
 (0)