Skip to content

Commit 45eec75

Browse files
committed
fix race condition in promotion
I believe there is a race condition here, and we are closing the channel during promotion before we finish filling it. This is because the WaitGroup is incremented _after_ we add to the channel. Refactor to make the logic (hopefully) clearer.
1 parent d7788b5 commit 45eec75

File tree

3 files changed

+380
-284
lines changed

3 files changed

+380
-284
lines changed

internal/legacy/dockerregistry/inventory.go

Lines changed: 167 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -1364,6 +1364,59 @@ func (sc *SyncContext) ExecRequests(
13641364
return err
13651365
}
13661366

1367+
type ForkJoinFunc[K any, V any] func(k K) (V, error)
1368+
1369+
type ForkJoinResult[K any, V any] struct {
1370+
Input K
1371+
Output V
1372+
Error error
1373+
}
1374+
1375+
// ExecRequests uses the Worker Pool pattern, where maxConcurrentRequests
1376+
// determines the number of workers to spawn.
1377+
func ForkJoin[K any, V any](
1378+
maxConcurrentRequests int,
1379+
requests []K,
1380+
processRequest ForkJoinFunc[K, V],
1381+
) []ForkJoinResult[K, V] {
1382+
1383+
reqChan := make(chan K, maxConcurrentRequests*2)
1384+
1385+
var wg sync.WaitGroup
1386+
var resultsMutex sync.Mutex
1387+
var results []ForkJoinResult[K, V]
1388+
1389+
// Log any errors encountered.
1390+
for range maxConcurrentRequests {
1391+
wg.Add(1)
1392+
go func() {
1393+
defer wg.Done()
1394+
1395+
for req := range reqChan {
1396+
result, err := processRequest(req)
1397+
1398+
resultsMutex.Lock()
1399+
results = append(results, ForkJoinResult[K, V]{
1400+
Input: req,
1401+
Output: result,
1402+
Error: err,
1403+
})
1404+
resultsMutex.Unlock()
1405+
}
1406+
}()
1407+
}
1408+
1409+
for _, req := range requests {
1410+
reqChan <- req
1411+
}
1412+
close(reqChan)
1413+
1414+
// Wait for all workers to finish draining the jobs.
1415+
wg.Wait()
1416+
1417+
return results
1418+
}
1419+
13671420
func extractRegistryTags(reader io.Reader) (*ggcrV1Google.Tags, error) {
13681421
tags := ggcrV1Google.Tags{}
13691422
decoder := json.NewDecoder(reader)
@@ -1479,64 +1532,65 @@ func (sc *SyncContext) ValidateEdge(edge *PromotionEdge) error {
14791532
return nil
14801533
}
14811534

1482-
// MKPopulateRequestsForPromotionEdges takes in a map of PromotionEdges to promote
1483-
// and a PromotionContext and returns a PopulateRequests which can generate
1484-
// requests to be processed.
1485-
func MKPopulateRequestsForPromotionEdges(
1535+
// BuildPopulateRequestsForPromotionEdges takes in a map of PromotionEdges to promote
1536+
// and a PromotionContext and returns the promotion requests.
1537+
func (sc *SyncContext) BuildPopulateRequestsForPromotionEdges(
14861538
toPromote map[PromotionEdge]interface{},
1487-
) PopulateRequests {
1488-
return func(sc *SyncContext, reqs chan<- stream.ExternalRequest, wg *sync.WaitGroup) {
1489-
if len(toPromote) == 0 {
1490-
logrus.Info("Nothing to promote.")
1491-
return
1492-
}
1539+
) []stream.ExternalRequest {
14931540

1494-
if sc.Confirm {
1495-
logrus.Info("---------- BEGIN PROMOTION ----------")
1496-
} else {
1497-
logrus.Info("---------- BEGIN PROMOTION (DRY RUN) ----------")
1498-
}
1541+
var requests []stream.ExternalRequest
14991542

1500-
for promoteMe := range toPromote {
1501-
var req stream.ExternalRequest
1502-
oldDigest := image.Digest("")
1503-
1504-
// Technically speaking none of the edges at this point should be
1505-
// invalid (such as trying to do tag moves), because we run
1506-
// ValidateEdges() in Promote() in the early stages, before passing
1507-
// on this closure to ExecRequests(). However the check here is so
1508-
// cheap that we do it anyway, just in case.
1509-
if err := sc.ValidateEdge(&promoteMe); err != nil {
1510-
logrus.Error(err)
1511-
continue
1512-
}
1543+
if len(toPromote) == 0 {
1544+
logrus.Info("Nothing to promote.")
1545+
return requests
1546+
}
15131547

1514-
// Save some information about this request. It's a bit like
1515-
// HTTP "headers".
1516-
req.RequestParams = PromotionRequest{
1517-
// Only support adding new tags during a promotion run. Tag
1518-
// moves and deletions are not supported.
1519-
//
1520-
// Although disallowing tag moves sounds a bit draconian, it
1521-
// does make protect production from a malformed set of promoter
1522-
// manifests with incorrect tag information.
1523-
Add,
1524-
// TODO: Clean up types to avoid having to split up promoteMe
1525-
// prematurely like this.
1526-
promoteMe.SrcRegistry.Name,
1527-
promoteMe.DstRegistry.Name,
1528-
promoteMe.DstRegistry.ServiceAccount,
1529-
promoteMe.SrcImageTag.Name,
1530-
promoteMe.DstImageTag.Name,
1531-
promoteMe.Digest,
1532-
oldDigest,
1533-
promoteMe.DstImageTag.Tag,
1534-
}
1548+
if sc.Confirm {
1549+
logrus.Info("---------- BEGIN PROMOTION ----------")
1550+
} else {
1551+
logrus.Info("---------- BEGIN PROMOTION (DRY RUN) ----------")
1552+
}
15351553

1536-
wg.Add(1)
1537-
reqs <- req
1554+
for promoteMe := range toPromote {
1555+
var req stream.ExternalRequest
1556+
oldDigest := image.Digest("")
1557+
1558+
// Technically speaking none of the edges at this point should be
1559+
// invalid (such as trying to do tag moves), because we run
1560+
// ValidateEdges() in Promote() in the early stages, before passing
1561+
// on this closure to ExecRequests(). However the check here is so
1562+
// cheap that we do it anyway, just in case.
1563+
if err := sc.ValidateEdge(&promoteMe); err != nil {
1564+
logrus.Error(err)
1565+
continue
1566+
}
1567+
1568+
// Save some information about this request. It's a bit like
1569+
// HTTP "headers".
1570+
req.RequestParams = PromotionRequest{
1571+
// Only support adding new tags during a promotion run. Tag
1572+
// moves and deletions are not supported.
1573+
//
1574+
// Although disallowing tag moves sounds a bit draconian, it
1575+
// does make protect production from a malformed set of promoter
1576+
// manifests with incorrect tag information.
1577+
Add,
1578+
// TODO: Clean up types to avoid having to split up promoteMe
1579+
// prematurely like this.
1580+
promoteMe.SrcRegistry.Name,
1581+
promoteMe.DstRegistry.Name,
1582+
promoteMe.DstRegistry.ServiceAccount,
1583+
promoteMe.SrcImageTag.Name,
1584+
promoteMe.DstImageTag.Name,
1585+
promoteMe.Digest,
1586+
oldDigest,
1587+
promoteMe.DstImageTag.Tag,
15381588
}
1589+
1590+
requests = append(requests, req)
15391591
}
1592+
1593+
return requests
15401594
}
15411595

15421596
// RunChecks runs defined PreChecks in order to check the promotion.
@@ -1665,7 +1719,7 @@ func getRegistriesToRead(edges map[PromotionEdge]interface{}) []registry.Context
16651719
// Manifest.
16661720
func (sc *SyncContext) Promote(
16671721
edges map[PromotionEdge]interface{},
1668-
customProcessRequest *ProcessRequest,
1722+
customProcessRequest ProcessRequestFunc,
16691723
) error {
16701724
if len(edges) == 0 {
16711725
logrus.Info("Nothing to promote.")
@@ -1693,17 +1747,11 @@ func (sc *SyncContext) Promote(
16931747
}
16941748

16951749
var (
1696-
populateRequests = MKPopulateRequestsForPromotionEdges(edges)
1697-
1698-
processRequest ProcessRequest
1699-
processRequestReal ProcessRequest = func(
1700-
_ *SyncContext,
1701-
reqs chan stream.ExternalRequest,
1702-
requestResults chan<- RequestResult,
1703-
_ *sync.WaitGroup,
1704-
_ *sync.Mutex,
1705-
) {
1706-
for req := range reqs {
1750+
promotionRequests = sc.BuildPopulateRequestsForPromotionEdges(edges)
1751+
1752+
processRequest func(req stream.ExternalRequest) (RequestResult, error)
1753+
processRequestReal ProcessRequestFunc = func(req stream.ExternalRequest) (RequestResult, error) {
1754+
{
17071755
reqRes := RequestResult{Context: req}
17081756
errors := make(Errors, 0)
17091757
// If we're adding or moving (i.e., creating a new image or
@@ -1756,8 +1804,20 @@ func (sc *SyncContext) Promote(
17561804
logrus.Infof("deletions are no longer supported")
17571805
}
17581806

1807+
if len(errors) > 0 {
1808+
logrus.Errorf(
1809+
// TODO(log): Consider logging with fields
1810+
"request %v: error(s) encountered: %v",
1811+
reqRes.Context,
1812+
reqRes.Errors)
1813+
} else {
1814+
// TODO(log): Consider logging with fields
1815+
logrus.Infof("request %v: OK", reqRes.Context.RequestParams)
1816+
}
1817+
// Log the HTTP request to GCR.
1818+
reqcounter.Increment()
17591819
reqRes.Errors = errors
1760-
requestResults <- reqRes
1820+
return reqRes, nil
17611821
}
17621822
}
17631823
)
@@ -1767,16 +1827,34 @@ func (sc *SyncContext) Promote(
17671827
if sc.Confirm {
17681828
processRequest = processRequestReal
17691829
} else {
1770-
processRequestDryRun := MkRequestCapturer(&captured)
1830+
processRequestDryRun := MkRequestCapturerFunc(&captured)
17711831
processRequest = processRequestDryRun
17721832
}
17731833

17741834
if customProcessRequest != nil {
1775-
processRequest = *customProcessRequest
1835+
processRequest = customProcessRequest
17761836
}
17771837

17781838
sc.PrintCapturedRequests(&captured)
1779-
return sc.ExecRequests(populateRequests, processRequest)
1839+
1840+
// Run requests.
1841+
maxConcurrentRequests := 10
1842+
1843+
if sc.Threads > 0 {
1844+
maxConcurrentRequests = sc.Threads
1845+
}
1846+
results := ForkJoin(maxConcurrentRequests, promotionRequests, processRequest)
1847+
1848+
var errs []error
1849+
for _, result := range results {
1850+
if len(result.Output.Errors) > 0 {
1851+
sc.Logs.Errors = append(sc.Logs.Errors, result.Output.Errors...)
1852+
errs = append(errs, errors.New("encountered an error while executing requests"))
1853+
}
1854+
errs = append(errs, result.Error)
1855+
}
1856+
1857+
return errors.Join(errs...)
17801858
}
17811859

17821860
// PrintCapturedRequests pretty-prints all given PromotionRequests.
@@ -1845,6 +1923,7 @@ func (pr *PromotionRequest) PrettyValue() string {
18451923

18461924
// MkRequestCapturer returns a function that simply records requests as they are
18471925
// captured (slurped out from the reqs channel).
1926+
// Deprecated: Prefer MkRequestCapturerFunc
18481927
func MkRequestCapturer(captured *CapturedRequests) ProcessRequest {
18491928
return func(
18501929
_ *SyncContext,
@@ -1870,6 +1949,27 @@ func MkRequestCapturer(captured *CapturedRequests) ProcessRequest {
18701949
}
18711950
}
18721951

1952+
// MkRequestCapturer returns a function that simply records requests as they are
1953+
// captured (slurped out from the reqs channel).
1954+
func MkRequestCapturerFunc(captured *CapturedRequests) ProcessRequestFunc {
1955+
var mutex sync.Mutex
1956+
1957+
return func(req stream.ExternalRequest) (RequestResult, error) {
1958+
{
1959+
pr := req.RequestParams.(PromotionRequest)
1960+
1961+
mutex.Lock()
1962+
(*captured)[pr]++
1963+
mutex.Unlock()
1964+
1965+
// Add a request result to signal the processing of this "request".
1966+
// This is necessary because ExecRequests() is the sole function in
1967+
// the codebase that decrements the WaitGroup semaphore.
1968+
return RequestResult{}, nil
1969+
}
1970+
}
1971+
}
1972+
18731973
func supportedMediaType(v string) (ggcrV1Types.MediaType, error) {
18741974
switch ggcrV1Types.MediaType(v) {
18751975
case ggcrV1Types.DockerManifestList:

0 commit comments

Comments
 (0)