diff --git a/dht_test.go b/dht_test.go index 3f07b911f..0f9058232 100644 --- a/dht_test.go +++ b/dht_test.go @@ -1343,7 +1343,8 @@ func TestClientModeConnect(t *testing.T) { c := testCaseCids[0] p := peer.ID("TestPeer") - a.ProviderManager.AddProvider(ctx, c.Hash(), p) + err := a.ProviderManager.AddProviderNonBlocking(ctx, c.Hash(), p) + require.NoError(t, err) time.Sleep(time.Millisecond * 5) // just in case... provs, err := b.FindProviders(ctx, c) diff --git a/fullrt/dht.go b/fullrt/dht.go index dfee22807..cbf3065f7 100644 --- a/fullrt/dht.go +++ b/fullrt/dht.go @@ -773,9 +773,15 @@ func (dht *FullRT) Provide(ctx context.Context, key cid.Cid, brdcst bool) (err e logger.Debugw("providing", "cid", key, "mh", internal.LoggableProviderRecordBytes(keyMH)) // add self locally - dht.ProviderManager.AddProvider(ctx, keyMH, dht.h.ID()) + err = dht.ProviderManager.AddProvider(ctx, keyMH, dht.h.ID()) if !brdcst { - return nil + // If we're not broadcasting, return immediately. But also return the error because, + // if something went wrong, we basically failed to do anything. + return err + } + if err != nil { + // Otherwise, "local" provides are "best effort". + logger.Debugw("local provide failed", "error", err) } closerCtx := ctx diff --git a/handlers.go b/handlers.go index 5160232c0..1de82f0d3 100644 --- a/handlers.go +++ b/handlers.go @@ -9,6 +9,7 @@ import ( "github.com/libp2p/go-libp2p-core/peer" "github.com/libp2p/go-libp2p-core/peerstore" + kb "github.com/libp2p/go-libp2p-kbucket" pstore "github.com/libp2p/go-libp2p-peerstore" "github.com/gogo/protobuf/proto" @@ -317,8 +318,39 @@ func (dht *IpfsDHT) handleGetProviders(ctx context.Context, p peer.ID, pmes *pb. resp := pb.NewMessage(pmes.GetType(), pmes.GetKey(), pmes.GetClusterLevel()) + // Find closer peers. + closer := dht.betterPeersToQuery(pmes, p, dht.bucketSize) + myBucket := true + if len(closer) > 0 { + // Fill out peer infos. + // TODO: pstore.PeerInfos should move to core (=> peerstore.AddrInfos). + infos := pstore.PeerInfos(dht.peerstore, closer) + resp.CloserPeers = pb.PeerInfosToPBPeers(dht.host.Network(), infos) + + // If we have a full bucket of closer peers, check to see if _we're_ in the closest + // set. + if len(closer) >= dht.bucketSize { + // Check to see if _we're_ in the "close" bucket. + // If not, we _may_ + peers := append(closer, dht.self) + peers = kb.SortClosestPeers(peers, kb.ConvertKey(string(pmes.GetKey()))) + myBucket = peers[len(peers)-1] != dht.self + } + } + // setup providers - providers := dht.ProviderManager.GetProviders(ctx, key) + var providers []peer.ID + if myBucket { + // If we're in the closest set, block getting providers. + providers = dht.ProviderManager.GetProviders(ctx, key) + } else { + // Otherwise, don't block. The peer will find a closer peer. + var err error + providers, err = dht.ProviderManager.GetProvidersNonBlocking(ctx, key) + if err != nil { + logger.Debugw("dropping get providers requests", err) + } + } if len(providers) > 0 { // TODO: pstore.PeerInfos should move to core (=> peerstore.AddrInfos). @@ -326,14 +358,6 @@ func (dht *IpfsDHT) handleGetProviders(ctx context.Context, p peer.ID, pmes *pb. resp.ProviderPeers = pb.PeerInfosToPBPeers(dht.host.Network(), infos) } - // Also send closer peers. - closer := dht.betterPeersToQuery(pmes, p, dht.bucketSize) - if closer != nil { - // TODO: pstore.PeerInfos should move to core (=> peerstore.AddrInfos). - infos := pstore.PeerInfos(dht.peerstore, closer) - resp.CloserPeers = pb.PeerInfosToPBPeers(dht.host.Network(), infos) - } - return resp, nil } @@ -366,7 +390,10 @@ func (dht *IpfsDHT) handleAddProvider(ctx context.Context, p peer.ID, pmes *pb.M // add the received addresses to our peerstore. dht.peerstore.AddAddrs(pi.ID, pi.Addrs, peerstore.ProviderAddrTTL) } - dht.ProviderManager.AddProvider(ctx, key, p) + err := dht.ProviderManager.AddProviderNonBlocking(ctx, key, p) + if err != nil { + return nil, err + } } return nil, nil diff --git a/providers/providers_manager.go b/providers/providers_manager.go index 20927d2e8..3e31c79e5 100644 --- a/providers/providers_manager.go +++ b/providers/providers_manager.go @@ -3,6 +3,7 @@ package providers import ( "context" "encoding/binary" + "errors" "fmt" "strings" "time" @@ -19,6 +20,11 @@ import ( base32 "github.com/multiformats/go-base32" ) +var ( + ErrWouldBlock = errors.New("provide would block") + ErrClosing = errors.New("provider manager is closing") +) + // ProvidersKeyPrefix is the prefix/namespace for ALL provider record // keys stored in the data store. const ProvidersKeyPrefix = "/providers/" @@ -29,6 +35,9 @@ var defaultCleanupInterval = time.Hour var lruCacheSize = 256 var batchBufferSize = 256 var log = logging.Logger("providers") +var defaultProvideBufferSize = 256 +var defaultGetProvidersBufferSize = 16 +var defaultGetProvidersNonBlockingBufferSize = defaultGetProvidersBufferSize / 4 // ProviderManager adds and pulls providers out of the datastore, // caching them in between @@ -38,9 +47,10 @@ type ProviderManager struct { cache lru.LRUCache dstore *autobatch.Datastore - newprovs chan *addProv - getprovs chan *getProv - proc goprocess.Process + nonBlocking bool + newprovs chan *addProv + getprovs chan *getProv + proc goprocess.Process cleanupInterval time.Duration } @@ -75,9 +85,19 @@ func Cache(c lru.LRUCache) Option { } } +// NonBlockingProvide causes the provide manager to drop inbound provides when the queue is full +// instead of blocking. +func NonBlockingProvide(nonBlocking bool) Option { + return func(pm *ProviderManager) error { + pm.nonBlocking = nonBlocking + return nil + } +} + type addProv struct { - key []byte - val peer.ID + key []byte + val peer.ID + resp chan error } type getProv struct { @@ -88,8 +108,9 @@ type getProv struct { // NewProviderManager constructor func NewProviderManager(ctx context.Context, local peer.ID, dstore ds.Batching, opts ...Option) (*ProviderManager, error) { pm := new(ProviderManager) - pm.getprovs = make(chan *getProv) - pm.newprovs = make(chan *addProv) + pm.nonBlocking = true + pm.getprovs = make(chan *getProv, defaultGetProvidersBufferSize) + pm.newprovs = make(chan *addProv, defaultProvideBufferSize) pm.dstore = autobatch.NewAutoBatching(dstore, batchBufferSize) cache, err := lru.NewLRU(lruCacheSize, nil) if err != nil { @@ -134,6 +155,9 @@ func (pm *ProviderManager) run(proc goprocess.Process) { select { case np := <-pm.newprovs: err := pm.addProv(np.key, np.val) + if np.resp != nil { + np.resp <- err + } if err != nil { log.Error("error adding new providers: ", err) continue @@ -213,15 +237,50 @@ func (pm *ProviderManager) run(proc goprocess.Process) { } } -// AddProvider adds a provider -func (pm *ProviderManager) AddProvider(ctx context.Context, k []byte, val peer.ID) { +// AddProviderNonBlocking adds a provider +func (pm *ProviderManager) AddProviderNonBlocking(ctx context.Context, k []byte, val peer.ID) error { prov := &addProv{ key: k, val: val, } + if pm.nonBlocking { + select { + case pm.newprovs <- prov: + default: + return ErrWouldBlock + } + } else { + select { + case pm.newprovs <- prov: + case <-pm.proc.Closing(): + return ErrClosing + case <-ctx.Done(): + return ctx.Err() + } + } + return nil +} + +func (pm *ProviderManager) AddProvider(ctx context.Context, k []byte, val peer.ID) error { + prov := &addProv{ + key: k, + val: val, + resp: make(chan error, 1), + } select { case pm.newprovs <- prov: + case <-pm.proc.Closing(): + return ErrClosing case <-ctx.Done(): + return ctx.Err() + } + select { + case err := <-prov.resp: + return err + case <-ctx.Done(): + return ctx.Err() + case <-pm.proc.Closing(): + return ErrClosing } } @@ -264,8 +323,12 @@ func (pm *ProviderManager) GetProviders(ctx context.Context, k []byte) []peer.ID case <-ctx.Done(): return nil case pm.getprovs <- gp: + case <-pm.proc.Closing(): + return nil } select { + case <-pm.proc.Closing(): + return nil case <-ctx.Done(): return nil case peers := <-gp.resp: @@ -273,6 +336,38 @@ func (pm *ProviderManager) GetProviders(ctx context.Context, k []byte) []peer.ID } } +// GetProvidersNonBlocking returns the set of providers for the given key. If the "get providers" +// queue is full, it returns immediately. +// +// This method _does not_ copy the set. Do not modify it. +func (pm *ProviderManager) GetProvidersNonBlocking(ctx context.Context, k []byte) ([]peer.ID, error) { + // If we're "busy", don't even try. This is clearly racy, but it's mostly an "optimistic" + // check anyways and it should stabalize pretty quickly when we're under load. + // + // This helps leave some space for peers that actually need responses. + if len(pm.getprovs) > defaultGetProvidersNonBlockingBufferSize { + return nil, ErrWouldBlock + } + + gp := &getProv{ + key: k, + resp: make(chan []peer.ID, 1), // buffered to prevent sender from blocking + } + select { + case pm.getprovs <- gp: + default: + return nil, ErrWouldBlock + } + select { + case <-pm.proc.Closing(): + return nil, ErrClosing + case <-ctx.Done(): + return nil, ctx.Err() + case peers := <-gp.resp: + return peers, nil + } +} + func (pm *ProviderManager) getProvidersForKey(k []byte) ([]peer.ID, error) { pset, err := pm.getProviderSetForKey(k) if err != nil { diff --git a/providers/providers_manager_test.go b/providers/providers_manager_test.go index 11a9a0e5b..0f1d116cd 100644 --- a/providers/providers_manager_test.go +++ b/providers/providers_manager_test.go @@ -9,6 +9,7 @@ import ( "time" "github.com/libp2p/go-libp2p-core/peer" + "github.com/stretchr/testify/require" mh "github.com/multiformats/go-multihash" @@ -31,7 +32,7 @@ func TestProviderManager(t *testing.T) { t.Fatal(err) } a := u.Hash([]byte("test")) - p.AddProvider(ctx, a, peer.ID("testingprovider")) + require.NoError(t, p.AddProvider(ctx, a, peer.ID("testingprovider"))) // Not cached // TODO verify that cache is empty @@ -47,8 +48,8 @@ func TestProviderManager(t *testing.T) { t.Fatal("Could not retrieve provider.") } - p.AddProvider(ctx, a, peer.ID("testingprovider2")) - p.AddProvider(ctx, a, peer.ID("testingprovider3")) + require.NoError(t, p.AddProvider(ctx, a, peer.ID("testingprovider2"))) + require.NoError(t, p.AddProvider(ctx, a, peer.ID("testingprovider3"))) // TODO verify that cache is already up to date resp = p.GetProviders(ctx, a) if len(resp) != 3 { @@ -78,7 +79,7 @@ func TestProvidersDatastore(t *testing.T) { for i := 0; i < 100; i++ { h := u.Hash([]byte(fmt.Sprint(i))) mhs = append(mhs, h) - p.AddProvider(ctx, h, friend) + require.NoError(t, p.AddProvider(ctx, h, friend)) } for _, c := range mhs { @@ -165,15 +166,15 @@ func TestProvidesExpire(t *testing.T) { } for _, h := range mhs[:5] { - p.AddProvider(ctx, h, peers[0]) - p.AddProvider(ctx, h, peers[1]) + require.NoError(t, p.AddProvider(ctx, h, peers[0])) + require.NoError(t, p.AddProvider(ctx, h, peers[1])) } time.Sleep(time.Second / 4) for _, h := range mhs[5:] { - p.AddProvider(ctx, h, peers[0]) - p.AddProvider(ctx, h, peers[1]) + require.NoError(t, p.AddProvider(ctx, h, peers[0])) + require.NoError(t, p.AddProvider(ctx, h, peers[1])) } for _, h := range mhs { @@ -271,7 +272,7 @@ func TestLargeProvidersSet(t *testing.T) { h := u.Hash([]byte(fmt.Sprint(i))) mhs = append(mhs, h) for _, pid := range peers { - p.AddProvider(ctx, h, pid) + require.NoError(t, p.AddProvider(ctx, h, pid)) } } @@ -296,16 +297,14 @@ func TestUponCacheMissProvidersAreReadFromDatastore(t *testing.T) { h1 := u.Hash([]byte("1")) h2 := u.Hash([]byte("2")) pm, err := NewProviderManager(ctx, p1, dssync.MutexWrap(ds.NewMapDatastore())) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) // add provider - pm.AddProvider(ctx, h1, p1) + require.NoError(t, pm.AddProvider(ctx, h1, p1)) // make the cached provider for h1 go to datastore - pm.AddProvider(ctx, h2, p1) + require.NoError(t, pm.AddProvider(ctx, h2, p1)) // now just offloaded record should be brought back and joined with p2 - pm.AddProvider(ctx, h1, p2) + require.NoError(t, pm.AddProvider(ctx, h1, p2)) h1Provs := pm.GetProviders(ctx, h1) if len(h1Provs) != 2 { @@ -325,11 +324,11 @@ func TestWriteUpdatesCache(t *testing.T) { } // add provider - pm.AddProvider(ctx, h1, p1) + require.NoError(t, pm.AddProvider(ctx, h1, p1)) // force into the cache - pm.GetProviders(ctx, h1) + _ = pm.GetProviders(ctx, h1) // add a second provider - pm.AddProvider(ctx, h1, p2) + require.NoError(t, pm.AddProvider(ctx, h1, p2)) c1Provs := pm.GetProviders(ctx, h1) if len(c1Provs) != 2 { diff --git a/routing.go b/routing.go index 7793bebb4..9bfc192de 100644 --- a/routing.go +++ b/routing.go @@ -403,9 +403,15 @@ func (dht *IpfsDHT) Provide(ctx context.Context, key cid.Cid, brdcst bool) (err logger.Debugw("providing", "cid", key, "mh", internal.LoggableProviderRecordBytes(keyMH)) // add self locally - dht.ProviderManager.AddProvider(ctx, keyMH, dht.self) + err = dht.ProviderManager.AddProvider(ctx, keyMH, dht.self) if !brdcst { - return nil + // If we're not broadcasting, return immediately. But also return the error because, + // if something went wrong, we basically failed to do anything. + return err + } + if err != nil { + // Otherwise, "local" provides are "best effort". + logger.Debugw("local provide failed", "error", err) } closerCtx := ctx