Skip to content

Commit 15c1451

Browse files
committed
parallel ForEach methods
1 parent 4bdc7bc commit 15c1451

File tree

2 files changed

+276
-0
lines changed

2 files changed

+276
-0
lines changed

amt.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,14 @@ func (r *Root) ForEachAt(ctx context.Context, start uint64, cb func(uint64, *cbg
321321
return r.node.forEachAt(ctx, r.store, r.bitWidth, r.height, start, 0, cb)
322322
}
323323

324+
func (r *Root) ForEachParallel(ctx context.Context, concurrency int, cb func(uint64, *cbg.Deferred) error) error {
325+
return r.node.forEachAtParallel(ctx, r.store, r.bitWidth, r.height, 0, 0, cb, concurrency)
326+
}
327+
328+
func (r *Root) ForEachAtParallel(ctx context.Context, concurrency int, start uint64, cb func(uint64, *cbg.Deferred) error) error {
329+
return r.node.forEachAtParallel(ctx, r.store, r.bitWidth, r.height, start, 0, cb, concurrency)
330+
}
331+
324332
// FirstSetIndex finds the lowest index in this AMT that has a value set for
325333
// it. If this operation is called on an empty AMT, an ErrNoValues will be
326334
// returned.

node.go

Lines changed: 268 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"github.com/ipfs/go-cid"
1010
cbor "github.com/ipfs/go-ipld-cbor"
1111
cbg "github.com/whyrusleeping/cbor-gen"
12+
"golang.org/x/sync/errgroup"
1213

1314
"github.com/filecoin-project/go-amt-ipld/v4/internal"
1415
)
@@ -348,6 +349,203 @@ func (n *node) forEachAt(ctx context.Context, bs cbor.IpldStore, bitWidth uint,
348349
return nil
349350
}
350351

352+
type descentContext struct {
353+
height int
354+
offset uint64
355+
}
356+
357+
type child struct {
358+
link *link
359+
descentContext
360+
}
361+
362+
type listChildren struct {
363+
children []child
364+
}
365+
366+
func (n *node) forEachAtParallel(ctx context.Context, bs cbor.IpldStore, bitWidth uint, height int, start, offset uint64, cb func(uint64, *cbg.Deferred) error, concurrency int) error {
367+
// Setup synchronization
368+
grp, errGrpCtx := errgroup.WithContext(ctx)
369+
// Input and output queues for workers.
370+
feed := make(chan *listChildren)
371+
out := make(chan *listChildren)
372+
done := make(chan struct{})
373+
374+
for i := 0; i < concurrency; i++ {
375+
grp.Go(func() error {
376+
for childrenList := range feed {
377+
linksToVisit := make([]cid.Cid, 0, len(childrenList.children))
378+
linksToVisitContext := make([]descentContext, 0, len(childrenList.children))
379+
cachedNodes := make([]*node, 0, len(childrenList.children))
380+
cachedNodesContext := make([]descentContext, 0, len(childrenList.children))
381+
for _, child := range childrenList.children {
382+
if child.link.cached != nil {
383+
cachedNodes = append(cachedNodes, child.link.cached)
384+
cachedNodesContext = append(cachedNodesContext, child.descentContext)
385+
} else if child.link.cid != cid.Undef {
386+
linksToVisit = append(linksToVisit, child.link.cid)
387+
linksToVisitContext = append(linksToVisitContext, child.descentContext)
388+
} else {
389+
return fmt.Errorf("invalid child")
390+
}
391+
}
392+
393+
dserv := bs.(*cbor.BatchOpIpldStore)
394+
nodes := make([]interface{}, len(linksToVisit))
395+
for j := 0; j < len(linksToVisit); j++ {
396+
nodes[j] = new(internal.Node)
397+
}
398+
cursorChan, missingCIDs, err := dserv.GetMany(errGrpCtx, linksToVisit, nodes)
399+
if err != nil {
400+
return err
401+
}
402+
if len(missingCIDs) != 0 {
403+
return fmt.Errorf("GetMany returned an incomplete result set. The set is missing these CIDs: %+v", missingCIDs)
404+
}
405+
for cursor := range cursorChan {
406+
if cursor.Err != nil {
407+
return cursor.Err
408+
}
409+
internalNextNode, ok := nodes[cursor.Index].(*internal.Node)
410+
if !ok {
411+
return fmt.Errorf("expected node, got %T", nodes[cursor.Index])
412+
}
413+
nextNode, err := newNode(*internalNextNode, bitWidth, false, linksToVisitContext[cursor.Index].height == 0)
414+
if err != nil {
415+
return err
416+
}
417+
nextChildren, err := nextNode.walkChildren(ctx, bitWidth, linksToVisitContext[cursor.Index].height, start, linksToVisitContext[cursor.Index].offset, cb)
418+
if err != nil {
419+
return err
420+
}
421+
select {
422+
case <-errGrpCtx.Done():
423+
return nil
424+
default:
425+
if nextChildren != nil {
426+
out <- nextChildren
427+
}
428+
}
429+
}
430+
for j, cachedNode := range cachedNodes {
431+
nextChildren, err := cachedNode.walkChildren(ctx, bitWidth, cachedNodesContext[j].height, start, cachedNodesContext[j].offset, cb)
432+
if err != nil {
433+
return err
434+
}
435+
select {
436+
case <-errGrpCtx.Done():
437+
return nil
438+
default:
439+
if nextChildren != nil {
440+
out <- nextChildren
441+
}
442+
}
443+
}
444+
445+
select {
446+
case done <- struct{}{}:
447+
case <-errGrpCtx.Done():
448+
}
449+
}
450+
return nil
451+
})
452+
}
453+
454+
send := feed
455+
var todoQueue []*listChildren
456+
var inProgress int
457+
458+
// start the walk
459+
children, err := n.walkChildren(ctx, bitWidth, height, start, offset, cb)
460+
// if we hit an error or there are no children, then we're done
461+
if err != nil || children == nil {
462+
close(feed)
463+
grp.Wait()
464+
return err
465+
}
466+
next := children
467+
468+
dispatcherLoop:
469+
for {
470+
select {
471+
case send <- next:
472+
inProgress++
473+
if len(todoQueue) > 0 {
474+
next = todoQueue[0]
475+
todoQueue = todoQueue[1:]
476+
} else {
477+
next = nil
478+
send = nil
479+
}
480+
case <-done:
481+
inProgress--
482+
if inProgress == 0 && next == nil {
483+
break dispatcherLoop
484+
}
485+
case nextNodes := <-out:
486+
if next == nil {
487+
next = nextNodes
488+
send = feed
489+
} else {
490+
todoQueue = append(todoQueue, nextNodes)
491+
}
492+
case <-errGrpCtx.Done():
493+
break dispatcherLoop
494+
}
495+
}
496+
close(feed)
497+
return grp.Wait()
498+
}
499+
500+
func (n *node) walkChildren(ctx context.Context, bitWidth uint, height int, start, offset uint64, cb func(uint64, *cbg.Deferred) error) (*listChildren, error) {
501+
if height == 0 {
502+
// height=0 means we're at leaf nodes and get to use our callback
503+
for i, v := range n.values {
504+
if v != nil {
505+
ix := offset + uint64(i)
506+
if ix < start {
507+
// if we're here, 'start' is probably somewhere in the
508+
// middle of this node's elements
509+
continue
510+
}
511+
512+
// use 'offset' to determine the actual index for this element, it
513+
// tells us how distant we are from the left-most leaf node
514+
if err := cb(offset+uint64(i), v); err != nil {
515+
return nil, err
516+
}
517+
}
518+
}
519+
520+
return nil, nil
521+
}
522+
children := make([]child, 0, len(n.links))
523+
524+
subCount := nodesForHeight(bitWidth, height)
525+
for i, ln := range n.links {
526+
if ln == nil {
527+
continue
528+
}
529+
530+
// 'offs' tells us the index of the left-most element of the subtree defined
531+
// by 'sub'
532+
offs := offset + (uint64(i) * subCount)
533+
nextOffs := offs + subCount
534+
// nextOffs > offs checks for overflow at MaxIndex (where the next offset wraps back
535+
// to 0).
536+
if nextOffs >= offs && start >= nextOffs {
537+
// if we're here, 'start' lets us skip this entire sub-tree
538+
continue
539+
}
540+
children = append(children, child{ln, descentContext{
541+
height: height - 1,
542+
offset: offs,
543+
}})
544+
}
545+
546+
return &listChildren{children: children}, nil
547+
}
548+
351549
var errNoVals = fmt.Errorf("no values")
352550

