Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 127 additions & 28 deletions internal/legacy/dockerregistry/inventory.go
Original file line number Diff line number Diff line change
Expand Up @@ -1364,6 +1364,59 @@ func (sc *SyncContext) ExecRequests(
return err
}

type ForkJoinFunc[K any, V any] func(k K) (V, error)

type ForkJoinResult[K any, V any] struct {
Input K
Output V
Error error
}

// ExecRequests uses the Worker Pool pattern, where maxConcurrentRequests
// determines the number of workers to spawn.
func ForkJoin[K any, V any](
maxConcurrentRequests int,
requests []K,
processRequest ForkJoinFunc[K, V],
) []ForkJoinResult[K, V] {

reqChan := make(chan K, maxConcurrentRequests*2)

var wg sync.WaitGroup
var resultsMutex sync.Mutex
var results []ForkJoinResult[K, V]

// Log any errors encountered.
for range maxConcurrentRequests {
wg.Add(1)
go func() {
defer wg.Done()

for req := range reqChan {
result, err := processRequest(req)

resultsMutex.Lock()
results = append(results, ForkJoinResult[K, V]{
Input: req,
Output: result,
Error: err,
})
resultsMutex.Unlock()
}
}()
}

for _, req := range requests {
reqChan <- req
}
close(reqChan)

// Wait for all workers to finish draining the jobs.
wg.Wait()

return results
}

func extractRegistryTags(reader io.Reader) (*ggcrV1Google.Tags, error) {
tags := ggcrV1Google.Tags{}
decoder := json.NewDecoder(reader)
Expand Down Expand Up @@ -1479,18 +1532,17 @@ func (sc *SyncContext) ValidateEdge(edge *PromotionEdge) error {
return nil
}

// MKPopulateRequestsForPromotionEdges takes in a map of PromotionEdges to promote
// and a PromotionContext and returns a PopulateRequests which can generate
// requests to be processed.
func MKPopulateRequestsForPromotionEdges(
// BuildPopulateRequestsForPromotionEdges takes in a map of PromotionEdges to promote
// and a PromotionContext and returns the promotion requests.
func (sc *SyncContext) BuildPopulateRequestsForPromotionEdges(
toPromote map[PromotionEdge]interface{},
) PopulateRequests {
return func(sc *SyncContext, reqs chan<- stream.ExternalRequest, wg *sync.WaitGroup) {
if len(toPromote) == 0 {
logrus.Info("Nothing to promote.")
return
}
) []stream.ExternalRequest {

var requests []stream.ExternalRequest

if len(toPromote) == 0 {
logrus.Info("Nothing to promote.")
} else {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not the most idiomatic, but preserves indentation for easier review.

if sc.Confirm {
logrus.Info("---------- BEGIN PROMOTION ----------")
} else {
Expand Down Expand Up @@ -1533,10 +1585,11 @@ func MKPopulateRequestsForPromotionEdges(
promoteMe.DstImageTag.Tag,
}

wg.Add(1)
reqs <- req
requests = append(requests, req)
}
}

return requests
}

// RunChecks runs defined PreChecks in order to check the promotion.
Expand Down Expand Up @@ -1665,7 +1718,7 @@ func getRegistriesToRead(edges map[PromotionEdge]interface{}) []registry.Context
// Manifest.
func (sc *SyncContext) Promote(
edges map[PromotionEdge]interface{},
customProcessRequest *ProcessRequest,
customProcessRequest ProcessRequestFunc,
) error {
if len(edges) == 0 {
logrus.Info("Nothing to promote.")
Expand Down Expand Up @@ -1693,17 +1746,11 @@ func (sc *SyncContext) Promote(
}

var (
populateRequests = MKPopulateRequestsForPromotionEdges(edges)

processRequest ProcessRequest
processRequestReal ProcessRequest = func(
_ *SyncContext,
reqs chan stream.ExternalRequest,
requestResults chan<- RequestResult,
_ *sync.WaitGroup,
_ *sync.Mutex,
) {
for req := range reqs {
promotionRequests = sc.BuildPopulateRequestsForPromotionEdges(edges)

processRequest func(req stream.ExternalRequest) (RequestResult, error)
processRequestReal ProcessRequestFunc = func(req stream.ExternalRequest) (RequestResult, error) {
{
reqRes := RequestResult{Context: req}
errors := make(Errors, 0)
// If we're adding or moving (i.e., creating a new image or
Expand Down Expand Up @@ -1756,8 +1803,20 @@ func (sc *SyncContext) Promote(
logrus.Infof("deletions are no longer supported")
}

if len(errors) > 0 {
logrus.Errorf(
// TODO(log): Consider logging with fields
"request %v: error(s) encountered: %v",
reqRes.Context,
reqRes.Errors)
} else {
// TODO(log): Consider logging with fields
logrus.Infof("request %v: OK", reqRes.Context.RequestParams)
}
// Log the HTTP request to GCR.
reqcounter.Increment()
reqRes.Errors = errors
requestResults <- reqRes
return reqRes, nil
}
}
)
Expand All @@ -1767,16 +1826,34 @@ func (sc *SyncContext) Promote(
if sc.Confirm {
processRequest = processRequestReal
} else {
processRequestDryRun := MkRequestCapturer(&captured)
processRequestDryRun := MkRequestCapturerFunc(&captured)
processRequest = processRequestDryRun
}

if customProcessRequest != nil {
processRequest = *customProcessRequest
processRequest = customProcessRequest
}

sc.PrintCapturedRequests(&captured)
return sc.ExecRequests(populateRequests, processRequest)

// Run concurrent requests.
maxConcurrentRequests := 10
if sc.Threads > 0 {
maxConcurrentRequests = sc.Threads
}

results := ForkJoin(maxConcurrentRequests, promotionRequests, processRequest)

var errs []error
for _, result := range results {
if len(result.Output.Errors) > 0 {
sc.Logs.Errors = append(sc.Logs.Errors, result.Output.Errors...)
errs = append(errs, errors.New("encountered an error while executing requests"))
}
errs = append(errs, result.Error)
}

return errors.Join(errs...)
}

// PrintCapturedRequests pretty-prints all given PromotionRequests.
Expand Down Expand Up @@ -1845,6 +1922,7 @@ func (pr *PromotionRequest) PrettyValue() string {

// MkRequestCapturer returns a function that simply records requests as they are
// captured (slurped out from the reqs channel).
// Deprecated: Prefer MkRequestCapturerFunc
func MkRequestCapturer(captured *CapturedRequests) ProcessRequest {
return func(
_ *SyncContext,
Expand All @@ -1870,6 +1948,27 @@ func MkRequestCapturer(captured *CapturedRequests) ProcessRequest {
}
}

// MkRequestCapturer returns a function that simply records requests as they are
// captured (slurped out from the reqs channel).
func MkRequestCapturerFunc(captured *CapturedRequests) ProcessRequestFunc {
var mutex sync.Mutex

return func(req stream.ExternalRequest) (RequestResult, error) {
{
pr := req.RequestParams.(PromotionRequest)

mutex.Lock()
(*captured)[pr]++
mutex.Unlock()

// Add a request result to signal the processing of this "request".
// This is necessary because ExecRequests() is the sole function in
// the codebase that decrements the WaitGroup semaphore.
return RequestResult{}, nil
}
}
}

func supportedMediaType(v string) (ggcrV1Types.MediaType, error) {
switch ggcrV1Types.MediaType(v) {
case ggcrV1Types.DockerManifestList:
Expand Down
Loading