|
9 | 9 | "github.com/ipfs/go-cid"
|
10 | 10 | cbor "github.com/ipfs/go-ipld-cbor"
|
11 | 11 | cbg "github.com/whyrusleeping/cbor-gen"
|
| 12 | + "golang.org/x/sync/errgroup" |
12 | 13 |
|
13 | 14 | "github.com/filecoin-project/go-amt-ipld/v4/internal"
|
14 | 15 | )
|
@@ -348,6 +349,203 @@ func (n *node) forEachAt(ctx context.Context, bs cbor.IpldStore, bitWidth uint,
|
348 | 349 | return nil
|
349 | 350 | }
|
350 | 351 |
|
| 352 | +type descentContext struct { |
| 353 | + height int |
| 354 | + offset uint64 |
| 355 | +} |
| 356 | + |
| 357 | +type child struct { |
| 358 | + link *link |
| 359 | + descentContext |
| 360 | +} |
| 361 | + |
| 362 | +type listChildren struct { |
| 363 | + children []child |
| 364 | +} |
| 365 | + |
| 366 | +func (n *node) forEachAtParallel(ctx context.Context, bs cbor.IpldStore, bitWidth uint, height int, start, offset uint64, cb func(uint64, *cbg.Deferred) error, concurrency int) error { |
| 367 | + // Setup synchronization |
| 368 | + grp, errGrpCtx := errgroup.WithContext(ctx) |
| 369 | + // Input and output queues for workers. |
| 370 | + feed := make(chan *listChildren) |
| 371 | + out := make(chan *listChildren) |
| 372 | + done := make(chan struct{}) |
| 373 | + |
| 374 | + for i := 0; i < concurrency; i++ { |
| 375 | + grp.Go(func() error { |
| 376 | + for childrenList := range feed { |
| 377 | + linksToVisit := make([]cid.Cid, 0, len(childrenList.children)) |
| 378 | + linksToVisitContext := make([]descentContext, 0, len(childrenList.children)) |
| 379 | + cachedNodes := make([]*node, 0, len(childrenList.children)) |
| 380 | + cachedNodesContext := make([]descentContext, 0, len(childrenList.children)) |
| 381 | + for _, child := range childrenList.children { |
| 382 | + if child.link.cached != nil { |
| 383 | + cachedNodes = append(cachedNodes, child.link.cached) |
| 384 | + cachedNodesContext = append(cachedNodesContext, child.descentContext) |
| 385 | + } else if child.link.cid != cid.Undef { |
| 386 | + linksToVisit = append(linksToVisit, child.link.cid) |
| 387 | + linksToVisitContext = append(linksToVisitContext, child.descentContext) |
| 388 | + } else { |
| 389 | + return fmt.Errorf("invalid child") |
| 390 | + } |
| 391 | + } |
| 392 | + |
| 393 | + dserv := bs.(*cbor.BatchOpIpldStore) |
| 394 | + nodes := make([]interface{}, len(linksToVisit)) |
| 395 | + for j := 0; j < len(linksToVisit); j++ { |
| 396 | + nodes[j] = new(internal.Node) |
| 397 | + } |
| 398 | + cursorChan, missingCIDs, err := dserv.GetMany(errGrpCtx, linksToVisit, nodes) |
| 399 | + if err != nil { |
| 400 | + return err |
| 401 | + } |
| 402 | + if len(missingCIDs) != 0 { |
| 403 | + return fmt.Errorf("GetMany returned an incomplete result set. The set is missing these CIDs: %+v", missingCIDs) |
| 404 | + } |
| 405 | + for cursor := range cursorChan { |
| 406 | + if cursor.Err != nil { |
| 407 | + return cursor.Err |
| 408 | + } |
| 409 | + internalNextNode, ok := nodes[cursor.Index].(*internal.Node) |
| 410 | + if !ok { |
| 411 | + return fmt.Errorf("expected node, got %T", nodes[cursor.Index]) |
| 412 | + } |
| 413 | + nextNode, err := newNode(*internalNextNode, bitWidth, false, linksToVisitContext[cursor.Index].height == 0) |
| 414 | + if err != nil { |
| 415 | + return err |
| 416 | + } |
| 417 | + nextChildren, err := nextNode.walkChildren(ctx, bitWidth, linksToVisitContext[cursor.Index].height, start, linksToVisitContext[cursor.Index].offset, cb) |
| 418 | + if err != nil { |
| 419 | + return err |
| 420 | + } |
| 421 | + select { |
| 422 | + case <-errGrpCtx.Done(): |
| 423 | + return nil |
| 424 | + default: |
| 425 | + if nextChildren != nil { |
| 426 | + out <- nextChildren |
| 427 | + } |
| 428 | + } |
| 429 | + } |
| 430 | + for j, cachedNode := range cachedNodes { |
| 431 | + nextChildren, err := cachedNode.walkChildren(ctx, bitWidth, cachedNodesContext[j].height, start, cachedNodesContext[j].offset, cb) |
| 432 | + if err != nil { |
| 433 | + return err |
| 434 | + } |
| 435 | + select { |
| 436 | + case <-errGrpCtx.Done(): |
| 437 | + return nil |
| 438 | + default: |
| 439 | + if nextChildren != nil { |
| 440 | + out <- nextChildren |
| 441 | + } |
| 442 | + } |
| 443 | + } |
| 444 | + |
| 445 | + select { |
| 446 | + case done <- struct{}{}: |
| 447 | + case <-errGrpCtx.Done(): |
| 448 | + } |
| 449 | + } |
| 450 | + return nil |
| 451 | + }) |
| 452 | + } |
| 453 | + |
| 454 | + send := feed |
| 455 | + var todoQueue []*listChildren |
| 456 | + var inProgress int |
| 457 | + |
| 458 | + // start the walk |
| 459 | + children, err := n.walkChildren(ctx, bitWidth, height, start, offset, cb) |
| 460 | + // if we hit an error or there are no children, then we're done |
| 461 | + if err != nil || children == nil { |
| 462 | + close(feed) |
| 463 | + grp.Wait() |
| 464 | + return err |
| 465 | + } |
| 466 | + next := children |
| 467 | + |
| 468 | +dispatcherLoop: |
| 469 | + for { |
| 470 | + select { |
| 471 | + case send <- next: |
| 472 | + inProgress++ |
| 473 | + if len(todoQueue) > 0 { |
| 474 | + next = todoQueue[0] |
| 475 | + todoQueue = todoQueue[1:] |
| 476 | + } else { |
| 477 | + next = nil |
| 478 | + send = nil |
| 479 | + } |
| 480 | + case <-done: |
| 481 | + inProgress-- |
| 482 | + if inProgress == 0 && next == nil { |
| 483 | + break dispatcherLoop |
| 484 | + } |
| 485 | + case nextNodes := <-out: |
| 486 | + if next == nil { |
| 487 | + next = nextNodes |
| 488 | + send = feed |
| 489 | + } else { |
| 490 | + todoQueue = append(todoQueue, nextNodes) |
| 491 | + } |
| 492 | + case <-errGrpCtx.Done(): |
| 493 | + break dispatcherLoop |
| 494 | + } |
| 495 | + } |
| 496 | + close(feed) |
| 497 | + return grp.Wait() |
| 498 | +} |
| 499 | + |
| 500 | +func (n *node) walkChildren(ctx context.Context, bitWidth uint, height int, start, offset uint64, cb func(uint64, *cbg.Deferred) error) (*listChildren, error) { |
| 501 | + if height == 0 { |
| 502 | + // height=0 means we're at leaf nodes and get to use our callback |
| 503 | + for i, v := range n.values { |
| 504 | + if v != nil { |
| 505 | + ix := offset + uint64(i) |
| 506 | + if ix < start { |
| 507 | + // if we're here, 'start' is probably somewhere in the |
| 508 | + // middle of this node's elements |
| 509 | + continue |
| 510 | + } |
| 511 | + |
| 512 | + // use 'offset' to determine the actual index for this element, it |
| 513 | + // tells us how distant we are from the left-most leaf node |
| 514 | + if err := cb(offset+uint64(i), v); err != nil { |
| 515 | + return nil, err |
| 516 | + } |
| 517 | + } |
| 518 | + } |
| 519 | + |
| 520 | + return nil, nil |
| 521 | + } |
| 522 | + children := make([]child, 0, len(n.links)) |
| 523 | + |
| 524 | + subCount := nodesForHeight(bitWidth, height) |
| 525 | + for i, ln := range n.links { |
| 526 | + if ln == nil { |
| 527 | + continue |
| 528 | + } |
| 529 | + |
| 530 | + // 'offs' tells us the index of the left-most element of the subtree defined |
| 531 | + // by 'sub' |
| 532 | + offs := offset + (uint64(i) * subCount) |
| 533 | + nextOffs := offs + subCount |
| 534 | + // nextOffs > offs checks for overflow at MaxIndex (where the next offset wraps back |
| 535 | + // to 0). |
| 536 | + if nextOffs >= offs && start >= nextOffs { |
| 537 | + // if we're here, 'start' lets us skip this entire sub-tree |
| 538 | + continue |
| 539 | + } |
| 540 | + children = append(children, child{ln, descentContext{ |
| 541 | + height: height - 1, |
| 542 | + offset: offs, |
| 543 | + }}) |
| 544 | + } |
| 545 | + |
| 546 | + return &listChildren{children: children}, nil |
| 547 | +} |
| 548 | + |
351 | 549 | var errNoVals = fmt.Errorf("no values")
|
352 | 550 |
|
353 | 551 | // Recursive implementation of FirstSetIndex that's performed on the left-most
|
@@ -494,6 +692,76 @@ func (n *node) flush(ctx context.Context, bs cbor.IpldStore, bitWidth uint, heig
|
494 | 692 | return nd, nil
|
495 | 693 | }
|
496 | 694 |
|
| 695 | +// compact converts a node into its internal.Node representation |
| 696 | +func (n *node) compact(ctx context.Context, bitWidth uint, height int) (*internal.Node, error) { |
| 697 | + nd := new(internal.Node) |
| 698 | + nd.Bmap = make([]byte, bmapBytes(bitWidth)) |
| 699 | + |
| 700 | + if height == 0 { |
| 701 | + // leaf node, we're storing values in this node |
| 702 | + for i, val := range n.values { |
| 703 | + if val == nil { |
| 704 | + continue |
| 705 | + } |
| 706 | + nd.Values = append(nd.Values, val) |
| 707 | + // set the bit in the bitmap for this position to indicate its presence |
| 708 | + nd.Bmap[i/8] |= 1 << (uint(i) % 8) |
| 709 | + } |
| 710 | + return nd, nil |
| 711 | + } |
| 712 | + |
| 713 | + // non-leaf node, we're only storing Links in this node |
| 714 | + for i, ln := range n.links { |
| 715 | + if ln == nil { |
| 716 | + continue |
| 717 | + } |
| 718 | + if ln.dirty { |
| 719 | + if ln.cached == nil { |
| 720 | + return nil, fmt.Errorf("expected dirty node to be cached") |
| 721 | + } |
| 722 | + subn, err := ln.cached.compact(ctx, bitWidth, height-1) |
| 723 | + if err != nil { |
| 724 | + return nil, err |
| 725 | + } |
| 726 | + c, err := calcCID(subn) |
| 727 | + if err != nil { |
| 728 | + return nil, err |
| 729 | + } |
| 730 | + |
| 731 | + ln.cid = c |
| 732 | + ln.dirty = false |
| 733 | + } |
| 734 | + nd.Links = append(nd.Links, ln.cid) |
| 735 | + // set the bit in the bitmap for this position to indicate its presence |
| 736 | + nd.Bmap[i/8] |= 1 << (uint(i) % 8) |
| 737 | + } |
| 738 | + |
| 739 | + return nd, nil |
| 740 | +} |
| 741 | + |
| 742 | +func calcCID(node cbg.CBORMarshaler) (cid.Cid, error) { |
| 743 | + mhType := cbor.DefaultMultihash |
| 744 | + mhLen := -1 |
| 745 | + codec := uint64(cid.DagCBOR) |
| 746 | + |
| 747 | + buf := new(bytes.Buffer) |
| 748 | + if err := node.MarshalCBOR(buf); err != nil { |
| 749 | + return cid.Undef, err |
| 750 | + } |
| 751 | + |
| 752 | + pref := cid.Prefix{ |
| 753 | + Codec: codec, |
| 754 | + MhType: mhType, |
| 755 | + MhLength: mhLen, |
| 756 | + Version: 1, |
| 757 | + } |
| 758 | + c, err := pref.Sum(buf.Bytes()) |
| 759 | + if err != nil { |
| 760 | + return cid.Undef, err |
| 761 | + } |
| 762 | + return c, nil |
| 763 | +} |
| 764 | + |
497 | 765 | func (n *node) setLink(bitWidth uint, i uint64, l *link) {
|
498 | 766 | if n.links == nil {
|
499 | 767 | if l == nil {
|
|
0 commit comments