From ba0cc479371b49b6559f2a5284bec0fc3aca21cd Mon Sep 17 00:00:00 2001 From: Adin Schmahmann Date: Tue, 13 May 2025 12:00:35 -0400 Subject: [PATCH] feat: rework fullrt to use a caching routing table --- cachert/rt.go | 145 +++++++++ cachert/rt_test.go | 91 ++++++ fullrt/dht.go | 505 +++++++++++++++++++++++--------- go.mod | 3 + go.sum | 6 + internal/net/message_manager.go | 4 + lookup.go | 8 +- lookup_optim.go | 6 +- query.go | 164 ++++++++--- routing.go | 14 +- 10 files changed, 748 insertions(+), 198 deletions(-) create mode 100644 cachert/rt.go create mode 100644 cachert/rt_test.go diff --git a/cachert/rt.go b/cachert/rt.go new file mode 100644 index 000000000..138717be8 --- /dev/null +++ b/cachert/rt.go @@ -0,0 +1,145 @@ +package cachert + +import ( + "time" + + "github.com/emirpasic/gods/v2/trees/avltree" +) + +type Key = string + +type KeySearchRange struct { + KeyLower Key + KeyUpper Key + Time time.Time +} + +type RT struct { + at *avltree.Tree[Key, *KeySearchRange] + ranges map[*KeySearchRange]struct{} +} + +func NewRT() *RT { + return &RT{ + at: avltree.New[Key, *KeySearchRange](), + ranges: make(map[*KeySearchRange]struct{}), + } +} + +func (t *RT) GetRanges() []KeySearchRange { + ranges := make([]KeySearchRange, 0, len(t.ranges)) + for k := range t.ranges { + if k.KeyLower > k.KeyUpper { + panic("lower key must be less than upper key") + } + ranges = append(ranges, KeySearchRange{ + KeyLower: k.KeyLower, + KeyUpper: k.KeyUpper, + Time: k.Time, + }) + } + return ranges +} + +func (t *RT) InsertRange(lower, upper Key, expiration time.Time) { + if lower > upper { + panic("lower key must be less than upper key") + } + if len([]byte(lower)) != len([]byte(upper)) && len([]byte(lower)) != 32 { + panic("lower and upper keys must be 32 bytes") + } + + r := &KeySearchRange{lower, upper, expiration} + + // Find the range that starts the latest, but before the start of this range + f, ok := t.at.Floor(r.KeyLower) + // If there are no nodes that start before this one + type rr struct { + k Key + r *KeySearchRange + } + + var rangesToReplace []rr + + if !ok { + f = t.at.Left() + } else { + if f.Value.KeyUpper > r.KeyLower { + // Truncate the previous range to stop where this one starts + f.Value.KeyUpper = r.KeyLower + rangesToReplace = append(rangesToReplace, rr{k: f.Key, r: f.Value}) + } + } + // Insert this one + t.at.Put(r.KeyLower, r) + t.ranges[r] = struct{}{} + + // Continue through subsequent ranges seeing if any get clobbered by this one + for f := f.Next(); f != nil; f = f.Next() { + if f.Value == r { + // This is the same range, so we can skip it + continue + } + if f.Value.KeyLower < r.KeyUpper { + if f.Value.KeyUpper <= r.KeyUpper { + rangesToReplace = append(rangesToReplace, rr{k: f.Key, r: nil}) + delete(t.ranges, f.Value) + } else { + f.Value.KeyUpper = r.KeyUpper + rangesToReplace = append(rangesToReplace, rr{k: f.Key, r: f.Value}) + } + } else { + break + } + } + + for _, rng := range rangesToReplace { + t.at.Remove(rng.k) + if rng.r == nil { + continue + } + if rng.r.KeyLower == rng.r.KeyUpper { + // If the range is empty, we don't want to keep it + delete(t.ranges, rng.r) + } else { + t.at.Put(rng.r.KeyLower, rng.r) + } + } +} + +func (t *RT) RangeIsCovered(lower, upper Key) bool { + if lower > upper { + panic("lower key must be less than upper key") + } + if len([]byte(lower)) != len([]byte(upper)) && len([]byte(lower)) != 32 { + panic("lower and upper keys must be 32 bytes") + } + + lowest := lower + f, ok := t.at.Floor(lowest) + if !ok { + return false + } + + for ; f != nil; f = f.Next() { + if f.Value.KeyLower > lowest { + return false + } + + if f.Value.KeyUpper >= upper { + return true + } + + lowest = f.Value.KeyUpper + } + return false +} + +func (t *RT) CollectGarbage(ti time.Time) { + for r, _ := range t.ranges { + if r.Time.Before(ti) { + t.at.Remove(r.KeyLower) + delete(t.ranges, r) + } + } +} diff --git a/cachert/rt_test.go b/cachert/rt_test.go new file mode 100644 index 000000000..25c791823 --- /dev/null +++ b/cachert/rt_test.go @@ -0,0 +1,91 @@ +package cachert + +import ( + "bytes" + "testing" + "time" +) + +func TestRT_InsertRangeAndGetRanges(t *testing.T) { + rt := NewRT() + + // Valid 32-byte keys + key1 := bytes.Repeat([]byte("a"), 32) + key2 := bytes.Repeat([]byte("b"), 32) + key3 := bytes.Repeat([]byte("c"), 32) + key4 := bytes.Repeat([]byte("d"), 32) + + // Insert non-overlapping range + rt.InsertRange(string(key1), string(key2), time.Now().Add(1*time.Hour)) + ranges := rt.GetRanges() + if len(ranges) != 1 { + t.Fatalf("expected 1 range, got %d", len(ranges)) + } + + // Insert overlapping range + rt.InsertRange(string(key2), string(key3), time.Now().Add(2*time.Hour)) + ranges = rt.GetRanges() + if len(ranges) != 2 { + t.Fatalf("expected 2 ranges, got %d", len(ranges)) + } + + // Insert range that merges with existing ranges + rt.InsertRange(string(key1), string(key4), time.Now().Add(3*time.Hour)) + ranges = rt.GetRanges() + if len(ranges) != 1 { + t.Fatalf("expected 1 merged range, got %d", len(ranges)) + } + if ranges[0].KeyLower != string(key1) || ranges[0].KeyUpper != string(key4) { + t.Fatalf("merged range has incorrect bounds: %+v", ranges[0]) + } +} + +func TestRT_RangeIsCovered(t *testing.T) { + rt := NewRT() + + // Valid 32-byte keys + key1 := bytes.Repeat([]byte("a"), 32) + key2 := bytes.Repeat([]byte("b"), 32) + key3 := bytes.Repeat([]byte("c"), 32) + key4 := bytes.Repeat([]byte("d"), 32) + + // Insert ranges + rt.InsertRange(string(key1), string(key2), time.Now().Add(1*time.Hour)) + rt.InsertRange(string(key2), string(key3), time.Now().Add(2*time.Hour)) + + // Check covered range + if !rt.RangeIsCovered(string(key1), string(key3)) { + t.Fatalf("expected range [%s, %s] to be covered", key1, key3) + } + + // Check partially covered range + if rt.RangeIsCovered(string(key1), string(key4)) { + t.Fatalf("expected range [%s, %s] to not be fully covered", key1, key4) + } + + // Check uncovered range + if rt.RangeIsCovered(string(key3), string(key4)) { + t.Fatalf("expected range [%s, %s] to not be covered", key3, key4) + } +} + +func TestRT_CollectGarbage(t *testing.T) { + rt := NewRT() + + // Valid 32-byte keys + key1 := bytes.Repeat([]byte("a"), 32) + key2 := bytes.Repeat([]byte("b"), 32) + + // Insert range with expiration + expiredTime := time.Now().Add(-1 * time.Hour) + rt.InsertRange(string(key1), string(key2), expiredTime) + + // Collect garbage + rt.CollectGarbage(time.Now()) + + // Validate range is removed + ranges := rt.GetRanges() + if len(ranges) != 0 { + t.Fatalf("expected 0 ranges after garbage collection, got %d", len(ranges)) + } +} diff --git a/fullrt/dht.go b/fullrt/dht.go index a630d995e..149899c57 100644 --- a/fullrt/dht.go +++ b/fullrt/dht.go @@ -3,9 +3,12 @@ package fullrt import ( "bytes" "context" - "errors" "fmt" + "github.com/libp2p/go-libp2p-kad-dht/cachert" + "github.com/libp2p/go-libp2p-kad-dht/qpeerset" + "math/big" "math/rand" + "slices" "sync" "sync/atomic" "time" @@ -21,7 +24,6 @@ import ( "github.com/libp2p/go-libp2p/core/peerstore" "github.com/libp2p/go-libp2p/core/protocol" "github.com/libp2p/go-libp2p/core/routing" - swarm "github.com/libp2p/go-libp2p/p2p/net/swarm" "github.com/gogo/protobuf/proto" u "github.com/ipfs/boxo/util" @@ -85,8 +87,11 @@ type FullRT struct { peerAddrsLk sync.RWMutex peerAddrs map[peer.ID][]multiaddr.Multiaddr + peerAddrsLastModifier map[peer.ID]time.Time bootstrapPeers []*peer.AddrInfo + crt *cachert.RT + crtLk sync.RWMutex bucketSize int @@ -196,12 +201,60 @@ func NewFullRT(h host.Host, protocolPrefix protocol.ID, options ...Option) (*Ful crawlerInterval: fullrtcfg.crawlInterval, bulkSendParallelism: fullrtcfg.bulkSendParallelism, + crt: cachert.NewRT(), self: self, } rt.wg.Add(1) - go rt.runCrawler(ctx) + rt.addBootstrapPeers() + + go func() { + t := time.NewTicker(fullrtcfg.crawlInterval) + for { + select { + case <-ctx.Done(): + return + case <-t.C: + rt.crtLk.Lock() + rt.crt.CollectGarbage(time.Now().Add(-time.Hour)) + rt.crtLk.Unlock() + + var peersToRemove []peer.ID + var keysToRemove []string + + rt.peerAddrsLk.RLock() + for k := range rt.peerAddrs { + if time.Since(rt.peerAddrsLastModifier[k]) > time.Hour { + peersToRemove = append(peersToRemove, k) + } + } + rt.peerAddrsLk.RUnlock() + + for _, k := range peersToRemove { + if _, ok := rt.keyToPeerMap[string(kb.ConvertPeerID(k))]; ok { + keysToRemove = append(keysToRemove, string(kb.ConvertPeerID(k))) + } + } + + rt.peerAddrsLk.Lock() + rt.kMapLk.Lock() + rt.rtLk.Lock() + for _, k := range keysToRemove { + delete(rt.keyToPeerMap, k) + rt.rt.Remove(kadkey.KbucketIDToKey(kb.ID(k))) + } + + for _, k := range peersToRemove { + delete(rt.peerAddrs, k) + delete(rt.peerAddrsLastModifier, k) + } + rt.rtLk.Unlock() + rt.kMapLk.Unlock() + rt.peerAddrsLk.Unlock() + } + } + }() return rt, nil } @@ -211,6 +264,12 @@ type crawlVal struct { key kadkey.Key } +func (dht *FullRT) GetRanges() []cachert.KeySearchRange { + dht.crtLk.RLock() + defer dht.crtLk.RUnlock() + return dht.crt.GetRanges() +} + func (dht *FullRT) TriggerRefresh(ctx context.Context) error { select { case <-ctx.Done(): @@ -234,125 +293,17 @@ func (dht *FullRT) Stat() map[string]peer.ID { } func (dht *FullRT) Ready() bool { - dht.rtLk.RLock() - lastCrawlTime := dht.lastCrawlTime - dht.rtLk.RUnlock() - - if time.Since(lastCrawlTime) > dht.crawlerInterval { - return false - } - - // TODO: This function needs to be better defined. Perhaps based on going through the peer map and seeing when the - // last time we were connected to any of them was. - dht.peerAddrsLk.RLock() - rtSize := len(dht.keyToPeerMap) - dht.peerAddrsLk.RUnlock() - - return rtSize > len(dht.bootstrapPeers)+1 + return true } func (dht *FullRT) Host() host.Host { return dht.h } -func (dht *FullRT) runCrawler(ctx context.Context) { - defer dht.wg.Done() - t := time.NewTicker(dht.crawlerInterval) - - m := make(map[peer.ID]*crawlVal) - mxLk := sync.Mutex{} - - initialTrigger := make(chan struct{}, 1) - initialTrigger <- struct{}{} - - for { - select { - case <-t.C: - case <-initialTrigger: - case <-dht.triggerRefresh: - case <-ctx.Done(): - return - } - - var addrs []*peer.AddrInfo - dht.peerAddrsLk.Lock() - for k := range m { - addrs = append(addrs, &peer.AddrInfo{ID: k}) // Addrs: v.addrs - } - - addrs = append(addrs, dht.bootstrapPeers...) - dht.peerAddrsLk.Unlock() - - for k := range m { - delete(m, k) - } - - start := time.Now() - limitErrOnce := sync.Once{} - dht.crawler.Run(ctx, addrs, - func(p peer.ID, rtPeers []*peer.AddrInfo) { - conns := dht.h.Network().ConnsToPeer(p) - var addrs []multiaddr.Multiaddr - for _, conn := range conns { - addr := conn.RemoteMultiaddr() - addrs = append(addrs, addr) - } - - if len(addrs) == 0 { - logger.Debugf("no connections to %v after successful query. keeping addresses from the peerstore", p) - addrs = dht.h.Peerstore().Addrs(p) - } - - keep := kaddht.PublicRoutingTableFilter(dht, p) - if !keep { - return - } - - mxLk.Lock() - defer mxLk.Unlock() - m[p] = &crawlVal{ - addrs: addrs, - } - }, - func(p peer.ID, err error) { - dialErr, ok := err.(*swarm.DialError) - if ok { - for _, transportErr := range dialErr.DialErrors { - if errors.Is(transportErr.Cause, network.ErrResourceLimitExceeded) { - limitErrOnce.Do(func() { logger.Errorf(rtRefreshLimitsMsg) }) - } - } - } - // note that DialError implements Unwrap() which returns the Cause, so this covers that case - if errors.Is(err, network.ErrResourceLimitExceeded) { - limitErrOnce.Do(func() { logger.Errorf(rtRefreshLimitsMsg) }) - } - }) - dur := time.Since(start) - logger.Infof("crawl took %v", dur) - - peerAddrs := make(map[peer.ID][]multiaddr.Multiaddr) - kPeerMap := make(map[string]peer.ID) - newRt := trie.New() - for k, v := range m { - v.key = kadkey.KbucketIDToKey(kb.ConvertPeerID(k)) - peerAddrs[k] = v.addrs - kPeerMap[string(v.key)] = k - newRt.Add(v.key) - } - - dht.peerAddrsLk.Lock() - dht.peerAddrs = peerAddrs - dht.peerAddrsLk.Unlock() - - dht.kMapLk.Lock() - dht.keyToPeerMap = kPeerMap - dht.kMapLk.Unlock() - - dht.rtLk.Lock() - dht.rt = newRt - dht.lastCrawlTime = time.Now() - dht.rtLk.Unlock() +func (dht *FullRT) addBootstrapPeers() { + for _, ai := range dht.bootstrapPeers { + dht.h.Peerstore().AddAddrs(ai.ID, ai.Addrs, peerstore.PermanentAddrTTL) + dht.ValidPeerFound(ai.ID) } } @@ -432,6 +383,77 @@ func workers(numWorkers int, fn func(interface{}), inputs <-chan interface{}) { } } +var _ kaddht.DhtQueryIface = (*FullRT)(nil) + +func (dht *FullRT) Self() peer.ID { + return dht.self +} + +func (dht *FullRT) BucketSize() int { + return dht.bucketSize +} + +func (dht *FullRT) Beta() int { + return 3 +} + +func (dht *FullRT) Alpha() int { + return 10 +} + +func (dht *FullRT) DialPeer(ctx context.Context, id peer.ID) error { + // TODO: logging improvements + return dht.h.Connect(ctx, peer.AddrInfo{ID: id}) +} + +func (dht *FullRT) PeerStoppedDHT(id peer.ID) { + k := kb.ConvertPeerID(id) + dht.rtLk.Lock() + dht.rt.Remove(kadkey.KbucketIDToKey(k)) + dht.rtLk.Unlock() + + dht.kMapLk.Lock() + delete(dht.keyToPeerMap,string(k)) + dht.kMapLk.Unlock() + + dht.peerAddrsLk.Lock() + delete(dht.peerAddrs, id) + delete(dht.peerAddrsLastModifier, id) + dht.peerAddrsLk.Unlock() +} + +func (dht *FullRT) UpdateLastUsefulAt(id peer.ID, t time.Time) bool { + return false +} + +func (dht *FullRT) ValidPeerFound(id peer.ID) { + k := kb.ConvertPeerID(id) + dht.rtLk.Lock() + dht.rt.Add(kadkey.KbucketIDToKey(k)) + dht.rtLk.Unlock() + + dht.kMapLk.Lock() + dht.keyToPeerMap[string(k)] = id + dht.kMapLk.Unlock() + + dht.peerAddrsLk.Lock() + dht.peerAddrs[id] = dht.Peerstore().Addrs(id) + dht.peerAddrsLastModifier[id] = time.Now() + dht.peerAddrsLk.Unlock() +} + +func (dht *FullRT) QueryPeerFilter(i interface{}, info peer.AddrInfo) bool { + return kaddht.PublicQueryFilter(i, info) +} + +func (dht *FullRT) MaybeAddAddrs(p peer.ID, addrs []multiaddr.Multiaddr, ttl time.Duration) { + dht.maybeAddAddrs(p, addrs, ttl) +} + +func (dht *FullRT) Peerstore() peerstore.Peerstore { + return dht.h.Peerstore() +} + func (dht *FullRT) GetClosestPeers(ctx context.Context, key string) ([]peer.ID, error) { _, span := internal.StartSpan(ctx, "FullRT.GetClosestPeers", trace.WithAttributes(internal.KeyAsAttribute("Key", key))) defer span.End() @@ -442,6 +464,68 @@ func (dht *FullRT) GetClosestPeers(ctx context.Context, key string) ([]peer.ID, closestKeys := kademlia.ClosestN(kadKey, dht.rt, dht.bucketSize) dht.rtLk.RUnlock() + lowest, highest := getRange(closestKeys, kadKey) + dht.crtLk.RLock() + rangeCovered := dht.crt.RangeIsCovered(cachert.Key(lowest), cachert.Key(highest)) + dht.crtLk.RUnlock() + if rangeCovered { + // Nothing to do + } else { + // Run a query and then put the results into the caching rt + peers := make([]peer.ID, 0, len(closestKeys)) + for _, k := range closestKeys { + dht.kMapLk.RLock() + p, ok := dht.keyToPeerMap[string(k)] + if !ok { + logger.Errorf("key not found in map") + } + dht.kMapLk.RUnlock() + dht.peerAddrsLk.RLock() + peerAddrs := dht.peerAddrs[p] + dht.peerAddrsLk.RUnlock() + + dht.h.Peerstore().AddAddrs(p, peerAddrs, peerstore.TempAddrTTL) + peers = append(peers, p) + } + + r, err := kaddht.RunLookupWithFollowup(ctx, key, func(ctx context.Context, id peer.ID) ([]*peer.AddrInfo, error) { + return dht.protoMessenger.GetClosestPeers(ctx, id, peer.ID(key)) + }, func(*qpeerset.QueryPeerset) bool { return false }, peers, dht) + if err != nil { + return nil, err + } + + var kadKeysOfClosestPeers []kadkey.Key + queryPeers := make([]peer.ID, 0, len(r.Peers)) + for i, p := range r.Peers { + kadKeysOfClosestPeers = append(kadKeysOfClosestPeers, kadkey.KbucketIDToKey(kb.ConvertPeerID(p))) + if state := r.State[i]; state != qpeerset.PeerUnreachable { + queryPeers = append(queryPeers, p) + } + } + lowest, highest := getRange(kadKeysOfClosestPeers, kadKey) + dht.crtLk.Lock() + dht.crt.InsertRange(cachert.Key(lowest), cachert.Key(highest), time.Now()) + dht.crtLk.Unlock() + + // TODO: add all the peers we've heard of and think may be valid to the routing table + dht.rtLk.Lock() + dht.kMapLk.Lock() + dht.peerAddrsLk.Lock() + for _, p := range queryPeers { + pKbID := kb.ConvertKey(string(p)) + pKadKey := kadkey.KbucketIDToKey(pKbID) + dht.rt.Add(pKadKey) + dht.keyToPeerMap[string(pKbID)] = p + dht.peerAddrs[p] = dht.Peerstore().Addrs(p) + dht.peerAddrsLastModifier[p] = time.Now() + } + dht.peerAddrsLk.Unlock() + dht.kMapLk.Unlock() + closestKeys = kademlia.ClosestN(kadKey, dht.rt, dht.bucketSize) + dht.rtLk.Unlock() + } + peers := make([]peer.ID, 0, len(closestKeys)) for _, k := range closestKeys { dht.kMapLk.RLock() @@ -460,6 +544,77 @@ func (dht *FullRT) GetClosestPeers(ctx context.Context, key string) ([]peer.ID, return peers, nil } +var max32Bytes = big.NewInt(0).SetBytes(bytes.Repeat([]byte{0xFF}, 32)) + +// getRange takes the closest keys to kadKey and returns two keys representative of the lower and upper bounds +func getRange(closestKeys []kadkey.Key, kadKey kadkey.Key) (kadkey.Key, kadkey.Key) { + slices.SortFunc(closestKeys, func(a, b kadkey.Key) int { + return bytes.Compare(a, b) + }) + var k, lowerKeyScratch, higherKeyScratch big.Int + k.SetBytes(kadKey) + lowerKeyScratch.SetBytes(closestKeys[0]) + lowerKeyScratch.Sub(&k, &lowerKeyScratch) // lowerKeyScratch = key - lowestKey + higherKeyScratch.SetBytes(closestKeys[len(closestKeys)-1]) + higherKeyScratch.Sub(&k, &higherKeyScratch) // higherKeyScratch = key - highestKey + + var lowest, highest kadkey.Key + if lowerKeyScratch.Sign() == -1 { + // lowestKey is still higher than the key + // So higher = highestKey, lower = key - (highestKey - key) + lowerKeyScratch.Add(&k, &higherKeyScratch) + lowestBytes := make([]byte, 32) + if lowerKeyScratch.Sign() == 1 { + lowest = lowerKeyScratch.FillBytes(lowestBytes) + } else { + lowest = lowestBytes + } + highest = closestKeys[len(closestKeys)-1] + } else if higherKeyScratch.Sign() == 1 { + // highestKey is still lower than the key + // So lower = lowestKey, higher = key + (key - lowestKey) + higherKeyScratch.Add(&k, &lowerKeyScratch) + highestBytes := make([]byte, 32) + if higherKeyScratch.Cmp(max32Bytes) >= 0 { + highest = max32Bytes.FillBytes(highestBytes) + } else { + highest = higherKeyScratch.FillBytes(highestBytes) + } + lowest = closestKeys[0] + } else { // Handle 0? + switch higherKeyScratch.CmpAbs(&lowerKeyScratch) { + case -1: + // highest - key < key - lowest + // So lower = lowestKey, higher = key + (key - lowestKey) + lowest = closestKeys[0] + higherKeyScratch.Add(&k, &lowerKeyScratch) + highestBytes := make([]byte, 32) + if higherKeyScratch.Cmp(max32Bytes) >= 0 { + highest = max32Bytes.FillBytes(highestBytes) + } else { + highest = higherKeyScratch.FillBytes(highestBytes) + } + case 1: + // highest - key > key - lowest + // So higher = highestKey, lower = key - (highestKey - key) + highest = closestKeys[len(closestKeys)-1] + lowerKeyScratch.Add(&k, &higherKeyScratch) + lowestBytes := make([]byte, 32) + if lowerKeyScratch.Sign() == 1 { + lowest = lowerKeyScratch.FillBytes(lowestBytes) + } else { + lowest = lowestBytes + } + case 0: + // TODO: edge case with one peer + lowest = closestKeys[0] + highest = closestKeys[len(closestKeys)-1] + } + } + + return lowest, highest +} + // PutValue adds value corresponding to given Key. // This is the top level "Store" operation of the DHT func (dht *FullRT) PutValue(ctx context.Context, key string, value []byte, opts ...routing.Option) (err error) { @@ -1043,6 +1198,7 @@ func (dht *FullRT) bulkMessageSend(ctx context.Context, keys []peer.ID, fn func( numPeers := len(dht.keyToPeerMap) dht.kMapLk.RUnlock() + numPeers = 10000 chunkSize := (len(sortedKeys) * dht.bucketSize * 2) / numPeers if chunkSize == 0 { chunkSize = 1 @@ -1116,35 +1272,112 @@ func (dht *FullRT) bulkMessageSend(ctx context.Context, keys []peer.ID, fn func( keyGroups := divideByChunkSize(sortedKeys, chunkSize) sendsSoFar := 0 - for _, g := range keyGroups { - if ctx.Err() != nil { - break - } - keysPerPeer := make(map[peer.ID][]peer.ID) - for _, k := range g { - peers, err := dht.GetClosestPeers(ctx, string(k)) - if err == nil { - for _, p := range peers { - keysPerPeer[p] = append(keysPerPeer[p], k) + kgParallelism := 10 + kgWg := sync.WaitGroup{} + kgWg.Add(kgParallelism) + parallelKeyGroups := make([][][]peer.ID, kgParallelism) + for i, g := range keyGroups { + gNum := i % kgParallelism + parallelKeyGroups[gNum] = append(parallelKeyGroups[gNum], g) + } + + sendCh := make(chan struct { + KeysPerPeer map[peer.ID][]peer.ID + GroupSize int + }) + for i := 0; i < kgParallelism; i++ { + // split into kgParallelism groups by alternating + go func(parallelGroups [][]peer.ID) { + defer kgWg.Done() + + for _, g := range parallelGroups { + if ctx.Err() != nil { + return + } + + keysPerPeer := make(map[peer.ID][]peer.ID) + for _, k := range g { + peers, err := dht.GetClosestPeers(ctx, string(k)) + if err == nil { + for _, p := range peers { + keysPerPeer[p] = append(keysPerPeer[p], k) + } + } + } + + select { + case sendCh <- struct { + KeysPerPeer map[peer.ID][]peer.ID + GroupSize int + }{KeysPerPeer: keysPerPeer, GroupSize: len(g)}: + case <-ctx.Done(): + return } } - } + }(parallelKeyGroups[i]) + } - logger.Debugf("bulk send: %d peers for group size %d", len(keysPerPeer), len(g)) + go func() { + kgWg.Wait() + close(sendCh) + }() - keyloop: - for p, workKeys := range keysPerPeer { - select { - case workCh <- workMessage{p: p, keys: workKeys}: - case <-ctx.Done(): - break keyloop +sendKgLoop: + for { + select { + case o, ok := <-sendCh: + if !ok { + break sendKgLoop } + logger.Debugf("bulk send: %d peers for group size %d", len(o.KeysPerPeer), o.GroupSize) + + keyloop: + for p, workKeys := range o.KeysPerPeer { + select { + case workCh <- workMessage{p: p, keys: workKeys}: + case <-ctx.Done(): + break keyloop + } + } + sendsSoFar += o.GroupSize + logger.Infof("bulk sending: %.1f%% done - %d/%d done", 100*float64(sendsSoFar)/float64(len(keySuccesses)), sendsSoFar, len(keySuccesses)) + case <-ctx.Done(): + break sendKgLoop } - sendsSoFar += len(g) - logger.Infof("bulk sending: %.1f%% done - %d/%d done", 100*float64(sendsSoFar)/float64(len(keySuccesses)), sendsSoFar, len(keySuccesses)) } + /* + for _, g := range keyGroups { + if ctx.Err() != nil { + break + } + + keysPerPeer := make(map[peer.ID][]peer.ID) + for _, k := range g { + peers, err := dht.GetClosestPeers(ctx, string(k)) + if err == nil { + for _, p := range peers { + keysPerPeer[p] = append(keysPerPeer[p], k) + } + } + } + + logger.Debugf("bulk send: %d peers for group size %d", len(keysPerPeer), len(g)) + + keyloop: + for p, workKeys := range keysPerPeer { + select { + case workCh <- workMessage{p: p, keys: workKeys}: + case <-ctx.Done(): + break keyloop + } + } + sendsSoFar += len(g) + logger.Infof("bulk sending: %.1f%% done - %d/%d done", 100*float64(sendsSoFar)/float64(len(keySuccesses)), sendsSoFar, len(keySuccesses)) + } + */ + close(workCh) logger.Debugf("bulk send complete, waiting on goroutines to close") diff --git a/go.mod b/go.mod index 3204f8842..3880739cd 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,9 @@ toolchain go1.22.1 retract v0.24.3 // this includes a breaking change and should have been released as v0.25.0 require ( + github.com/emirpasic/gods/v2 v2.0.0-alpha github.com/gogo/protobuf v1.3.2 + github.com/google/btree v1.1.3 github.com/google/gopacket v1.1.19 github.com/google/uuid v1.6.0 github.com/hashicorp/go-multierror v1.1.1 @@ -31,6 +33,7 @@ require ( github.com/multiformats/go-multibase v0.2.0 github.com/multiformats/go-multihash v0.2.3 github.com/multiformats/go-multistream v0.5.0 + github.com/probe-lab/go-kademlia v0.0.0-20240823125516-aed94cdc2c2f github.com/stretchr/testify v1.9.0 github.com/whyrusleeping/go-keyspace v0.0.0-20160322163242-5b898ac5add1 go.opencensus.io v0.24.0 diff --git a/go.sum b/go.sum index d04192bc1..6d6b6b59f 100644 --- a/go.sum +++ b/go.sum @@ -71,6 +71,8 @@ github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25Kn github.com/elastic/gosigar v0.12.0/go.mod h1:iXRIGg2tLnu7LBdpqzyQfGDEidKCfWcCMS0WKyPWoMs= github.com/elastic/gosigar v0.14.3 h1:xwkKwPia+hSfg9GqrCUKYdId102m9qTJIIr7egmK/uo= github.com/elastic/gosigar v0.14.3/go.mod h1:iXRIGg2tLnu7LBdpqzyQfGDEidKCfWcCMS0WKyPWoMs= +github.com/emirpasic/gods/v2 v2.0.0-alpha h1:dwFlh8pBg1VMOXWGipNMRt8v96dKAIvBehtCt6OtunU= +github.com/emirpasic/gods/v2 v2.0.0-alpha/go.mod h1:W0y4M2dtBB9U5z3YlghmpuUhiaZT2h6yoeE+C1sCp6A= github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98= @@ -125,6 +127,8 @@ github.com/golang/protobuf v1.4.1/go.mod h1:U8fpvMrcmy5pZrNK1lt4xCsGvpyWQ/VVv6QD github.com/golang/protobuf v1.4.3/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= +github.com/google/btree v1.1.3 h1:CVpQJjYgC4VbzxeGVHfvZrv1ctoYCAI8vbl07Fcxlyg= +github.com/google/btree v1.1.3/go.mod h1:qOPhT0dTNdNzV6Z/lhRX0YXUafgPLFUh+gZMl761Gm4= github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= @@ -422,6 +426,8 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/polydawn/refmt v0.89.0 h1:ADJTApkvkeBZsN0tBTx8QjpD9JkmxbKp0cxfr9qszm4= github.com/polydawn/refmt v0.89.0/go.mod h1:/zvteZs/GwLtCgZ4BL6CBsk9IKIlexP43ObX9AxTqTw= +github.com/probe-lab/go-kademlia v0.0.0-20240823125516-aed94cdc2c2f h1:RRo6TuvMKIiTiKjQHEYrakdCExb/lPG9rNelr4tZadk= +github.com/probe-lab/go-kademlia v0.0.0-20240823125516-aed94cdc2c2f/go.mod h1:FHBJfMug9mTo5BDBC6i9PbfYPImNCjeGC+oBJxNGsJQ= github.com/prometheus/client_golang v0.8.0/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw= github.com/prometheus/client_golang v1.20.5 h1:cxppBPuYhUnsO6yo/aoRol4L7q7UFfdm+bR9r+8l63Y= github.com/prometheus/client_golang v1.20.5/go.mod h1:PIEt8X02hGcP8JWbeHyeZ53Y/jReSnHgO035n//V5WE= diff --git a/internal/net/message_manager.go b/internal/net/message_manager.go index 294425f76..5090b4be9 100644 --- a/internal/net/message_manager.go +++ b/internal/net/message_manager.go @@ -74,6 +74,8 @@ func (m *messageSenderImpl) OnDisconnect(ctx context.Context, p peer.ID) { // measure the RTT for latency measurements. func (m *messageSenderImpl) SendRequest(ctx context.Context, p peer.ID, pmes *pb.Message) (*pb.Message, error) { ctx, _ = tag.New(ctx, metrics.UpsertMessageType(pmes)) + ctx, cancel := context.WithTimeout(ctx, dhtReadMessageTimeout) + defer cancel() ms, err := m.messageSenderForPeer(ctx, p) if err != nil { @@ -109,6 +111,8 @@ func (m *messageSenderImpl) SendRequest(ctx context.Context, p peer.ID, pmes *pb // SendMessage sends out a message func (m *messageSenderImpl) SendMessage(ctx context.Context, p peer.ID, pmes *pb.Message) error { ctx, _ = tag.New(ctx, metrics.UpsertMessageType(pmes)) + ctx, cancel := context.WithTimeout(ctx, dhtReadMessageTimeout) + defer cancel() ms, err := m.messageSenderForPeer(ctx, p) if err != nil { diff --git a/lookup.go b/lookup.go index 03801325c..cc9e66952 100644 --- a/lookup.go +++ b/lookup.go @@ -34,12 +34,12 @@ func (dht *IpfsDHT) GetClosestPeers(ctx context.Context, key string) ([]peer.ID, return nil, err } - if err := ctx.Err(); err != nil || !lookupRes.completed { - return lookupRes.peers, err + if err := ctx.Err(); err != nil || !lookupRes.Completed { + return lookupRes.Peers, err } // tracking lookup results for network size estimator - if err = dht.nsEstimator.Track(key, lookupRes.closest); err != nil { + if err = dht.nsEstimator.Track(key, lookupRes.Closest); err != nil { logger.Warnf("network size estimator track peers: %s", err) } @@ -50,7 +50,7 @@ func (dht *IpfsDHT) GetClosestPeers(ctx context.Context, key string) ([]peer.ID, // refresh the cpl for this key as the query was successful dht.routingTable.ResetCplRefreshedAtForID(kb.ConvertKey(key), time.Now()) - return lookupRes.peers, nil + return lookupRes.Peers, nil } // pmGetClosestPeers is the protocol messenger version of the GetClosestPeer queryFn. diff --git a/lookup_optim.go b/lookup_optim.go index 428e86f24..ae9213465 100644 --- a/lookup_optim.go +++ b/lookup_optim.go @@ -150,7 +150,7 @@ func (dht *IpfsDHT) optimisticProvide(outerCtx context.Context, keyMH multihash. // Store the provider records with all the closest peers we haven't already contacted/scheduled interaction with. es.peerStatesLk.Lock() - for _, p := range lookupRes.peers { + for _, p := range lookupRes.Peers { if _, found := es.peerStates[p]; found { continue } @@ -163,12 +163,12 @@ func (dht *IpfsDHT) optimisticProvide(outerCtx context.Context, keyMH multihash. // wait until a threshold number of RPCs have completed es.waitForRPCs() - if err := outerCtx.Err(); err != nil || !lookupRes.completed { // likely the "completed" field is false but that's not a given + if err := outerCtx.Err(); err != nil || !lookupRes.Completed { // likely the "completed" field is false but that's not a given return err } // tracking lookup results for network size estimator as "completed" is true - if err = dht.nsEstimator.Track(key, lookupRes.closest); err != nil { + if err = dht.nsEstimator.Track(key, lookupRes.Closest); err != nil { logger.Warnf("network size estimator track peers: %s", err) } diff --git a/query.go b/query.go index 7c01a2af2..694291b03 100644 --- a/query.go +++ b/query.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + ma "github.com/multiformats/go-multiaddr" "math" "sync" "time" @@ -27,6 +28,22 @@ var ErrNoPeersQueried = errors.New("failed to query any peers") type queryFn func(context.Context, peer.ID) ([]*peer.AddrInfo, error) type stopFn func(*qpeerset.QueryPeerset) bool +type DhtQueryIface interface { + Self() peer.ID + BucketSize() int + Beta() int + Alpha() int + DialPeer(context.Context, peer.ID) error + + PeerStoppedDHT(id peer.ID) + UpdateLastUsefulAt(id peer.ID, t time.Time) bool + ValidPeerFound(id peer.ID) + QueryPeerFilter(i interface{}, info peer.AddrInfo) bool + MaybeAddAddrs(p peer.ID, addrs []ma.Multiaddr, ttl time.Duration) + + Peerstore() pstore.Peerstore +} + // query represents a single DHT query. type query struct { // unique identifier for the lookup instance @@ -38,7 +55,7 @@ type query struct { // the query context. ctx context.Context - dht *IpfsDHT + dht DhtQueryIface // seedPeers is the set of peers that seed the query seedPeers []peer.ID @@ -63,29 +80,29 @@ type query struct { stopFn stopFn } -type lookupWithFollowupResult struct { - peers []peer.ID // the top K not unreachable peers at the end of the query - state []qpeerset.PeerState // the peer states at the end of the query of the peers slice (not closest) - closest []peer.ID // the top K peers at the end of the query +type LookupWithFollowupResult struct { + Peers []peer.ID // the top K not unreachable peers at the end of the query + State []qpeerset.PeerState // the peer states at the end of the query of the peers slice (not closest) + Closest []peer.ID // the top K peers at the end of the query // indicates that neither the lookup nor the followup has been prematurely terminated by an external condition such // as context cancellation or the stop function being called. - completed bool + Completed bool } -// runLookupWithFollowup executes the lookup on the target using the given query function and stopping when either the +// RunLookupWithFollowup executes the lookup on the target using the given query function and stopping when either the // context is cancelled or the stop function returns true. Note: if the stop function is not sticky, i.e. it does not // return true every time after the first time it returns true, it is not guaranteed to cause a stop to occur just // because it momentarily returns true. // // After the lookup is complete the query function is run (unless stopped) against all of the top K peers from the // lookup that have not already been successfully queried. -func (dht *IpfsDHT) runLookupWithFollowup(ctx context.Context, target string, queryFn queryFn, stopFn stopFn) (*lookupWithFollowupResult, error) { +func RunLookupWithFollowup(ctx context.Context, target string, queryFn queryFn, stopFn stopFn, seedPeers []peer.ID, dht DhtQueryIface) (*LookupWithFollowupResult, error) { ctx, span := internal.StartSpan(ctx, "IpfsDHT.RunLookupWithFollowup", trace.WithAttributes(internal.KeyAsAttribute("Target", target))) defer span.End() // run the query - lookupRes, qps, err := dht.runQuery(ctx, target, queryFn, stopFn) + lookupRes, qps, err := RunQuery(ctx, target, queryFn, stopFn, seedPeers, dht) if err != nil { return nil, err } @@ -94,9 +111,9 @@ func (dht *IpfsDHT) runLookupWithFollowup(ctx context.Context, target string, qu // This ensures that all of the top K results have been queried which adds to resiliency against churn for query // functions that carry state (e.g. FindProviders and GetValue) as well as establish connections that are needed // by stateless query functions (e.g. GetClosestPeers and therefore Provide and PutValue) - queryPeers := make([]peer.ID, 0, len(lookupRes.peers)) - for i, p := range lookupRes.peers { - if state := lookupRes.state[i]; state == qpeerset.PeerHeard || state == qpeerset.PeerWaiting { + queryPeers := make([]peer.ID, 0, len(lookupRes.Peers)) + for i, p := range lookupRes.Peers { + if state := lookupRes.State[i]; state == qpeerset.PeerHeard || state == qpeerset.PeerWaiting { queryPeers = append(queryPeers, p) } } @@ -107,7 +124,7 @@ func (dht *IpfsDHT) runLookupWithFollowup(ctx context.Context, target string, qu // return if the lookup has been externally stopped if ctx.Err() != nil || stopFn(qps) { - lookupRes.completed = false + lookupRes.Completed = false return lookupRes, nil } @@ -125,25 +142,25 @@ func (dht *IpfsDHT) runLookupWithFollowup(ctx context.Context, target string, qu // wait for all queries to complete before returning, aborting ongoing queries if we've been externally stopped followupsCompleted := 0 processFollowUp: - for i := 0; i < len(queryPeers); i++ { + for i := 0; i < len(queryPeers)/3; i++ { select { case <-doneCh: followupsCompleted++ if stopFn(qps) { cancelFollowUp() if i < len(queryPeers)-1 { - lookupRes.completed = false + lookupRes.Completed = false } break processFollowUp } case <-ctx.Done(): - lookupRes.completed = false + lookupRes.Completed = false cancelFollowUp() break processFollowUp } } - if !lookupRes.completed { + if !lookupRes.Completed { for i := followupsCompleted; i < len(queryPeers); i++ { <-doneCh } @@ -152,13 +169,64 @@ processFollowUp: return lookupRes, nil } -func (dht *IpfsDHT) runQuery(ctx context.Context, target string, queryFn queryFn, stopFn stopFn) (*lookupWithFollowupResult, *qpeerset.QueryPeerset, error) { +func (dht *IpfsDHT) runLookupWithFollowup(ctx context.Context, target string, queryFn queryFn, stopFn stopFn) (*LookupWithFollowupResult, error) { + targetKadID := kb.ConvertKey(target) + seedPeers := dht.routingTable.NearestPeers(targetKadID, dht.bucketSize) + return RunLookupWithFollowup(ctx, target, queryFn, stopFn, seedPeers, dht) +} + +var _ DhtQueryIface = (*IpfsDHT)(nil) + +func (dht *IpfsDHT) Self() peer.ID { + return dht.self +} + +func (dht *IpfsDHT) BucketSize() int { + return dht.bucketSize +} + +func (dht *IpfsDHT) Beta() int { + return dht.beta +} + +func (dht *IpfsDHT) Alpha() int { + return dht.alpha +} + +func (dht *IpfsDHT) DialPeer(ctx context.Context, id peer.ID) error { + return dht.dialPeer(ctx, id) +} + +func (dht *IpfsDHT) PeerStoppedDHT(id peer.ID) { + dht.peerStoppedDHT(id) +} + +func (dht *IpfsDHT) UpdateLastUsefulAt(id peer.ID, t time.Time) bool { + return dht.routingTable.UpdateLastUsefulAt(id, t) +} + +func (dht *IpfsDHT) ValidPeerFound(id peer.ID) { + dht.validPeerFound(id) +} + +func (dht *IpfsDHT) QueryPeerFilter(i interface{}, info peer.AddrInfo) bool { + return dht.queryPeerFilter(i, info) +} + +func (dht *IpfsDHT) MaybeAddAddrs(p peer.ID, addrs []ma.Multiaddr, ttl time.Duration) { + dht.maybeAddAddrs(p, addrs, ttl) +} + +func (dht *IpfsDHT) Peerstore() pstore.Peerstore { + return dht.peerstore +} + +func RunQuery(ctx context.Context, target string, queryFn queryFn, stopFn stopFn, seedPeers []peer.ID, dht DhtQueryIface) (*LookupWithFollowupResult, *qpeerset.QueryPeerset, error) { ctx, span := internal.StartSpan(ctx, "IpfsDHT.RunQuery") defer span.End() // pick the K closest peers to the key in our Routing table. targetKadID := kb.ConvertKey(target) - seedPeers := dht.routingTable.NearestPeers(targetKadID, dht.bucketSize) if len(seedPeers) == 0 { routing.PublishQueryEvent(ctx, &routing.QueryEvent{ Type: routing.QueryError, @@ -192,7 +260,7 @@ func (dht *IpfsDHT) runQuery(ctx context.Context, target string, queryFn queryFn } func (q *query) recordPeerIsValuable(p peer.ID) { - if !q.dht.routingTable.UpdateLastUsefulAt(p, time.Now()) { + if !q.dht.UpdateLastUsefulAt(p, time.Now()) { // not in routing table return } @@ -218,7 +286,7 @@ func (q *query) recordValuablePeers() { } // constructLookupResult takes the query information and uses it to construct the lookup result -func (q *query) constructLookupResult(target kb.ID) *lookupWithFollowupResult { +func (q *query) constructLookupResult(target kb.ID) *LookupWithFollowupResult { // determine if the query terminated early completed := true @@ -232,7 +300,7 @@ func (q *query) constructLookupResult(target kb.ID) *lookupWithFollowupResult { // extract the top K not unreachable peers var peers []peer.ID peerState := make(map[peer.ID]qpeerset.PeerState) - qp := q.queryPeers.GetClosestNInStates(q.dht.bucketSize, qpeerset.PeerHeard, qpeerset.PeerWaiting, qpeerset.PeerQueried) + qp := q.queryPeers.GetClosestNInStates(q.dht.BucketSize(), qpeerset.PeerHeard, qpeerset.PeerWaiting, qpeerset.PeerQueried) for _, p := range qp { state := q.queryPeers.GetState(p) peerState[p] = state @@ -241,22 +309,22 @@ func (q *query) constructLookupResult(target kb.ID) *lookupWithFollowupResult { // get the top K overall peers sortedPeers := kb.SortClosestPeers(peers, target) - if len(sortedPeers) > q.dht.bucketSize { - sortedPeers = sortedPeers[:q.dht.bucketSize] + if len(sortedPeers) > q.dht.BucketSize() { + sortedPeers = sortedPeers[:q.dht.BucketSize()] } - closest := q.queryPeers.GetClosestNInStates(q.dht.bucketSize, qpeerset.PeerHeard, qpeerset.PeerWaiting, qpeerset.PeerQueried, qpeerset.PeerUnreachable) + closest := q.queryPeers.GetClosestNInStates(q.dht.BucketSize(), qpeerset.PeerHeard, qpeerset.PeerWaiting, qpeerset.PeerQueried, qpeerset.PeerUnreachable) // return the top K not unreachable peers as well as their states at the end of the query - res := &lookupWithFollowupResult{ - peers: sortedPeers, - state: make([]qpeerset.PeerState, len(sortedPeers)), - completed: completed, - closest: closest, + res := &LookupWithFollowupResult{ + Peers: sortedPeers, + State: make([]qpeerset.PeerState, len(sortedPeers)), + Completed: completed, + Closest: closest, } for i, p := range sortedPeers { - res.state[i] = peerState[p] + res.State[i] = peerState[p] } return res @@ -278,10 +346,10 @@ func (q *query) run() { pathCtx, cancelPath := context.WithCancel(ctx) defer cancelPath() - alpha := q.dht.alpha + alpha := q.dht.Alpha() ch := make(chan *queryUpdate, alpha) - ch <- &queryUpdate{cause: q.dht.self, heard: q.seedPeers} + ch <- &queryUpdate{cause: q.dht.Self(), heard: q.seedPeers} // return only once all outstanding queries have completed. defer q.waitGroup.Wait() @@ -327,7 +395,7 @@ func (q *query) spawnQuery(ctx context.Context, cause peer.ID, queryPeer peer.ID PublishLookupEvent(ctx, NewLookupEvent( - q.dht.self, + q.dht.Self(), q.id, q.key, NewLookupUpdateEvent( @@ -377,7 +445,7 @@ func (q *query) isReadyToTerminate(ctx context.Context, nPeersToQuery int) (bool // From the set of all nodes that are not unreachable, // if the closest beta nodes are all queried, the lookup can terminate. func (q *query) isLookupTermination() bool { - peers := q.queryPeers.GetClosestNInStates(q.dht.beta, qpeerset.PeerHeard, qpeerset.PeerWaiting, qpeerset.PeerQueried) + peers := q.queryPeers.GetClosestNInStates(q.dht.Beta(), qpeerset.PeerHeard, qpeerset.PeerWaiting, qpeerset.PeerQueried) for _, p := range peers { if q.queryPeers.GetState(p) != qpeerset.PeerQueried { return false @@ -400,7 +468,7 @@ func (q *query) terminate(ctx context.Context, cancel context.CancelFunc, reason PublishLookupEvent(ctx, NewLookupEvent( - q.dht.self, + q.dht.Self(), q.id, q.key, nil, @@ -423,10 +491,10 @@ func (q *query) queryPeer(ctx context.Context, ch chan<- *queryUpdate, p peer.ID dialCtx, queryCtx := ctx, ctx // dial the peer - if err := q.dht.dialPeer(dialCtx, p); err != nil { + if err := q.dht.DialPeer(dialCtx, p); err != nil { // remove the peer if there was a dial failure..but not because of a context cancellation if dialCtx.Err() == nil { - q.dht.peerStoppedDHT(p) + q.dht.PeerStoppedDHT(p) } ch <- &queryUpdate{cause: p, unreachable: []peer.ID{p}} return @@ -437,7 +505,7 @@ func (q *query) queryPeer(ctx context.Context, ch chan<- *queryUpdate, p peer.ID newPeers, err := q.queryFn(queryCtx, p) if err != nil { if queryCtx.Err() == nil { - q.dht.peerStoppedDHT(p) + q.dht.PeerStoppedDHT(p) } ch <- &queryUpdate{cause: p, unreachable: []peer.ID{p}} return @@ -446,18 +514,18 @@ func (q *query) queryPeer(ctx context.Context, ch chan<- *queryUpdate, p peer.ID queryDuration := time.Since(startQuery) // query successful, try to add to RT - q.dht.validPeerFound(p) + q.dht.ValidPeerFound(p) // process new peers saw := []peer.ID{} for _, next := range newPeers { - if next.ID == q.dht.self { // don't add self. + if next.ID == q.dht.Self() { // don't add self. logger.Debugf("PEERS CLOSER -- worker for: %v found self", p) continue } // add any other know addresses for the candidate peer. - curInfo := q.dht.peerstore.PeerInfo(next.ID) + curInfo := q.dht.Peerstore().PeerInfo(next.ID) next.Addrs = append(next.Addrs, curInfo.Addrs...) // add their addresses to the dialer's peerstore @@ -465,8 +533,8 @@ func (q *query) queryPeer(ctx context.Context, ch chan<- *queryUpdate, p peer.ID // add the next peer to the query if matches the query target even if it would otherwise fail the query filter // TODO: this behavior is really specific to how FindPeer works and not GetClosestPeers or any other function isTarget := string(next.ID) == q.key - if isTarget || q.dht.queryPeerFilter(q.dht, *next) { - q.dht.maybeAddAddrs(next.ID, next.Addrs, pstore.TempAddrTTL) + if isTarget || q.dht.QueryPeerFilter(q.dht, *next) { + q.dht.MaybeAddAddrs(next.ID, next.Addrs, pstore.TempAddrTTL) saw = append(saw, next.ID) } } @@ -480,7 +548,7 @@ func (q *query) updateState(ctx context.Context, up *queryUpdate) { } PublishLookupEvent(ctx, NewLookupEvent( - q.dht.self, + q.dht.Self(), q.id, q.key, nil, @@ -496,13 +564,13 @@ func (q *query) updateState(ctx context.Context, up *queryUpdate) { ), ) for _, p := range up.heard { - if p == q.dht.self { // don't add self. + if p == q.dht.Self() { // don't add self. continue } q.queryPeers.TryAdd(p, up.cause) } for _, p := range up.queried { - if p == q.dht.self { // don't add self. + if p == q.dht.Self() { // don't add self. continue } if st := q.queryPeers.GetState(p); st == qpeerset.PeerWaiting { @@ -513,7 +581,7 @@ func (q *query) updateState(ctx context.Context, up *queryUpdate) { } } for _, p := range up.unreachable { - if p == q.dht.self { // don't add self. + if p == q.dht.Self() { // don't add self. continue } diff --git a/routing.go b/routing.go index 1df05e1b1..9b020ad19 100644 --- a/routing.go +++ b/routing.go @@ -179,7 +179,7 @@ func (dht *IpfsDHT) SearchValue(ctx context.Context, key string, opts ...routing return } - for _, p := range l.peers { + for _, p := range l.Peers { if _, ok := peersWithBest[p]; !ok { updatePeers = append(updatePeers, p) } @@ -281,9 +281,9 @@ func (dht *IpfsDHT) updatePeerValues(ctx context.Context, key string, val []byte } } -func (dht *IpfsDHT) getValues(ctx context.Context, key string, stopQuery chan struct{}) (<-chan recvdVal, <-chan *lookupWithFollowupResult) { +func (dht *IpfsDHT) getValues(ctx context.Context, key string, stopQuery chan struct{}) (<-chan recvdVal, <-chan *LookupWithFollowupResult) { valCh := make(chan recvdVal, 1) - lookupResCh := make(chan *lookupWithFollowupResult, 1) + lookupResCh := make(chan *LookupWithFollowupResult, 1) logger.Debugw("finding value", "key", internal.LoggableRecordKeyString(key)) @@ -371,8 +371,8 @@ func (dht *IpfsDHT) getValues(ctx context.Context, key string, stopQuery chan st return valCh, lookupResCh } -func (dht *IpfsDHT) refreshRTIfNoShortcut(key kb.ID, lookupRes *lookupWithFollowupResult) { - if lookupRes.completed { +func (dht *IpfsDHT) refreshRTIfNoShortcut(key kb.ID, lookupRes *LookupWithFollowupResult) { + if lookupRes.Completed { // refresh the cpl for this key as the query was successful dht.routingTable.ResetCplRefreshedAtForID(key, time.Now()) } @@ -672,12 +672,12 @@ func (dht *IpfsDHT) FindPeer(ctx context.Context, id peer.ID) (pi peer.AddrInfo, } dialedPeerDuringQuery := false - for i, p := range lookupRes.peers { + for i, p := range lookupRes.Peers { if p == id { // Note: we consider PeerUnreachable to be a valid state because the peer may not support the DHT protocol // and therefore the peer would fail the query. The fact that a peer that is returned can be a non-DHT // server peer and is not identified as such is a bug. - dialedPeerDuringQuery = (lookupRes.state[i] == qpeerset.PeerQueried || lookupRes.state[i] == qpeerset.PeerUnreachable || lookupRes.state[i] == qpeerset.PeerWaiting) + dialedPeerDuringQuery = (lookupRes.State[i] == qpeerset.PeerQueried || lookupRes.State[i] == qpeerset.PeerUnreachable || lookupRes.State[i] == qpeerset.PeerWaiting) break } }