353551
// Recursive implementation of FirstSetIndex that's performed on the left-most
@@ -494,6 +692,76 @@ func (n *node) flush(ctx context.Context, bs cbor.IpldStore, bitWidth uint, heig
494692
return nd, nil
495693
}
496694

695+
// compact converts a node into its internal.Node representation
696+
func (n *node) compact(ctx context.Context, bitWidth uint, height int) (*internal.Node, error) {
697+
nd := new(internal.Node)
698+
nd.Bmap = make([]byte, bmapBytes(bitWidth))
699+
700+
if height == 0 {
701+
// leaf node, we're storing values in this node
702+
for i, val := range n.values {
703+
if val == nil {
704+
continue
705+
}
706+
nd.Values = append(nd.Values, val)
707+
// set the bit in the bitmap for this position to indicate its presence
708+
nd.Bmap[i/8] |= 1 << (uint(i) % 8)
709+
}
710+
return nd, nil
711+
}
712+
713+
// non-leaf node, we're only storing Links in this node
714+
for i, ln := range n.links {
715+
if ln == nil {
716+
continue
717+
}
718+
if ln.dirty {
719+
if ln.cached == nil {
720+
return nil, fmt.Errorf("expected dirty node to be cached")
721+
}
722+
subn, err := ln.cached.compact(ctx, bitWidth, height-1)
723+
if err != nil {
724+
return nil, err
725+
}
726+
c, err := calcCID(subn)
727+
if err != nil {
728+
return nil, err
729+
}
730+
731+
ln.cid = c
732+
ln.dirty = false
733+
}
734+
nd.Links = append(nd.Links, ln.cid)
735+
// set the bit in the bitmap for this position to indicate its presence
736+
nd.Bmap[i/8] |= 1 << (uint(i) % 8)
737+
}
738+
739+
return nd, nil
740+
}
741+
742+
func calcCID(node cbg.CBORMarshaler) (cid.Cid, error) {
743+
mhType := cbor.DefaultMultihash
744+
mhLen := -1
745+
codec := uint64(cid.DagCBOR)
746+
747+
buf := new(bytes.Buffer)
748+
if err := node.MarshalCBOR(buf); err != nil {
749+
return cid.Undef, err
750+
}
751+
752+
pref := cid.Prefix{
753+
Codec: codec,
754+
MhType: mhType,
755+
MhLength: mhLen,
756+
Version: 1,
757+
}
758+
c, err := pref.Sum(buf.Bytes())
759+
if err != nil {
760+
return cid.Undef, err
761+
}
762+
return c, nil
763+
}
764+
497765
func (n *node) setLink(bitWidth uint, i uint64, l *link) {
498766
if n.links == nil {
499767
if l == nil {

0 commit comments

Comments
 (0)