diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index eb4746b1..a235287a 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -59,7 +59,7 @@ import ( "github.com/filecoin-project/go-statemachine" - datatransfer "github.com/filecoin-project/go-data-transfer" + datatransfer "github.com/filecoin-project/go-data-transfer/v2" ) ``` diff --git a/README.md b/README.md index ff5738ac..3ab3f91f 100644 --- a/README.md +++ b/README.md @@ -21,7 +21,7 @@ This module encapsulates protocols for exchanging piece data between storage cli **Requires go 1.13** -Install the module in your package or app with `go get "github.com/filecoin-project/go-data-transfer/datatransfer"` +Install the module in your package or app with `go get "github.com/filecoin-project/go-data-transfer/v2/datatransfer"` ### Initialize a data transfer module @@ -31,8 +31,8 @@ Install the module in your package or app with `go get "github.com/filecoin-proj import ( gsimpl "github.com/ipfs/go-graphsync/impl" - datatransfer "github.com/filecoin-project/go-data-transfer/impl" - gstransport "github.com/filecoin-project/go-data-transfer/transport/graphsync" + datatransfer "github.com/filecoin-project/go-data-transfer/v2/impl" + gstransport "github.com/filecoin-project/go-data-transfer/v2/transport/graphsync" "github.com/libp2p/go-libp2p-core/host" ) @@ -85,7 +85,7 @@ func (vl *myValidator) ValidatePush( sender peer.ID, voucher datatransfer.Voucher, baseCid cid.Cid, - selector ipld.Node) error { + selector datamodel.Node) error { v := voucher.(*myVoucher) if v.data == "" || v.data != "validpush" { @@ -99,7 +99,7 @@ func (vl *myValidator) ValidatePull( receiver peer.ID, voucher datatransfer.Voucher, baseCid cid.Cid, - selector ipld.Node) error { + selector datamodel.Node) error { v := voucher.(*myVoucher) if v.data == "" || v.data != "validpull" { @@ -135,7 +135,7 @@ must be sent with the request. Using the trivial examples above: For more detail, please see the [unit tests](https://github.com/filecoin-project/go-data-transfer/blob/master/impl/impl_test.go). ### Open a Push or Pull Request -For a push or pull request, provide a context, a `datatransfer.Voucher`, a host recipient `peer.ID`, a baseCID `cid.CID` and a selector `ipld.Node`. These +For a push or pull request, provide a context, a `datatransfer.Voucher`, a host recipient `peer.ID`, a baseCID `cid.CID` and a selector `datamodel.Node`. These calls return a `datatransfer.ChannelID` and any error: ```go channelID, err := dtm.OpenPullDataChannel(ctx, recipient, voucher, baseCid, selector) diff --git a/benchmarks/benchmark_test.go b/benchmarks/benchmark_test.go index 78d53f00..5216f534 100644 --- a/benchmarks/benchmark_test.go +++ b/benchmarks/benchmark_test.go @@ -24,16 +24,14 @@ import ( "github.com/ipfs/go-merkledag" "github.com/ipfs/go-unixfs/importer/balanced" ihelper "github.com/ipfs/go-unixfs/importer/helpers" - basicnode "github.com/ipld/go-ipld-prime/node/basic" - ipldselector "github.com/ipld/go-ipld-prime/traversal/selector" - "github.com/ipld/go-ipld-prime/traversal/selector/builder" + selectorparse "github.com/ipld/go-ipld-prime/traversal/selector/parse" mocknet "github.com/libp2p/go-libp2p/p2p/net/mock" "github.com/stretchr/testify/require" - datatransfer "github.com/filecoin-project/go-data-transfer" - "github.com/filecoin-project/go-data-transfer/benchmarks/testinstance" - tn "github.com/filecoin-project/go-data-transfer/benchmarks/testnet" - "github.com/filecoin-project/go-data-transfer/testutil" + datatransfer "github.com/filecoin-project/go-data-transfer/v2" + "github.com/filecoin-project/go-data-transfer/v2/benchmarks/testinstance" + tn "github.com/filecoin-project/go-data-transfer/v2/benchmarks/testnet" + "github.com/filecoin-project/go-data-transfer/v2/testutil" ) const stdBlockSize = 8000 @@ -77,10 +75,6 @@ func p2pStrestTest(ctx context.Context, b *testing.B, numfiles int, df distFunc, thisCids := df(ctx, b, instances[:1]) allCids = append(allCids, thisCids...) } - ssb := builder.NewSelectorSpecBuilder(basicnode.Prototype.Any) - - allSelector := ssb.ExploreRecursive(ipldselector.RecursionLimitNone(), - ssb.ExploreAll(ssb.ExploreRecursiveEdge())).Node() runtime.GC() b.ResetTimer() @@ -105,7 +99,7 @@ func p2pStrestTest(ctx context.Context, b *testing.B, numfiles int, df distFunc, timer := time.NewTimer(30 * time.Second) start := time.Now() for j := 0; j < numfiles; j++ { - _, err := pusher.Manager.OpenPushDataChannel(ctx, receiver.Peer, testutil.NewFakeDTType(), allCids[j], allSelector) + _, err := pusher.Manager.OpenPushDataChannel(ctx, receiver.Peer, testutil.NewTestTypedVoucher(), allCids[j], selectorparse.CommonSelector_ExploreAllRecursively) if err != nil { b.Fatalf("received error on request: %s", err.Error()) } diff --git a/benchmarks/testinstance/testinstance.go b/benchmarks/testinstance/testinstance.go index 6732610e..685cd264 100644 --- a/benchmarks/testinstance/testinstance.go +++ b/benchmarks/testinstance/testinstance.go @@ -18,12 +18,12 @@ import ( "github.com/ipld/go-ipld-prime" peer "github.com/libp2p/go-libp2p-core/peer" - datatransfer "github.com/filecoin-project/go-data-transfer" - tn "github.com/filecoin-project/go-data-transfer/benchmarks/testnet" - dtimpl "github.com/filecoin-project/go-data-transfer/impl" - dtnet "github.com/filecoin-project/go-data-transfer/network" - "github.com/filecoin-project/go-data-transfer/testutil" - gstransport "github.com/filecoin-project/go-data-transfer/transport/graphsync" + datatransfer "github.com/filecoin-project/go-data-transfer/v2" + tn "github.com/filecoin-project/go-data-transfer/v2/benchmarks/testnet" + dtimpl "github.com/filecoin-project/go-data-transfer/v2/impl" + "github.com/filecoin-project/go-data-transfer/v2/testutil" + gstransport "github.com/filecoin-project/go-data-transfer/v2/transport/graphsync" + dtnet "github.com/filecoin-project/go-data-transfer/v2/transport/helpers/network" ) // TempDirGenerator is any interface that can generate temporary directories @@ -164,8 +164,8 @@ func NewInstance(ctx context.Context, net tn.Network, tempDir string, diskBasedD linkSystem := storeutil.LinkSystemForBlockstore(bstore) gs := gsimpl.New(ctx, gsNet, linkSystem, gsimpl.RejectAllRequestsByDefault()) - transport := gstransport.NewTransport(p, gs) - dt, err := dtimpl.NewDataTransfer(namespace.Wrap(dstore, datastore.NewKey("/data-transfers/transfers")), dtNet, transport) + transport := gstransport.NewTransport(gs, dtNet) + dt, err := dtimpl.NewDataTransfer(namespace.Wrap(dstore, datastore.NewKey("/data-transfers/transfers")), p, transport) if err != nil { return Instance{}, err } @@ -188,8 +188,7 @@ func NewInstance(ctx context.Context, net tn.Network, tempDir string, diskBasedD sv := testutil.NewStubbedValidator() sv.StubSuccessPull() sv.StubSuccessPush() - dt.RegisterVoucherType(testutil.NewFakeDTType(), sv) - dt.RegisterVoucherResultType(testutil.NewFakeDTType()) + dt.RegisterVoucherType(testutil.TestVoucherType, sv) return Instance{ Adapter: dtNet, Peer: p, diff --git a/benchmarks/testnet/interface.go b/benchmarks/testnet/interface.go index 8a53d040..72858184 100644 --- a/benchmarks/testnet/interface.go +++ b/benchmarks/testnet/interface.go @@ -4,7 +4,7 @@ import ( gsnet "github.com/ipfs/go-graphsync/network" "github.com/libp2p/go-libp2p-core/peer" - dtnet "github.com/filecoin-project/go-data-transfer/network" + dtnet "github.com/filecoin-project/go-data-transfer/v2/transport/helpers/network" ) // Network is an interface for generating graphsync network interfaces diff --git a/benchmarks/testnet/peernet.go b/benchmarks/testnet/peernet.go index 94e469eb..ba83ff09 100644 --- a/benchmarks/testnet/peernet.go +++ b/benchmarks/testnet/peernet.go @@ -7,7 +7,7 @@ import ( "github.com/libp2p/go-libp2p-core/peer" mockpeernet "github.com/libp2p/go-libp2p/p2p/net/mock" - dtnet "github.com/filecoin-project/go-data-transfer/network" + dtnet "github.com/filecoin-project/go-data-transfer/v2/transport/helpers/network" ) type peernet struct { diff --git a/channelmonitor/channelmonitor.go b/channelmonitor/channelmonitor.go index ed3b0a3c..2eb0b5be 100644 --- a/channelmonitor/channelmonitor.go +++ b/channelmonitor/channelmonitor.go @@ -11,8 +11,8 @@ import ( "github.com/libp2p/go-libp2p-core/peer" "golang.org/x/xerrors" - datatransfer "github.com/filecoin-project/go-data-transfer" - "github.com/filecoin-project/go-data-transfer/channels" + datatransfer "github.com/filecoin-project/go-data-transfer/v2" + "github.com/filecoin-project/go-data-transfer/v2/channels" ) var log = logging.Logger("dt-chanmon") @@ -21,7 +21,6 @@ type monitorAPI interface { SubscribeToEvents(subscriber datatransfer.Subscriber) datatransfer.Unsubscribe RestartDataTransferChannel(ctx context.Context, chid datatransfer.ChannelID) error CloseDataTransferChannelWithError(ctx context.Context, chid datatransfer.ChannelID, cherr error) error - ConnectTo(context.Context, peer.ID) error PeerID() peer.ID } @@ -84,18 +83,8 @@ func checkConfig(cfg *Config) { } } -// AddPushChannel adds a push channel to the channel monitor -func (m *Monitor) AddPushChannel(chid datatransfer.ChannelID) *monitoredChannel { - return m.addChannel(chid, true) -} - -// AddPullChannel adds a pull channel to the channel monitor -func (m *Monitor) AddPullChannel(chid datatransfer.ChannelID) *monitoredChannel { - return m.addChannel(chid, false) -} - -// addChannel adds a channel to the channel monitor -func (m *Monitor) addChannel(chid datatransfer.ChannelID, isPush bool) *monitoredChannel { +// AddChannel adds a channel to the channel monitor +func (m *Monitor) AddChannel(chid datatransfer.ChannelID, isPull bool) *monitoredChannel { if !m.enabled() { return nil } @@ -106,7 +95,7 @@ func (m *Monitor) addChannel(chid datatransfer.ChannelID, isPush bool) *monitore // Check if there is already a monitor for this channel if _, ok := m.channels[chid]; ok { tp := "push" - if !isPush { + if isPull { tp = "pull" } log.Warnf("ignoring add %s channel %s: %s channel with that id already exists", @@ -454,22 +443,11 @@ func (mc *monitoredChannel) doRestartChannel() error { } func (mc *monitoredChannel) sendRestartMessage(restartCount int) error { - // Establish a connection to the peer, in case the connection went down. - // Note that at the networking layer there is logic to retry if a network - // connection cannot be established, so this may take some time. p := mc.chid.OtherParty(mc.mgr.PeerID()) - log.Debugf("%s: re-establishing connection to %s", mc.chid, p) - start := time.Now() - err := mc.mgr.ConnectTo(mc.ctx, p) - if err != nil { - return xerrors.Errorf("%s: failed to reconnect to peer %s after %s: %w", - mc.chid, p, time.Since(start), err) - } - log.Debugf("%s: re-established connection to %s in %s", mc.chid, p, time.Since(start)) // Send a restart message for the channel log.Debugf("%s: sending restart message to %s (%d consecutive restarts)", mc.chid, p, restartCount) - err = mc.mgr.RestartDataTransferChannel(mc.ctx, mc.chid) + err := mc.mgr.RestartDataTransferChannel(mc.ctx, mc.chid) if err != nil { return xerrors.Errorf("%s: failed to send restart message to %s: %w", mc.chid, p, err) } diff --git a/channelmonitor/channelmonitor_test.go b/channelmonitor/channelmonitor_test.go index bb4857f0..aa0acc32 100644 --- a/channelmonitor/channelmonitor_test.go +++ b/channelmonitor/channelmonitor_test.go @@ -7,13 +7,12 @@ import ( "testing" "time" - "github.com/ipfs/go-cid" - "github.com/ipld/go-ipld-prime" "github.com/libp2p/go-libp2p-core/peer" "github.com/stretchr/testify/require" "golang.org/x/xerrors" - datatransfer "github.com/filecoin-project/go-data-transfer" + datatransfer "github.com/filecoin-project/go-data-transfer/v2" + "github.com/filecoin-project/go-data-transfer/v2/testutil" ) var ch1 = datatransfer.ChannelID{ @@ -30,9 +29,6 @@ func TestChannelMonitorAutoRestart(t *testing.T) { } testCases := []testCase{{ name: "attempt restart", - }, { - name: "fail to reconnect to peer", - errReconnect: true, }, { name: "fail to send restart message", errSendRestartMsg: true, @@ -41,8 +37,8 @@ func TestChannelMonitorAutoRestart(t *testing.T) { runTest := func(name string, isPush bool) { for _, tc := range testCases { t.Run(name+": "+tc.name, func(t *testing.T) { - ch := &mockChannelState{chid: ch1} - mockAPI := newMockMonitorAPI(ch, tc.errReconnect, tc.errSendRestartMsg) + ch := testutil.NewMockChannelState(testutil.MockChannelStateParams{ChannelID: ch1}) + mockAPI := newMockMonitorAPI(ch, tc.errSendRestartMsg) triggerErrorEvent := func() { if isPush { @@ -60,9 +56,9 @@ func TestChannelMonitorAutoRestart(t *testing.T) { var mch *monitoredChannel if isPush { - mch = m.AddPushChannel(ch1) + mch = m.AddChannel(ch1, false) } else { - mch = m.AddPullChannel(ch1) + mch = m.AddChannel(ch1, true) } // Simulate the responder sending Accept @@ -115,8 +111,8 @@ func TestChannelMonitorAutoRestart(t *testing.T) { func TestChannelMonitorMaxConsecutiveRestarts(t *testing.T) { runTest := func(name string, isPush bool) { t.Run(name, func(t *testing.T) { - ch := &mockChannelState{chid: ch1} - mockAPI := newMockMonitorAPI(ch, false, false) + ch := testutil.NewMockChannelState(testutil.MockChannelStateParams{ChannelID: ch1}) + mockAPI := newMockMonitorAPI(ch, false) triggerErrorEvent := func() { if isPush { @@ -135,12 +131,12 @@ func TestChannelMonitorMaxConsecutiveRestarts(t *testing.T) { var mch *monitoredChannel if isPush { - mch = m.AddPushChannel(ch1) + mch = m.AddChannel(ch1, false) mockAPI.dataQueued(10) mockAPI.dataSent(5) } else { - mch = m.AddPullChannel(ch1) + mch = m.AddChannel(ch1, true) mockAPI.dataReceived(5) } @@ -198,8 +194,8 @@ func awaitRestartComplete(mch *monitoredChannel) error { func TestChannelMonitorQueuedRestart(t *testing.T) { runTest := func(name string, isPush bool) { t.Run(name, func(t *testing.T) { - ch := &mockChannelState{chid: ch1} - mockAPI := newMockMonitorAPI(ch, false, false) + ch := testutil.NewMockChannelState(testutil.MockChannelStateParams{ChannelID: ch1}) + mockAPI := newMockMonitorAPI(ch, false) triggerErrorEvent := func() { if isPush { @@ -217,12 +213,12 @@ func TestChannelMonitorQueuedRestart(t *testing.T) { }) if isPush { - m.AddPushChannel(ch1) + m.AddChannel(ch1, false) mockAPI.dataQueued(10) mockAPI.dataSent(5) } else { - m.AddPullChannel(ch1) + m.AddChannel(ch1, true) mockAPI.dataReceived(5) } @@ -285,8 +281,8 @@ func TestChannelMonitorTimeouts(t *testing.T) { runTest := func(name string, isPush bool) { for _, tc := range testCases { t.Run(name+": "+tc.name, func(t *testing.T) { - ch := &mockChannelState{chid: ch1} - mockAPI := newMockMonitorAPI(ch, false, false) + ch := testutil.NewMockChannelState(testutil.MockChannelStateParams{ChannelID: ch1}) + mockAPI := newMockMonitorAPI(ch, false) verifyClosedAndShutdown := func(chCtx context.Context, timeout time.Duration) { mockAPI.verifyChannelClosed(t, true) @@ -311,10 +307,10 @@ func TestChannelMonitorTimeouts(t *testing.T) { var chCtx context.Context if isPush { - mch := m.AddPushChannel(ch1) + mch := m.AddChannel(ch1, false) chCtx = mch.ctx } else { - mch := m.AddPullChannel(ch1) + mch := m.AddChannel(ch1, true) chCtx = mch.ctx } @@ -370,8 +366,7 @@ func verifyChannelShutdown(t *testing.T, shutdownCtx context.Context) { } type mockMonitorAPI struct { - ch *mockChannelState - connectErrors bool + ch *testutil.MockChannelState restartErrors bool restartMessages chan struct{} closeErr chan error @@ -380,10 +375,9 @@ type mockMonitorAPI struct { subscribers map[int]datatransfer.Subscriber } -func newMockMonitorAPI(ch *mockChannelState, errOnReconnect, errOnRestart bool) *mockMonitorAPI { +func newMockMonitorAPI(ch *testutil.MockChannelState, errOnRestart bool) *mockMonitorAPI { return &mockMonitorAPI{ ch: ch, - connectErrors: errOnReconnect, restartErrors: errOnRestart, restartMessages: make(chan struct{}, 100), closeErr: make(chan error, 1), @@ -415,13 +409,6 @@ func (m *mockMonitorAPI) fireEvent(e datatransfer.Event, state datatransfer.Chan } } -func (m *mockMonitorAPI) ConnectTo(ctx context.Context, id peer.ID) error { - if m.connectErrors { - return xerrors.Errorf("connect err") - } - return nil -} - func (m *mockMonitorAPI) PeerID() peer.ID { return "p" } @@ -482,17 +469,17 @@ func (m *mockMonitorAPI) accept() { } func (m *mockMonitorAPI) dataQueued(n uint64) { - m.ch.queued = n + m.ch.SetQueued(n) m.fireEvent(datatransfer.Event{Code: datatransfer.DataQueued}, m.ch) } func (m *mockMonitorAPI) dataSent(n uint64) { - m.ch.sent = n + m.ch.SetSent(n) m.fireEvent(datatransfer.Event{Code: datatransfer.DataSent}, m.ch) } func (m *mockMonitorAPI) dataReceived(n uint64) { - m.ch.received = n + m.ch.SetReceived(n) m.fireEvent(datatransfer.Event{Code: datatransfer.DataReceived}, m.ch) } @@ -501,7 +488,7 @@ func (m *mockMonitorAPI) finishTransfer() { } func (m *mockMonitorAPI) completed() { - m.ch.complete = true + m.ch.SetComplete(true) m.fireEvent(datatransfer.Event{Code: datatransfer.Complete}, m.ch) } @@ -512,112 +499,3 @@ func (m *mockMonitorAPI) sendDataErrorEvent() { func (m *mockMonitorAPI) receiveDataErrorEvent() { m.fireEvent(datatransfer.Event{Code: datatransfer.ReceiveDataError}, m.ch) } - -type mockChannelState struct { - chid datatransfer.ChannelID - queued uint64 - sent uint64 - received uint64 - complete bool -} - -var _ datatransfer.ChannelState = (*mockChannelState)(nil) - -func (m *mockChannelState) Queued() uint64 { - return m.queued -} - -func (m *mockChannelState) Sent() uint64 { - return m.sent -} - -func (m *mockChannelState) Received() uint64 { - return m.received -} - -func (m *mockChannelState) ChannelID() datatransfer.ChannelID { - return m.chid -} - -func (m *mockChannelState) Status() datatransfer.Status { - if m.complete { - return datatransfer.Completed - } - return datatransfer.Ongoing -} - -func (m *mockChannelState) TransferID() datatransfer.TransferID { - panic("implement me") -} - -func (m *mockChannelState) BaseCID() cid.Cid { - panic("implement me") -} - -func (m *mockChannelState) Selector() ipld.Node { - panic("implement me") -} - -func (m *mockChannelState) Voucher() datatransfer.Voucher { - panic("implement me") -} - -func (m *mockChannelState) Sender() peer.ID { - panic("implement me") -} - -func (m *mockChannelState) Recipient() peer.ID { - panic("implement me") -} - -func (m *mockChannelState) TotalSize() uint64 { - panic("implement me") -} - -func (m *mockChannelState) IsPull() bool { - panic("implement me") -} - -func (m *mockChannelState) OtherPeer() peer.ID { - panic("implement me") -} - -func (m *mockChannelState) SelfPeer() peer.ID { - panic("implement me") -} - -func (m *mockChannelState) Message() string { - panic("implement me") -} - -func (m *mockChannelState) Vouchers() []datatransfer.Voucher { - panic("implement me") -} - -func (m *mockChannelState) VoucherResults() []datatransfer.VoucherResult { - panic("implement me") -} - -func (m *mockChannelState) LastVoucher() datatransfer.Voucher { - panic("implement me") -} - -func (m *mockChannelState) LastVoucherResult() datatransfer.VoucherResult { - panic("implement me") -} - -func (m *mockChannelState) Stages() *datatransfer.ChannelStages { - panic("implement me") -} - -func (m *mockChannelState) ReceivedCidsTotal() int64 { - panic("implement me") -} - -func (m *mockChannelState) QueuedCidsTotal() int64 { - panic("implement me") -} - -func (m *mockChannelState) SentCidsTotal() int64 { - panic("implement me") -} diff --git a/channels/block_index_cache.go b/channels/block_index_cache.go deleted file mode 100644 index 490f77fd..00000000 --- a/channels/block_index_cache.go +++ /dev/null @@ -1,63 +0,0 @@ -package channels - -import ( - "sync" - "sync/atomic" - - datatransfer "github.com/filecoin-project/go-data-transfer" -) - -type readOriginalFn func(datatransfer.ChannelID) (int64, error) - -type blockIndexKey struct { - evt datatransfer.EventCode - chid datatransfer.ChannelID -} -type blockIndexCache struct { - lk sync.RWMutex - values map[blockIndexKey]*int64 -} - -func newBlockIndexCache() *blockIndexCache { - return &blockIndexCache{ - values: make(map[blockIndexKey]*int64), - } -} - -func (bic *blockIndexCache) getValue(evt datatransfer.EventCode, chid datatransfer.ChannelID, readFromOriginal readOriginalFn) (*int64, error) { - idxKey := blockIndexKey{evt, chid} - bic.lk.RLock() - value := bic.values[idxKey] - bic.lk.RUnlock() - if value != nil { - return value, nil - } - bic.lk.Lock() - defer bic.lk.Unlock() - value = bic.values[idxKey] - if value != nil { - return value, nil - } - newValue, err := readFromOriginal(chid) - if err != nil { - return nil, err - } - bic.values[idxKey] = &newValue - return &newValue, nil -} - -func (bic *blockIndexCache) updateIfGreater(evt datatransfer.EventCode, chid datatransfer.ChannelID, newIndex int64, readFromOriginal readOriginalFn) (bool, error) { - value, err := bic.getValue(evt, chid, readFromOriginal) - if err != nil { - return false, err - } - for { - currentIndex := atomic.LoadInt64(value) - if newIndex <= currentIndex { - return false, nil - } - if atomic.CompareAndSwapInt64(value, currentIndex, newIndex) { - return true, nil - } - } -} diff --git a/channels/channel_state.go b/channels/channel_state.go index bf89b85a..efabd74d 100644 --- a/channels/channel_state.go +++ b/channels/channel_state.go @@ -1,195 +1,161 @@ package channels import ( - "bytes" - "github.com/ipfs/go-cid" - "github.com/ipld/go-ipld-prime" - "github.com/ipld/go-ipld-prime/codec/dagcbor" - basicnode "github.com/ipld/go-ipld-prime/node/basic" + "github.com/ipld/go-ipld-prime/datamodel" peer "github.com/libp2p/go-libp2p-core/peer" - cbg "github.com/whyrusleeping/cbor-gen" - datatransfer "github.com/filecoin-project/go-data-transfer" - "github.com/filecoin-project/go-data-transfer/channels/internal" + datatransfer "github.com/filecoin-project/go-data-transfer/v2" + "github.com/filecoin-project/go-data-transfer/v2/channels/internal" ) // channelState is immutable channel data plus mutable state type channelState struct { - // peerId of the manager peer - selfPeer peer.ID - // an identifier for this channel shared by request and responder, set by requester through protocol - transferID datatransfer.TransferID - // base CID for the piece being transferred - baseCid cid.Cid - // portion of Piece to return, specified by an IPLD selector - selector *cbg.Deferred - // the party that is sending the data (not who initiated the request) - sender peer.ID - // the party that is receiving the data (not who initiated the request) - recipient peer.ID - // expected amount of data to be transferred - totalSize uint64 - // current status of this deal - status datatransfer.Status - // isPull indicates if this is a push or pull request - isPull bool - // total bytes read from this node and queued for sending (0 if receiver) - queued uint64 - // total bytes sent from this node (0 if receiver) - sent uint64 - // total bytes received by this node (0 if sender) - received uint64 - // number of blocks that have been received, including blocks that are - // present in more than one place in the DAG - receivedBlocksTotal int64 - // Number of blocks that have been queued, including blocks that are - // present in more than one place in the DAG - queuedBlocksTotal int64 - // Number of blocks that have been sent, including blocks that are - // present in more than one place in the DAG - sentBlocksTotal int64 - // more informative status on a channel - message string - // additional vouchers - vouchers []internal.EncodedVoucher - // additional voucherResults - voucherResults []internal.EncodedVoucherResult - voucherResultDecoder DecoderByTypeFunc - voucherDecoder DecoderByTypeFunc - - // stages tracks the timeline of events related to a data transfer, for - // traceability purposes. - stages *datatransfer.ChannelStages + ic internal.ChannelState } // EmptyChannelState is the zero value for channel state, meaning not present var EmptyChannelState = channelState{} // Status is the current status of this channel -func (c channelState) Status() datatransfer.Status { return c.status } +func (c channelState) Status() datatransfer.Status { return c.ic.Status } // Received returns the number of bytes received -func (c channelState) Queued() uint64 { return c.queued } +func (c channelState) Queued() uint64 { return c.ic.Queued } // Sent returns the number of bytes sent -func (c channelState) Sent() uint64 { return c.sent } +func (c channelState) Sent() uint64 { return c.ic.Sent } // Received returns the number of bytes received -func (c channelState) Received() uint64 { return c.received } +func (c channelState) Received() uint64 { return c.ic.Received } // TransferID returns the transfer id for this channel -func (c channelState) TransferID() datatransfer.TransferID { return c.transferID } +func (c channelState) TransferID() datatransfer.TransferID { return c.ic.TransferID } // BaseCID returns the CID that is at the root of this data transfer -func (c channelState) BaseCID() cid.Cid { return c.baseCid } +func (c channelState) BaseCID() cid.Cid { return c.ic.BaseCid } // Selector returns the IPLD selector for this data transfer (represented as // an IPLD node) -func (c channelState) Selector() ipld.Node { - builder := basicnode.Prototype.Any.NewBuilder() - reader := bytes.NewReader(c.selector.Raw) - err := dagcbor.Decode(builder, reader) - if err != nil { - log.Error(err) - } - return builder.Build() +func (c channelState) Selector() datamodel.Node { + return c.ic.Selector.Node } // Voucher returns the voucher for this data transfer -func (c channelState) Voucher() datatransfer.Voucher { - if len(c.vouchers) == 0 { - return nil +func (c channelState) Voucher() datatransfer.TypedVoucher { + if len(c.ic.Vouchers) == 0 { + return datatransfer.TypedVoucher{} } - decoder, _ := c.voucherDecoder(c.vouchers[0].Type) - encodable, _ := decoder.DecodeFromCbor(c.vouchers[0].Voucher.Raw) - return encodable.(datatransfer.Voucher) + ev := c.ic.Vouchers[0] + return datatransfer.TypedVoucher{Voucher: ev.Voucher.Node, Type: ev.Type} } -// ReceivedCidsTotal returns the number of (non-unique) cids received so far -// on the channel - note that a block can exist in more than one place in the DAG -func (c channelState) ReceivedCidsTotal() int64 { - return c.receivedBlocksTotal +// ReceivedIndex returns the index, a transport specific identifier for "where" +// we are in receiving data for a transfer +func (c channelState) ReceivedIndex() datamodel.Node { + return c.ic.ReceivedIndex.Node } -// QueuedCidsTotal returns the number of (non-unique) cids queued so far -// on the channel - note that a block can exist in more than one place in the DAG -func (c channelState) QueuedCidsTotal() int64 { - return c.queuedBlocksTotal +// QueuedIndex returns the index, a transport specific identifier for "where" +// we are in queing data for a transfer +func (c channelState) QueuedIndex() datamodel.Node { + return c.ic.QueuedIndex.Node } -// SentCidsTotal returns the number of (non-unique) cids sent so far -// on the channel - note that a block can exist in more than one place in the DAG -func (c channelState) SentCidsTotal() int64 { - return c.sentBlocksTotal +// SentIndex returns the index, a transport specific identifier for "where" +// we are in sending data for a transfer +func (c channelState) SentIndex() datamodel.Node { + return c.ic.SentIndex.Node } // Sender returns the peer id for the node that is sending data -func (c channelState) Sender() peer.ID { return c.sender } +func (c channelState) Sender() peer.ID { return c.ic.Sender } // Recipient returns the peer id for the node that is receiving data -func (c channelState) Recipient() peer.ID { return c.recipient } +func (c channelState) Recipient() peer.ID { return c.ic.Recipient } // TotalSize returns the total size for the data being transferred -func (c channelState) TotalSize() uint64 { return c.totalSize } +func (c channelState) TotalSize() uint64 { return c.ic.TotalSize } // IsPull returns whether this is a pull request based on who initiated it func (c channelState) IsPull() bool { - return c.isPull + return c.ic.Initiator == c.ic.Recipient } func (c channelState) ChannelID() datatransfer.ChannelID { - if c.isPull { - return datatransfer.ChannelID{ID: c.transferID, Initiator: c.recipient, Responder: c.sender} + if c.IsPull() { + return datatransfer.ChannelID{ID: c.ic.TransferID, Initiator: c.ic.Recipient, Responder: c.ic.Sender} } - return datatransfer.ChannelID{ID: c.transferID, Initiator: c.sender, Responder: c.recipient} + return datatransfer.ChannelID{ID: c.ic.TransferID, Initiator: c.ic.Sender, Responder: c.ic.Recipient} } func (c channelState) Message() string { - return c.message + return c.ic.Message } -func (c channelState) Vouchers() []datatransfer.Voucher { - vouchers := make([]datatransfer.Voucher, 0, len(c.vouchers)) - for _, encoded := range c.vouchers { - decoder, _ := c.voucherDecoder(encoded.Type) - encodable, _ := decoder.DecodeFromCbor(encoded.Voucher.Raw) - vouchers = append(vouchers, encodable.(datatransfer.Voucher)) +func (c channelState) Vouchers() []datatransfer.TypedVoucher { + vouchers := make([]datatransfer.TypedVoucher, 0, len(c.ic.Vouchers)) + for _, encoded := range c.ic.Vouchers { + vouchers = append(vouchers, datatransfer.TypedVoucher{Voucher: encoded.Voucher.Node, Type: encoded.Type}) } return vouchers } -func (c channelState) LastVoucher() datatransfer.Voucher { - decoder, _ := c.voucherDecoder(c.vouchers[len(c.vouchers)-1].Type) - encodable, _ := decoder.DecodeFromCbor(c.vouchers[len(c.vouchers)-1].Voucher.Raw) - return encodable.(datatransfer.Voucher) +func (c channelState) LastVoucher() datatransfer.TypedVoucher { + ev := c.ic.Vouchers[len(c.ic.Vouchers)-1] + + return datatransfer.TypedVoucher{Voucher: ev.Voucher.Node, Type: ev.Type} } -func (c channelState) LastVoucherResult() datatransfer.VoucherResult { - decoder, _ := c.voucherResultDecoder(c.voucherResults[len(c.voucherResults)-1].Type) - encodable, _ := decoder.DecodeFromCbor(c.voucherResults[len(c.voucherResults)-1].VoucherResult.Raw) - return encodable.(datatransfer.VoucherResult) +func (c channelState) LastVoucherResult() datatransfer.TypedVoucher { + evr := c.ic.VoucherResults[len(c.ic.VoucherResults)-1] + return datatransfer.TypedVoucher{Voucher: evr.VoucherResult.Node, Type: evr.Type} } -func (c channelState) VoucherResults() []datatransfer.VoucherResult { - voucherResults := make([]datatransfer.VoucherResult, 0, len(c.voucherResults)) - for _, encoded := range c.voucherResults { - decoder, _ := c.voucherResultDecoder(encoded.Type) - encodable, _ := decoder.DecodeFromCbor(encoded.VoucherResult.Raw) - voucherResults = append(voucherResults, encodable.(datatransfer.VoucherResult)) +func (c channelState) VoucherResults() []datatransfer.TypedVoucher { + voucherResults := make([]datatransfer.TypedVoucher, 0, len(c.ic.VoucherResults)) + for _, encoded := range c.ic.VoucherResults { + voucherResults = append(voucherResults, datatransfer.TypedVoucher{Voucher: encoded.VoucherResult.Node, Type: encoded.Type}) } return voucherResults } func (c channelState) SelfPeer() peer.ID { - return c.selfPeer + return c.ic.SelfPeer } func (c channelState) OtherPeer() peer.ID { - if c.sender == c.selfPeer { - return c.recipient + if c.ic.Sender == c.ic.SelfPeer { + return c.ic.Recipient } - return c.sender + return c.ic.Sender +} + +func (c channelState) DataLimit() uint64 { + return c.ic.DataLimit +} + +func (c channelState) RequiresFinalization() bool { + return c.ic.RequiresFinalization +} + +func (c channelState) InitiatorPaused() bool { + return c.ic.InitiatorPaused +} + +func (c channelState) ResponderPaused() bool { + return c.ic.ResponderPaused || c.ic.Status == datatransfer.Finalizing +} + +func (c channelState) BothPaused() bool { + return c.InitiatorPaused() && c.ResponderPaused() +} + +func (c channelState) SelfPaused() bool { + if c.ic.SelfPeer == c.ic.Initiator { + return c.InitiatorPaused() + } + return c.ResponderPaused() } // Stages returns the current ChannelStages object, or an empty object. @@ -198,39 +164,17 @@ func (c channelState) OtherPeer() peer.ID { // // EXPERIMENTAL; subject to change. func (c channelState) Stages() *datatransfer.ChannelStages { - if c.stages == nil { + if c.ic.Stages == nil { // return an empty placeholder; it will be discarded because the caller // is not supposed to mutate the value anyway. return &datatransfer.ChannelStages{} } - return c.stages -} - -func fromInternalChannelState(c internal.ChannelState, voucherDecoder DecoderByTypeFunc, voucherResultDecoder DecoderByTypeFunc) datatransfer.ChannelState { - return channelState{ - selfPeer: c.SelfPeer, - isPull: c.Initiator == c.Recipient, - transferID: c.TransferID, - baseCid: c.BaseCid, - selector: c.Selector, - sender: c.Sender, - recipient: c.Recipient, - totalSize: c.TotalSize, - status: c.Status, - queued: c.Queued, - sent: c.Sent, - received: c.Received, - receivedBlocksTotal: c.ReceivedBlocksTotal, - queuedBlocksTotal: c.QueuedBlocksTotal, - sentBlocksTotal: c.SentBlocksTotal, - message: c.Message, - vouchers: c.Vouchers, - voucherResults: c.VoucherResults, - voucherResultDecoder: voucherResultDecoder, - voucherDecoder: voucherDecoder, - stages: c.Stages, - } + return c.ic.Stages +} + +func fromInternalChannelState(c internal.ChannelState) datatransfer.ChannelState { + return channelState{ic: c} } var _ datatransfer.ChannelState = channelState{} diff --git a/channels/channels.go b/channels/channels.go index dd2ad45d..bc00c420 100644 --- a/channels/channels.go +++ b/channels/channels.go @@ -7,9 +7,8 @@ import ( "github.com/ipfs/go-cid" "github.com/ipfs/go-datastore" - "github.com/ipld/go-ipld-prime" + "github.com/ipld/go-ipld-prime/datamodel" peer "github.com/libp2p/go-libp2p-core/peer" - cbg "github.com/whyrusleeping/cbor-gen" "golang.org/x/xerrors" versioning "github.com/filecoin-project/go-ds-versioning/pkg" @@ -17,46 +16,25 @@ import ( "github.com/filecoin-project/go-statemachine" "github.com/filecoin-project/go-statemachine/fsm" - datatransfer "github.com/filecoin-project/go-data-transfer" - "github.com/filecoin-project/go-data-transfer/channels/internal" - "github.com/filecoin-project/go-data-transfer/channels/internal/migrations" - "github.com/filecoin-project/go-data-transfer/encoding" + datatransfer "github.com/filecoin-project/go-data-transfer/v2" + "github.com/filecoin-project/go-data-transfer/v2/channels/internal" + "github.com/filecoin-project/go-data-transfer/v2/channels/internal/migrations" ) -type DecoderByTypeFunc func(identifier datatransfer.TypeIdentifier) (encoding.Decoder, bool) - type Notifier func(datatransfer.Event, datatransfer.ChannelState) -// ErrNotFound is returned when a channel cannot be found with a given channel ID -type ErrNotFound struct { - ChannelID datatransfer.ChannelID -} - -func (e *ErrNotFound) Error() string { - return "No channel for channel ID " + e.ChannelID.String() -} - -func NewErrNotFound(chid datatransfer.ChannelID) error { - return &ErrNotFound{ChannelID: chid} -} - // ErrWrongType is returned when a caller attempts to change the type of implementation data after setting it var ErrWrongType = errors.New("Cannot change type of implementation specific data after setting it") // Channels is a thread safe list of channels type Channels struct { notifier Notifier - voucherDecoder DecoderByTypeFunc - voucherResultDecoder DecoderByTypeFunc - blockIndexCache *blockIndexCache stateMachines fsm.Group migrateStateMachines func(context.Context) error } // ChannelEnvironment -- just a proxy for DTNetwork for now type ChannelEnvironment interface { - Protect(id peer.ID, tag string) - Unprotect(id peer.ID, tag string) bool ID() peer.ID CleanupChannel(chid datatransfer.ChannelID) } @@ -64,17 +42,10 @@ type ChannelEnvironment interface { // New returns a new thread safe list of channels func New(ds datastore.Batching, notifier Notifier, - voucherDecoder DecoderByTypeFunc, - voucherResultDecoder DecoderByTypeFunc, env ChannelEnvironment, selfPeer peer.ID) (*Channels, error) { - c := &Channels{ - notifier: notifier, - voucherDecoder: voucherDecoder, - voucherResultDecoder: voucherResultDecoder, - } - c.blockIndexCache = newBlockIndexCache() + c := &Channels{notifier: notifier} channelMigrations, err := migrations.GetChannelStateMigrations(selfPeer) if err != nil { return nil, err @@ -87,7 +58,7 @@ func New(ds datastore.Batching, StateEntryFuncs: ChannelStateEntryFuncs, Notifier: c.dispatch, FinalityStates: ChannelFinalityStates, - }, channelMigrations, versioning.VersionKey("2")) + }, channelMigrations, versioning.VersionKey("3")) if err != nil { return nil, err } @@ -119,7 +90,7 @@ func (c *Channels) dispatch(eventName fsm.EventName, channel fsm.StateType) { // CreateNew creates a new channel id and channel state and saves to channels. // returns error if the channel exists already. -func (c *Channels) CreateNew(selfPeer peer.ID, tid datatransfer.TransferID, baseCid cid.Cid, selector ipld.Node, voucher datatransfer.Voucher, initiator, dataSender, dataReceiver peer.ID) (datatransfer.ChannelID, error) { +func (c *Channels) CreateNew(selfPeer peer.ID, tid datatransfer.TransferID, baseCid cid.Cid, selector datamodel.Node, voucher datatransfer.TypedVoucher, initiator, dataSender, dataReceiver peer.ID) (datatransfer.ChannelID, datatransfer.Channel, error) { var responder peer.ID if dataSender == initiator { responder = dataReceiver @@ -127,40 +98,31 @@ func (c *Channels) CreateNew(selfPeer peer.ID, tid datatransfer.TransferID, base responder = dataSender } chid := datatransfer.ChannelID{Initiator: initiator, Responder: responder, ID: tid} - voucherBytes, err := encoding.Encode(voucher) - if err != nil { - return datatransfer.ChannelID{}, err - } - selBytes, err := encoding.Encode(selector) - if err != nil { - return datatransfer.ChannelID{}, err - } - err = c.stateMachines.Begin(chid, &internal.ChannelState{ + channel := &internal.ChannelState{ SelfPeer: selfPeer, TransferID: tid, Initiator: initiator, Responder: responder, BaseCid: baseCid, - Selector: &cbg.Deferred{Raw: selBytes}, + Selector: internal.CborGenCompatibleNode{Node: selector}, Sender: dataSender, Recipient: dataReceiver, Stages: &datatransfer.ChannelStages{}, Vouchers: []internal.EncodedVoucher{ { - Type: voucher.Type(), - Voucher: &cbg.Deferred{ - Raw: voucherBytes, - }, + Type: voucher.Type, + Voucher: internal.CborGenCompatibleNode{voucher.Voucher}, }, }, Status: datatransfer.Requested, - }) + } + err := c.stateMachines.Begin(chid, channel) if err != nil { log.Errorw("failed to create new tracking channel for data-transfer", "channelID", chid, "err", err) - return datatransfer.ChannelID{}, err + return datatransfer.ChannelID{}, nil, err } log.Debugw("created tracking channel for data-transfer, emitting channel Open event", "channelID", chid) - return chid, c.stateMachines.Send(chid, datatransfer.Open) + return chid, c.fromInternalChannelState(*channel), c.stateMachines.Send(chid, datatransfer.Open) } // InProgress returns a list of in progress channels @@ -184,7 +146,7 @@ func (c *Channels) GetByID(ctx context.Context, chid datatransfer.ChannelID) (da var internalChannel internal.ChannelState err := c.stateMachines.GetSync(ctx, chid, &internalChannel) if err != nil { - return nil, NewErrNotFound(chid) + return nil, datatransfer.ErrChannelNotFound } return c.fromInternalChannelState(internalChannel), nil } @@ -198,8 +160,8 @@ func (c *Channels) ChannelOpened(chid datatransfer.ChannelID) error { return c.send(chid, datatransfer.Opened) } -func (c *Channels) TransferRequestQueued(chid datatransfer.ChannelID) error { - return c.send(chid, datatransfer.TransferRequestQueued) +func (c *Channels) TransferInitiated(chid datatransfer.ChannelID) error { + return c.send(chid, datatransfer.TransferInitiated) } // Restart marks a data transfer as restarted @@ -207,46 +169,29 @@ func (c *Channels) Restart(chid datatransfer.ChannelID) error { return c.send(chid, datatransfer.Restart) } +// CompleteCleanupOnRestart tells a channel to restart func (c *Channels) CompleteCleanupOnRestart(chid datatransfer.ChannelID) error { return c.send(chid, datatransfer.CompleteCleanupOnRestart) } -func (c *Channels) getQueuedIndex(chid datatransfer.ChannelID) (int64, error) { - chst, err := c.GetByID(context.TODO(), chid) - if err != nil { - return 0, err - } - return chst.QueuedCidsTotal(), nil -} - -func (c *Channels) getReceivedIndex(chid datatransfer.ChannelID) (int64, error) { - chst, err := c.GetByID(context.TODO(), chid) - if err != nil { - return 0, err - } - return chst.ReceivedCidsTotal(), nil -} - -func (c *Channels) getSentIndex(chid datatransfer.ChannelID) (int64, error) { - chst, err := c.GetByID(context.TODO(), chid) - if err != nil { - return 0, err - } - return chst.SentCidsTotal(), nil +// DataSent records data being sent +func (c *Channels) DataSent(chid datatransfer.ChannelID, delta uint64, index datamodel.Node) error { + return c.fireProgressEvent(chid, datatransfer.DataSent, datatransfer.DataSentProgress, delta, index) } -func (c *Channels) DataSent(chid datatransfer.ChannelID, k cid.Cid, delta uint64, index int64, unique bool) (bool, error) { - return c.fireProgressEvent(chid, datatransfer.DataSent, datatransfer.DataSentProgress, k, delta, index, unique, c.getSentIndex) +// DataQueued records data being queued +func (c *Channels) DataQueued(chid datatransfer.ChannelID, delta uint64, index datamodel.Node) error { + return c.fireProgressEvent(chid, datatransfer.DataQueued, datatransfer.DataQueuedProgress, delta, index) } -func (c *Channels) DataQueued(chid datatransfer.ChannelID, k cid.Cid, delta uint64, index int64, unique bool) (bool, error) { - return c.fireProgressEvent(chid, datatransfer.DataQueued, datatransfer.DataQueuedProgress, k, delta, index, unique, c.getQueuedIndex) +// DataReceived records data being received +func (c *Channels) DataReceived(chid datatransfer.ChannelID, delta uint64, index datamodel.Node) error { + return c.fireProgressEvent(chid, datatransfer.DataReceived, datatransfer.DataReceivedProgress, delta, index) } -// Returns true if this is the first time the block has been received -func (c *Channels) DataReceived(chid datatransfer.ChannelID, k cid.Cid, delta uint64, index int64, unique bool) (bool, error) { - new, err := c.fireProgressEvent(chid, datatransfer.DataReceived, datatransfer.DataReceivedProgress, k, delta, index, unique, c.getReceivedIndex) - return new, err +// DataLimitExceeded records a data limit exceeded event +func (c *Channels) DataLimitExceeded(chid datatransfer.ChannelID) error { + return c.send(chid, datatransfer.DataLimitExceeded) } // PauseInitiator pauses the initator of this channel @@ -270,21 +215,13 @@ func (c *Channels) ResumeResponder(chid datatransfer.ChannelID) error { } // NewVoucher records a new voucher for this channel -func (c *Channels) NewVoucher(chid datatransfer.ChannelID, voucher datatransfer.Voucher) error { - voucherBytes, err := encoding.Encode(voucher) - if err != nil { - return err - } - return c.send(chid, datatransfer.NewVoucher, voucher.Type(), voucherBytes) +func (c *Channels) NewVoucher(chid datatransfer.ChannelID, voucher datatransfer.TypedVoucher) error { + return c.send(chid, datatransfer.NewVoucher, voucher) } // NewVoucherResult records a new voucher result for this channel -func (c *Channels) NewVoucherResult(chid datatransfer.ChannelID, voucherResult datatransfer.VoucherResult) error { - voucherResultBytes, err := encoding.Encode(voucherResult) - if err != nil { - return err - } - return c.send(chid, datatransfer.NewVoucherResult, voucherResult.Type(), voucherResultBytes) +func (c *Channels) NewVoucherResult(chid datatransfer.ChannelID, voucherResult datatransfer.TypedVoucher) error { + return c.send(chid, datatransfer.NewVoucherResult, voucherResult) } // Complete indicates responder has completed sending/receiving data @@ -354,6 +291,21 @@ func (c *Channels) ReceiveDataError(chid datatransfer.ChannelID, err error) erro return c.send(chid, datatransfer.ReceiveDataError, err) } +// SendMessageError indicates an error sending a message to the transport layer +func (c *Channels) SendMessageError(chid datatransfer.ChannelID, err error) error { + return c.send(chid, datatransfer.SendMessageError, err) +} + +// SetDataLimit means a data limit has been set on this channel +func (c *Channels) SetDataLimit(chid datatransfer.ChannelID, dataLimit uint64) error { + return c.send(chid, datatransfer.SetDataLimit, dataLimit) +} + +// SetRequiresFinalization sets the state of whether a data transfer can complete +func (c *Channels) SetRequiresFinalization(chid datatransfer.ChannelID, RequiresFinalization bool) error { + return c.send(chid, datatransfer.SetRequiresFinalization, RequiresFinalization) +} + // HasChannel returns true if the given channel id is being tracked func (c *Channels) HasChannel(chid datatransfer.ChannelID) (bool, error) { return c.stateMachines.Has(chid) @@ -361,30 +313,19 @@ func (c *Channels) HasChannel(chid datatransfer.ChannelID) (bool, error) { // fireProgressEvent fires // - an event for queuing / sending / receiving blocks -// - a corresponding "progress" event if the block has not been seen before -// For example, if a block is being sent for the first time, the method will -// fire both DataSent AND DataSentProgress. -// If a block is resent, the method will fire DataSent but not DataSentProgress. -// Returns true if the block is new (both the event and a progress event were fired). -func (c *Channels) fireProgressEvent(chid datatransfer.ChannelID, evt datatransfer.EventCode, progressEvt datatransfer.EventCode, k cid.Cid, delta uint64, index int64, unique bool, readFromOriginal readOriginalFn) (bool, error) { +// - a corresponding "progress" event +func (c *Channels) fireProgressEvent(chid datatransfer.ChannelID, evt datatransfer.EventCode, progressEvt datatransfer.EventCode, delta uint64, index datamodel.Node) error { if err := c.checkChannelExists(chid, evt); err != nil { - return false, err - } - - isNewIndex, err := c.blockIndexCache.updateIfGreater(evt, chid, index, readFromOriginal) - if err != nil { - return false, err + return err } - // If the block has not been seen before, fire the progress event - if unique && isNewIndex { - if err := c.stateMachines.Send(chid, progressEvt, delta); err != nil { - return false, err - } + // Fire the progress event + if err := c.stateMachines.Send(chid, progressEvt, delta); err != nil { + return err } // Fire the regular event - return unique && isNewIndex, c.stateMachines.Send(chid, evt, index) + return c.stateMachines.Send(chid, evt, index) } func (c *Channels) send(chid datatransfer.ChannelID, code datatransfer.EventCode, args ...interface{}) error { @@ -403,12 +344,12 @@ func (c *Channels) checkChannelExists(chid datatransfer.ChannelID, code datatran } if !has { return xerrors.Errorf("cannot send FSM event %s to data-transfer channel %s: %w", - datatransfer.Events[code], chid, NewErrNotFound(chid)) + datatransfer.Events[code], chid, datatransfer.ErrChannelNotFound) } return nil } // Convert from the internally used channel state format to the externally exposed ChannelState func (c *Channels) fromInternalChannelState(ch internal.ChannelState) datatransfer.ChannelState { - return fromInternalChannelState(ch, c.voucherDecoder, c.voucherResultDecoder) + return fromInternalChannelState(ch) } diff --git a/channels/channels_fsm.go b/channels/channels_fsm.go index 6fd51967..51dfc6d5 100644 --- a/channels/channels_fsm.go +++ b/channels/channels_fsm.go @@ -2,26 +2,16 @@ package channels import ( logging "github.com/ipfs/go-log/v2" - cbg "github.com/whyrusleeping/cbor-gen" + "github.com/ipld/go-ipld-prime/datamodel" "github.com/filecoin-project/go-statemachine/fsm" - datatransfer "github.com/filecoin-project/go-data-transfer" - "github.com/filecoin-project/go-data-transfer/channels/internal" + datatransfer "github.com/filecoin-project/go-data-transfer/v2" + "github.com/filecoin-project/go-data-transfer/v2/channels/internal" ) var log = logging.Logger("data-transfer") -var transferringStates = []fsm.StateKey{ - datatransfer.Requested, - datatransfer.Ongoing, - datatransfer.InitiatorPaused, - datatransfer.ResponderPaused, - datatransfer.BothPaused, - datatransfer.ResponderCompleted, - datatransfer.ResponderFinalizing, -} - // ChannelEvents describe the events taht can var ChannelEvents = fsm.Events{ // Open a channel @@ -29,23 +19,32 @@ var ChannelEvents = fsm.Events{ chst.AddLog("") return nil }), + // Remote peer has accepted the Open channel request - fsm.Event(datatransfer.Accept).From(datatransfer.Requested).To(datatransfer.Ongoing).Action(func(chst *internal.ChannelState) error { - chst.AddLog("") - return nil - }), + fsm.Event(datatransfer.Accept). + From(datatransfer.Requested).To(datatransfer.Queued). + From(datatransfer.AwaitingAcceptance).To(datatransfer.Ongoing). + Action(func(chst *internal.ChannelState) error { + chst.AddLog("") + return nil + }), - fsm.Event(datatransfer.TransferRequestQueued).FromAny().ToJustRecord().Action(func(chst *internal.ChannelState) error { - chst.Message = "" - chst.AddLog("") - return nil - }), + // The transport has indicated it's begun sending/receiving data + fsm.Event(datatransfer.TransferInitiated). + From(datatransfer.Requested).To(datatransfer.AwaitingAcceptance). + From(datatransfer.Queued).To(datatransfer.Ongoing). + From(datatransfer.Ongoing).ToJustRecord(). + Action(func(chst *internal.ChannelState) error { + chst.AddLog("") + return nil + }), fsm.Event(datatransfer.Restart).FromAny().ToJustRecord().Action(func(chst *internal.ChannelState) error { chst.Message = "" chst.AddLog("") return nil }), + fsm.Event(datatransfer.Cancel).FromAny().To(datatransfer.Cancelling).Action(func(chst *internal.ChannelState) error { chst.AddLog("") return nil @@ -60,75 +59,90 @@ var ChannelEvents = fsm.Events{ return nil }), - fsm.Event(datatransfer.DataReceived).FromAny().ToNoChange(). - Action(func(chst *internal.ChannelState, rcvdBlocksTotal int64) error { - if rcvdBlocksTotal > chst.ReceivedBlocksTotal { - chst.ReceivedBlocksTotal = rcvdBlocksTotal - } + fsm.Event(datatransfer.DataReceived).FromMany(datatransfer.TransferringStates.AsFSMStates()...).ToNoChange(). + Action(func(chst *internal.ChannelState, receivedIndex datamodel.Node) error { + chst.ReceivedIndex = internal.CborGenCompatibleNode{Node: receivedIndex} chst.AddLog("") return nil }), - fsm.Event(datatransfer.DataReceivedProgress).FromMany(transferringStates...).ToNoChange(). + fsm.Event(datatransfer.DataReceivedProgress).FromMany(datatransfer.TransferringStates.AsFSMStates()...).ToNoChange(). Action(func(chst *internal.ChannelState, delta uint64) error { chst.Received += delta chst.AddLog("received data") return nil }), - fsm.Event(datatransfer.DataSent). - FromMany(transferringStates...).ToNoChange(). - From(datatransfer.TransferFinished).ToNoChange(). - Action(func(chst *internal.ChannelState, sentBlocksTotal int64) error { - if sentBlocksTotal > chst.SentBlocksTotal { - chst.SentBlocksTotal = sentBlocksTotal - } + fsm.Event(datatransfer.DataSent).FromMany(datatransfer.TransferringStates.AsFSMStates()...).ToNoChange(). + Action(func(chst *internal.ChannelState, sentIndex datamodel.Node) error { + chst.SentIndex = internal.CborGenCompatibleNode{Node: sentIndex} chst.AddLog("") return nil }), - fsm.Event(datatransfer.DataSentProgress).FromMany(transferringStates...).ToNoChange(). + fsm.Event(datatransfer.DataSentProgress).FromMany(datatransfer.TransferringStates.AsFSMStates()...).ToNoChange(). Action(func(chst *internal.ChannelState, delta uint64) error { chst.Sent += delta chst.AddLog("sending data") return nil }), - fsm.Event(datatransfer.DataQueued). - FromMany(transferringStates...).ToNoChange(). - From(datatransfer.TransferFinished).ToNoChange(). - Action(func(chst *internal.ChannelState, queuedBlocksTotal int64) error { - if queuedBlocksTotal > chst.QueuedBlocksTotal { - chst.QueuedBlocksTotal = queuedBlocksTotal - } + fsm.Event(datatransfer.DataQueued).FromMany(datatransfer.TransferringStates.AsFSMStates()...).ToNoChange(). + Action(func(chst *internal.ChannelState, queuedIndex datamodel.Node) error { + chst.QueuedIndex = internal.CborGenCompatibleNode{Node: queuedIndex} chst.AddLog("") return nil }), - fsm.Event(datatransfer.DataQueuedProgress).FromMany(transferringStates...).ToNoChange(). + fsm.Event(datatransfer.DataQueuedProgress).FromMany(datatransfer.TransferringStates.AsFSMStates()...).ToNoChange(). Action(func(chst *internal.ChannelState, delta uint64) error { chst.Queued += delta chst.AddLog("") return nil }), + + fsm.Event(datatransfer.SetDataLimit).FromAny().ToJustRecord(). + Action(func(chst *internal.ChannelState, dataLimit uint64) error { + chst.DataLimit = dataLimit + chst.AddLog("") + return nil + }), + + fsm.Event(datatransfer.SetRequiresFinalization).FromAny().ToJustRecord(). + Action(func(chst *internal.ChannelState, RequiresFinalization bool) error { + chst.RequiresFinalization = RequiresFinalization + chst.AddLog("") + return nil + }), + fsm.Event(datatransfer.Disconnected).FromAny().ToNoChange().Action(func(chst *internal.ChannelState, err error) error { chst.Message = err.Error() chst.AddLog("data transfer disconnected: %s", chst.Message) return nil }), + fsm.Event(datatransfer.SendDataError).FromAny().ToNoChange().Action(func(chst *internal.ChannelState, err error) error { chst.Message = err.Error() chst.AddLog("data transfer send error: %s", chst.Message) return nil }), + fsm.Event(datatransfer.ReceiveDataError).FromAny().ToNoChange().Action(func(chst *internal.ChannelState, err error) error { chst.Message = err.Error() chst.AddLog("data transfer receive error: %s", chst.Message) return nil }), + + fsm.Event(datatransfer.SendMessageError).FromAny().ToNoChange().Action(func(chst *internal.ChannelState, err error) error { + chst.Message = err.Error() + chst.AddLog("data transfer errored sending message: %s", chst.Message) + return nil + }), + fsm.Event(datatransfer.RequestCancelled).FromAny().ToNoChange().Action(func(chst *internal.ChannelState, err error) error { chst.Message = err.Error() chst.AddLog("data transfer request cancelled: %s", chst.Message) return nil }), + fsm.Event(datatransfer.Error).FromAny().To(datatransfer.Failing).Action(func(chst *internal.ChannelState, err error) error { chst.Message = err.Error() chst.AddLog("data transfer erred: %s", chst.Message) @@ -136,51 +150,68 @@ var ChannelEvents = fsm.Events{ }), fsm.Event(datatransfer.NewVoucher).FromAny().ToNoChange(). - Action(func(chst *internal.ChannelState, vtype datatransfer.TypeIdentifier, voucherBytes []byte) error { - chst.Vouchers = append(chst.Vouchers, internal.EncodedVoucher{Type: vtype, Voucher: &cbg.Deferred{Raw: voucherBytes}}) + Action(func(chst *internal.ChannelState, voucher datatransfer.TypedVoucher) error { + chst.Vouchers = append(chst.Vouchers, internal.EncodedVoucher{Type: voucher.Type, Voucher: internal.CborGenCompatibleNode{Node: voucher.Voucher}}) chst.AddLog("got new voucher") return nil }), + fsm.Event(datatransfer.NewVoucherResult).FromAny().ToNoChange(). - Action(func(chst *internal.ChannelState, vtype datatransfer.TypeIdentifier, voucherResultBytes []byte) error { + Action(func(chst *internal.ChannelState, voucherResult datatransfer.TypedVoucher) error { chst.VoucherResults = append(chst.VoucherResults, - internal.EncodedVoucherResult{Type: vtype, VoucherResult: &cbg.Deferred{Raw: voucherResultBytes}}) + internal.EncodedVoucherResult{Type: voucherResult.Type, VoucherResult: internal.CborGenCompatibleNode{Node: voucherResult.Voucher}}) chst.AddLog("got new voucher result") return nil }), + // TODO: There are four states from which the request can be "paused": request, queued, awaiting acceptance + // and ongoing. There four states of being + // paused (no pause, initiator pause, responder pause, both paused). Until the state machine software + // supports orthogonal regions (https://en.wikipedia.org/wiki/UML_state_machine#Orthogonal_regions) + // we end up with a cartesian product of states and as you can see, fairly complicated state transfers. + // Previously, we had dealt with this by moving directly to the Ongoing state upon return from pause but this + // seems less than ideal. We need some kind of support for pausing being an independent aspect of state + // Possibly we should just remove whether a state is paused from the state entirely. fsm.Event(datatransfer.PauseInitiator). - FromMany(datatransfer.Requested, datatransfer.Ongoing).To(datatransfer.InitiatorPaused). - From(datatransfer.ResponderPaused).To(datatransfer.BothPaused). - FromAny().ToJustRecord().Action(func(chst *internal.ChannelState) error { - chst.AddLog("") - return nil - }), + FromMany(datatransfer.Ongoing, datatransfer.Requested, datatransfer.Queued, datatransfer.AwaitingAcceptance).ToJustRecord(). + Action(func(chst *internal.ChannelState) error { + chst.InitiatorPaused = true + chst.AddLog("") + return nil + }), fsm.Event(datatransfer.PauseResponder). - FromMany(datatransfer.Requested, datatransfer.Ongoing).To(datatransfer.ResponderPaused). - From(datatransfer.InitiatorPaused).To(datatransfer.BothPaused). - FromAny().ToJustRecord().Action(func(chst *internal.ChannelState) error { - chst.AddLog("") - return nil - }), + FromMany(datatransfer.Ongoing, datatransfer.Requested, datatransfer.Queued, datatransfer.AwaitingAcceptance, datatransfer.TransferFinished).ToJustRecord(). + Action(func(chst *internal.ChannelState) error { + chst.ResponderPaused = true + chst.AddLog("") + return nil + }), + + fsm.Event(datatransfer.DataLimitExceeded). + FromMany(datatransfer.Ongoing, datatransfer.Requested, datatransfer.Queued, datatransfer.AwaitingAcceptance, datatransfer.ResponderCompleted, datatransfer.ResponderFinalizing).ToJustRecord(). + Action(func(chst *internal.ChannelState) error { + chst.ResponderPaused = true + chst.AddLog("") + return nil + }), fsm.Event(datatransfer.ResumeInitiator). - From(datatransfer.InitiatorPaused).To(datatransfer.Ongoing). - From(datatransfer.BothPaused).To(datatransfer.ResponderPaused). - FromAny().ToJustRecord().Action(func(chst *internal.ChannelState) error { - chst.AddLog("") - return nil - }), + FromMany(datatransfer.Ongoing, datatransfer.Requested, datatransfer.Queued, datatransfer.AwaitingAcceptance, datatransfer.ResponderCompleted, datatransfer.ResponderFinalizing).ToJustRecord(). + Action(func(chst *internal.ChannelState) error { + chst.InitiatorPaused = false + chst.AddLog("") + return nil + }), fsm.Event(datatransfer.ResumeResponder). - From(datatransfer.ResponderPaused).To(datatransfer.Ongoing). - From(datatransfer.BothPaused).To(datatransfer.InitiatorPaused). + FromMany(datatransfer.Ongoing, datatransfer.Requested, datatransfer.Queued, datatransfer.AwaitingAcceptance, datatransfer.TransferFinished).ToJustRecord(). From(datatransfer.Finalizing).To(datatransfer.Completing). - FromAny().ToJustRecord().Action(func(chst *internal.ChannelState) error { - chst.AddLog("") - return nil - }), + Action(func(chst *internal.ChannelState) error { + chst.ResponderPaused = false + chst.AddLog("") + return nil + }), // The transfer has finished on the local node - all data was sent / received fsm.Event(datatransfer.FinishTransfer). @@ -188,10 +219,10 @@ var ChannelEvents = fsm.Events{ FromMany(datatransfer.Failing, datatransfer.Cancelling).ToJustRecord(). From(datatransfer.ResponderCompleted).To(datatransfer.Completing). From(datatransfer.ResponderFinalizing).To(datatransfer.ResponderFinalizingTransferFinished). - // If we are in the requested state, it means the other party simply never responded to our + // If we are in the AwaitingAcceptance state, it means the other party simply never responded to our // our data transfer, or we never actually contacted them. In any case, it's safe to skip // the finalization process and complete the transfer - From(datatransfer.Requested).To(datatransfer.Completing). + From(datatransfer.AwaitingAcceptance).To(datatransfer.Completing). Action(func(chst *internal.ChannelState) error { chst.AddLog("") return nil @@ -200,7 +231,8 @@ var ChannelEvents = fsm.Events{ fsm.Event(datatransfer.ResponderBeginsFinalization). FromAny().To(datatransfer.ResponderFinalizing). FromMany(datatransfer.Failing, datatransfer.Cancelling).ToJustRecord(). - From(datatransfer.TransferFinished).To(datatransfer.ResponderFinalizingTransferFinished).Action(func(chst *internal.ChannelState) error { + From(datatransfer.TransferFinished).To(datatransfer.ResponderFinalizingTransferFinished). + FromMany(datatransfer.ResponderFinalizing, datatransfer.ResponderFinalizingTransferFinished).ToJustRecord().Action(func(chst *internal.ChannelState) error { chst.AddLog("") return nil }), @@ -209,9 +241,7 @@ var ChannelEvents = fsm.Events{ fsm.Event(datatransfer.ResponderCompletes). FromAny().To(datatransfer.ResponderCompleted). FromMany(datatransfer.Failing, datatransfer.Cancelling).ToJustRecord(). - From(datatransfer.ResponderPaused).To(datatransfer.ResponderFinalizing). From(datatransfer.TransferFinished).To(datatransfer.Completing). - From(datatransfer.ResponderFinalizing).To(datatransfer.ResponderCompleted). From(datatransfer.ResponderFinalizingTransferFinished).To(datatransfer.Completing).Action(func(chst *internal.ChannelState) error { chst.AddLog("") return nil @@ -257,7 +287,6 @@ func cleanupConnection(ctx fsm.Context, env ChannelEnvironment, channel internal otherParty = channel.Responder } env.CleanupChannel(datatransfer.ChannelID{ID: channel.TransferID, Initiator: channel.Initiator, Responder: channel.Responder}) - env.Unprotect(otherParty, datatransfer.ChannelID{ID: channel.TransferID, Initiator: channel.Initiator, Responder: channel.Responder}.String()) return ctx.Trigger(datatransfer.CleanupComplete) } diff --git a/channels/channels_test.go b/channels/channels_test.go index c6fd7478..3e26af1d 100644 --- a/channels/channels_test.go +++ b/channels/channels_test.go @@ -1,11 +1,14 @@ package channels_test import ( + "bytes" "context" "errors" + "math/rand" "testing" "time" + "github.com/ipfs/go-cid" "github.com/ipfs/go-datastore" dss "github.com/ipfs/go-datastore/sync" basicnode "github.com/ipld/go-ipld-prime/node/basic" @@ -14,10 +17,14 @@ import ( "github.com/stretchr/testify/require" "golang.org/x/xerrors" - datatransfer "github.com/filecoin-project/go-data-transfer" - "github.com/filecoin-project/go-data-transfer/channels" - "github.com/filecoin-project/go-data-transfer/encoding" - "github.com/filecoin-project/go-data-transfer/testutil" + versioning "github.com/filecoin-project/go-ds-versioning/pkg" + versionedds "github.com/filecoin-project/go-ds-versioning/pkg/datastore" + + datatransfer "github.com/filecoin-project/go-data-transfer/v2" + "github.com/filecoin-project/go-data-transfer/v2/channels" + "github.com/filecoin-project/go-data-transfer/v2/channels/internal" + "github.com/filecoin-project/go-data-transfer/v2/channels/internal/migrations" + "github.com/filecoin-project/go-data-transfer/v2/testutil" ) func TestChannels(t *testing.T) { @@ -32,31 +39,31 @@ func TestChannels(t *testing.T) { tid1 := datatransfer.TransferID(0) tid2 := datatransfer.TransferID(1) - fv1 := &testutil.FakeDTType{} - fv2 := &testutil.FakeDTType{} - cids := testutil.GenerateCids(2) + fv1 := testutil.NewTestTypedVoucher() + fv2 := testutil.NewTestTypedVoucher() + cids := testutil.GenerateCids(4) selector := builder.NewSelectorSpecBuilder(basicnode.Prototype.Any).Matcher().Node() peers := testutil.GeneratePeers(4) - channelList, err := channels.New(ds, notifier, decoderByType, decoderByType, &fakeEnv{}, peers[0]) + channelList, err := channels.New(ds, notifier, &fakeEnv{}, peers[0]) require.NoError(t, err) err = channelList.Start(ctx) require.NoError(t, err) t.Run("adding channels", func(t *testing.T) { - chid, err := channelList.CreateNew(peers[0], tid1, cids[0], selector, fv1, peers[0], peers[0], peers[1]) + chid, _, err := channelList.CreateNew(peers[0], tid1, cids[0], selector, fv1, peers[0], peers[0], peers[1]) require.NoError(t, err) require.Equal(t, peers[0], chid.Initiator) require.Equal(t, tid1, chid.ID) // cannot add twice for same channel id - _, err = channelList.CreateNew(peers[0], tid1, cids[1], selector, fv2, peers[0], peers[1], peers[0]) + _, _, err = channelList.CreateNew(peers[0], tid1, cids[1], selector, fv2, peers[0], peers[1], peers[0]) require.Error(t, err) state := checkEvent(ctx, t, received, datatransfer.Open) require.Equal(t, datatransfer.Requested, state.Status()) // can add for different id - chid, err = channelList.CreateNew(peers[2], tid2, cids[1], selector, fv2, peers[3], peers[2], peers[3]) + chid, _, err = channelList.CreateNew(peers[2], tid2, cids[1], selector, fv2, peers[3], peers[2], peers[3]) require.NoError(t, err) require.Equal(t, peers[3], chid.Initiator) require.Equal(t, tid2, chid.ID) @@ -80,14 +87,15 @@ func TestChannels(t *testing.T) { require.NotEqual(t, channels.EmptyChannelState, state) require.Equal(t, cids[0], state.BaseCID()) require.Equal(t, selector, state.Selector()) - require.Equal(t, fv1, state.Voucher()) + voucher := state.Voucher() + require.True(t, fv1.Equals(voucher)) require.Equal(t, peers[0], state.Sender()) require.Equal(t, peers[1], state.Recipient()) // empty if channel does not exist state, err = channelList.GetByID(ctx, datatransfer.ChannelID{Initiator: peers[1], Responder: peers[1], ID: tid1}) require.Equal(t, nil, state) - require.True(t, xerrors.As(err, new(*channels.ErrNotFound))) + require.True(t, errors.Is(err, datatransfer.ErrChannelNotFound)) // works for other channel as well state, err = channelList.GetByID(ctx, datatransfer.ChannelID{Initiator: peers[3], Responder: peers[2], ID: tid2}) @@ -104,123 +112,110 @@ func TestChannels(t *testing.T) { err = channelList.Accept(datatransfer.ChannelID{Initiator: peers[0], Responder: peers[1], ID: tid1}) require.NoError(t, err) state = checkEvent(ctx, t, received, datatransfer.Accept) - require.Equal(t, state.Status(), datatransfer.Ongoing) + require.Equal(t, state.Status(), datatransfer.Queued) err = channelList.Accept(datatransfer.ChannelID{Initiator: peers[1], Responder: peers[0], ID: tid1}) - require.True(t, xerrors.As(err, new(*channels.ErrNotFound))) + require.True(t, errors.Is(err, datatransfer.ErrChannelNotFound)) }) - t.Run("transfer queued", func(t *testing.T) { + t.Run("transfer initiated", func(t *testing.T) { state, err := channelList.GetByID(ctx, datatransfer.ChannelID{Initiator: peers[0], Responder: peers[1], ID: tid1}) require.NoError(t, err) - require.Equal(t, state.Status(), datatransfer.Ongoing) + require.Equal(t, state.Status(), datatransfer.Queued) - err = channelList.TransferRequestQueued(datatransfer.ChannelID{Initiator: peers[0], Responder: peers[1], ID: tid1}) + err = channelList.TransferInitiated(datatransfer.ChannelID{Initiator: peers[0], Responder: peers[1], ID: tid1}) require.NoError(t, err) - state = checkEvent(ctx, t, received, datatransfer.TransferRequestQueued) + state = checkEvent(ctx, t, received, datatransfer.TransferInitiated) require.Equal(t, state.Status(), datatransfer.Ongoing) }) - t.Run("datasent/queued when transfer is already finished", func(t *testing.T) { - ds := dss.MutexWrap(datastore.NewMapDatastore()) - - channelList, err := channels.New(ds, notifier, decoderByType, decoderByType, &fakeEnv{}, peers[0]) - require.NoError(t, err) - err = channelList.Start(ctx) - require.NoError(t, err) - - chid, err := channelList.CreateNew(peers[0], tid1, cids[0], selector, fv1, peers[0], peers[0], peers[1]) - require.NoError(t, err) - checkEvent(ctx, t, received, datatransfer.Open) - require.NoError(t, channelList.Accept(chid)) - checkEvent(ctx, t, received, datatransfer.Accept) - - // move the channel to `TransferFinished` state. - require.NoError(t, channelList.FinishTransfer(chid)) - state := checkEvent(ctx, t, received, datatransfer.FinishTransfer) - require.Equal(t, datatransfer.TransferFinished, state.Status()) - - // send a data-sent event and ensure it's a no-op - _, err = channelList.DataSent(chid, cids[1], 1, 1, true) - require.NoError(t, err) - state = checkEvent(ctx, t, received, datatransfer.DataSent) - require.Equal(t, datatransfer.TransferFinished, state.Status()) - - // send a data-queued event and ensure it's a no-op. - _, err = channelList.DataQueued(chid, cids[1], 1, 1, true) - require.NoError(t, err) - state = checkEvent(ctx, t, received, datatransfer.DataQueued) - require.Equal(t, datatransfer.TransferFinished, state.Status()) - }) - t.Run("updating send/receive values", func(t *testing.T) { ds := dss.MutexWrap(datastore.NewMapDatastore()) - channelList, err := channels.New(ds, notifier, decoderByType, decoderByType, &fakeEnv{}, peers[0]) + channelList, err := channels.New(ds, notifier, &fakeEnv{}, peers[0]) require.NoError(t, err) err = channelList.Start(ctx) require.NoError(t, err) - _, err = channelList.CreateNew(peers[0], tid1, cids[0], selector, fv1, peers[0], peers[0], peers[1]) + _, _, err = channelList.CreateNew(peers[0], tid1, cids[0], selector, fv1, peers[0], peers[0], peers[1]) require.NoError(t, err) state := checkEvent(ctx, t, received, datatransfer.Open) require.Equal(t, datatransfer.Requested, state.Status()) require.Equal(t, uint64(0), state.Received()) require.Equal(t, uint64(0), state.Sent()) - isNew, err := channelList.DataReceived(datatransfer.ChannelID{Initiator: peers[0], Responder: peers[1], ID: tid1}, cids[0], 50, 1, true) + err = channelList.TransferInitiated(datatransfer.ChannelID{Initiator: peers[0], Responder: peers[1], ID: tid1}) + require.NoError(t, err) + _ = checkEvent(ctx, t, received, datatransfer.TransferInitiated) + + err = channelList.DataReceived(datatransfer.ChannelID{Initiator: peers[0], Responder: peers[1], ID: tid1}, 50, basicnode.NewInt(1)) require.NoError(t, err) _ = checkEvent(ctx, t, received, datatransfer.DataReceivedProgress) - require.True(t, isNew) state = checkEvent(ctx, t, received, datatransfer.DataReceived) require.Equal(t, uint64(50), state.Received()) require.Equal(t, uint64(0), state.Sent()) - isNew, err = channelList.DataSent(datatransfer.ChannelID{Initiator: peers[0], Responder: peers[1], ID: tid1}, cids[1], 100, 1, true) + err = channelList.DataSent(datatransfer.ChannelID{Initiator: peers[0], Responder: peers[1], ID: tid1}, 100, basicnode.NewInt(1)) require.NoError(t, err) _ = checkEvent(ctx, t, received, datatransfer.DataSentProgress) - require.True(t, isNew) - state = checkEvent(ctx, t, received, datatransfer.DataSent) - require.Equal(t, uint64(50), state.Received()) - require.Equal(t, uint64(100), state.Sent()) - - // send block again has no effect - isNew, err = channelList.DataSent(datatransfer.ChannelID{Initiator: peers[0], Responder: peers[1], ID: tid1}, cids[1], 100, 1, true) - require.NoError(t, err) - require.False(t, isNew) state = checkEvent(ctx, t, received, datatransfer.DataSent) require.Equal(t, uint64(50), state.Received()) require.Equal(t, uint64(100), state.Sent()) // errors if channel does not exist - isNew, err = channelList.DataReceived(datatransfer.ChannelID{Initiator: peers[1], Responder: peers[0], ID: tid1}, cids[1], 200, 2, true) - require.True(t, xerrors.As(err, new(*channels.ErrNotFound))) - require.False(t, isNew) - isNew, err = channelList.DataSent(datatransfer.ChannelID{Initiator: peers[1], Responder: peers[0], ID: tid1}, cids[1], 200, 2, true) - require.True(t, xerrors.As(err, new(*channels.ErrNotFound))) - require.False(t, isNew) + err = channelList.DataReceived(datatransfer.ChannelID{Initiator: peers[1], Responder: peers[0], ID: tid1}, 200, basicnode.NewInt(2)) + require.True(t, errors.Is(err, datatransfer.ErrChannelNotFound)) + err = channelList.DataSent(datatransfer.ChannelID{Initiator: peers[1], Responder: peers[0], ID: tid1}, 200, basicnode.NewInt(2)) + require.True(t, errors.Is(err, datatransfer.ErrChannelNotFound)) - isNew, err = channelList.DataReceived(datatransfer.ChannelID{Initiator: peers[0], Responder: peers[1], ID: tid1}, cids[1], 50, 2, true) + err = channelList.DataReceived(datatransfer.ChannelID{Initiator: peers[0], Responder: peers[1], ID: tid1}, 50, basicnode.NewInt(2)) require.NoError(t, err) _ = checkEvent(ctx, t, received, datatransfer.DataReceivedProgress) - require.True(t, isNew) state = checkEvent(ctx, t, received, datatransfer.DataReceived) require.Equal(t, uint64(100), state.Received()) require.Equal(t, uint64(100), state.Sent()) + }) + + t.Run("data limit", func(t *testing.T) { + ds := dss.MutexWrap(datastore.NewMapDatastore()) - isNew, err = channelList.DataSent(datatransfer.ChannelID{Initiator: peers[0], Responder: peers[1], ID: tid1}, cids[1], 25, 2, false) + channelList, err := channels.New(ds, notifier, &fakeEnv{}, peers[0]) + require.NoError(t, err) + err = channelList.Start(ctx) require.NoError(t, err) - require.False(t, isNew) - state = checkEvent(ctx, t, received, datatransfer.DataSent) - require.Equal(t, uint64(100), state.Received()) - require.Equal(t, uint64(100), state.Sent()) - isNew, err = channelList.DataReceived(datatransfer.ChannelID{Initiator: peers[0], Responder: peers[1], ID: tid1}, cids[0], 50, 3, false) + _, _, err = channelList.CreateNew(peers[0], tid1, cids[0], selector, fv1, peers[1], peers[0], peers[1]) require.NoError(t, err) - require.False(t, isNew) - state = checkEvent(ctx, t, received, datatransfer.DataReceived) - require.Equal(t, uint64(100), state.Received()) - require.Equal(t, uint64(100), state.Sent()) + state := checkEvent(ctx, t, received, datatransfer.Open) + + err = channelList.SetDataLimit(datatransfer.ChannelID{Initiator: peers[1], Responder: peers[0], ID: tid1}, 400) + require.NoError(t, err) + state = checkEvent(ctx, t, received, datatransfer.SetDataLimit) + require.Equal(t, state.DataLimit(), uint64(400)) + + err = channelList.DataLimitExceeded(datatransfer.ChannelID{Initiator: peers[1], Responder: peers[0], ID: tid1}) + require.NoError(t, err) + state = checkEvent(ctx, t, received, datatransfer.DataLimitExceeded) + require.True(t, state.ResponderPaused()) + + err = channelList.SetDataLimit(datatransfer.ChannelID{Initiator: peers[1], Responder: peers[0], ID: tid1}, 700) + require.NoError(t, err) + state = checkEvent(ctx, t, received, datatransfer.SetDataLimit) + require.Equal(t, state.DataLimit(), uint64(700)) + + err = channelList.ResumeResponder(datatransfer.ChannelID{Initiator: peers[1], Responder: peers[0], ID: tid1}) + state = checkEvent(ctx, t, received, datatransfer.ResumeResponder) + require.False(t, state.ResponderPaused()) + + err = channelList.PauseInitiator(datatransfer.ChannelID{Initiator: peers[1], Responder: peers[0], ID: tid1}) + state = checkEvent(ctx, t, received, datatransfer.PauseInitiator) + require.True(t, state.InitiatorPaused()) + + err = channelList.DataLimitExceeded(datatransfer.ChannelID{Initiator: peers[1], Responder: peers[0], ID: tid1}) + require.NoError(t, err) + state = checkEvent(ctx, t, received, datatransfer.DataLimitExceeded) + require.True(t, state.BothPaused()) + }) t.Run("pause/resume", func(t *testing.T) { @@ -231,17 +226,19 @@ func TestChannels(t *testing.T) { err = channelList.PauseInitiator(datatransfer.ChannelID{Initiator: peers[0], Responder: peers[1], ID: tid1}) require.NoError(t, err) state = checkEvent(ctx, t, received, datatransfer.PauseInitiator) - require.Equal(t, datatransfer.InitiatorPaused, state.Status()) + require.True(t, state.InitiatorPaused()) + require.False(t, state.BothPaused()) err = channelList.PauseResponder(datatransfer.ChannelID{Initiator: peers[0], Responder: peers[1], ID: tid1}) require.NoError(t, err) state = checkEvent(ctx, t, received, datatransfer.PauseResponder) - require.Equal(t, datatransfer.BothPaused, state.Status()) + require.True(t, state.BothPaused()) err = channelList.ResumeInitiator(datatransfer.ChannelID{Initiator: peers[0], Responder: peers[1], ID: tid1}) require.NoError(t, err) state = checkEvent(ctx, t, received, datatransfer.ResumeInitiator) - require.Equal(t, datatransfer.ResponderPaused, state.Status()) + require.True(t, state.ResponderPaused()) + require.False(t, state.BothPaused()) err = channelList.ResumeResponder(datatransfer.ChannelID{Initiator: peers[0], Responder: peers[1], ID: tid1}) require.NoError(t, err) @@ -250,31 +247,44 @@ func TestChannels(t *testing.T) { }) t.Run("new vouchers & voucherResults", func(t *testing.T) { - fv3 := testutil.NewFakeDTType() - fvr1 := testutil.NewFakeDTType() + fv3 := testutil.NewTestTypedVoucher() + fvr1 := testutil.NewTestTypedVoucher() state, err := channelList.GetByID(ctx, datatransfer.ChannelID{Initiator: peers[0], Responder: peers[1], ID: tid1}) require.NoError(t, err) - require.Equal(t, []datatransfer.Voucher{fv1}, state.Vouchers()) - require.Equal(t, fv1, state.Voucher()) - require.Equal(t, fv1, state.LastVoucher()) + vouchers := state.Vouchers() + require.Len(t, vouchers, 1) + require.True(t, fv1.Equals(vouchers[0])) + voucher := state.Voucher() + require.True(t, fv1.Equals(voucher)) + voucher = state.LastVoucher() + require.True(t, fv1.Equals(voucher)) err = channelList.NewVoucher(datatransfer.ChannelID{Initiator: peers[0], Responder: peers[1], ID: tid1}, fv3) require.NoError(t, err) state = checkEvent(ctx, t, received, datatransfer.NewVoucher) - require.Equal(t, []datatransfer.Voucher{fv1, fv3}, state.Vouchers()) - require.Equal(t, fv1, state.Voucher()) - require.Equal(t, fv3, state.LastVoucher()) + vouchers = state.Vouchers() + require.Len(t, vouchers, 2) + require.True(t, fv1.Equals(vouchers[0])) + require.True(t, fv3.Equals(vouchers[1])) + voucher = state.Voucher() + require.True(t, fv1.Equals(voucher)) + voucher = state.LastVoucher() + require.True(t, fv3.Equals(voucher)) state, err = channelList.GetByID(ctx, datatransfer.ChannelID{Initiator: peers[0], Responder: peers[1], ID: tid1}) require.NoError(t, err) - require.Equal(t, []datatransfer.VoucherResult{}, state.VoucherResults()) + results := state.VoucherResults() + require.Equal(t, []datatransfer.TypedVoucher{}, results) err = channelList.NewVoucherResult(datatransfer.ChannelID{Initiator: peers[0], Responder: peers[1], ID: tid1}, fvr1) require.NoError(t, err) state = checkEvent(ctx, t, received, datatransfer.NewVoucherResult) - require.Equal(t, []datatransfer.VoucherResult{fvr1}, state.VoucherResults()) - require.Equal(t, fvr1, state.LastVoucherResult()) + voucherResults := state.VoucherResults() + require.Len(t, voucherResults, 1) + require.True(t, fvr1.Equals(voucherResults[0])) + voucherResult := state.LastVoucherResult() + require.True(t, fvr1.Equals(voucherResult)) }) t.Run("test finality", func(t *testing.T) { @@ -301,7 +311,7 @@ func TestChannels(t *testing.T) { state = checkEvent(ctx, t, received, datatransfer.CleanupComplete) require.Equal(t, datatransfer.Failed, state.Status()) - chid, err := channelList.CreateNew(peers[0], tid2, cids[1], selector, fv2, peers[2], peers[1], peers[2]) + chid, _, err := channelList.CreateNew(peers[0], tid2, cids[1], selector, fv2, peers[2], peers[1], peers[2]) require.NoError(t, err) require.Equal(t, peers[2], chid.Initiator) require.Equal(t, tid2, chid.ID) @@ -318,7 +328,7 @@ func TestChannels(t *testing.T) { t.Run("test self peer and other peer", func(t *testing.T) { // sender is self peer - chid, err := channelList.CreateNew(peers[1], tid1, cids[0], selector, fv1, peers[1], peers[1], peers[2]) + chid, _, err := channelList.CreateNew(peers[1], tid1, cids[0], selector, fv1, peers[1], peers[1], peers[2]) require.NoError(t, err) ch, err := channelList.GetByID(context.Background(), chid) require.NoError(t, err) @@ -326,7 +336,7 @@ func TestChannels(t *testing.T) { require.Equal(t, peers[2], ch.OtherPeer()) // recipient is self peer - chid, err = channelList.CreateNew(peers[2], datatransfer.TransferID(1001), cids[0], selector, fv1, peers[1], peers[2], peers[1]) + chid, _, err = channelList.CreateNew(peers[2], datatransfer.TransferID(1001), cids[0], selector, fv1, peers[1], peers[2], peers[1]) require.NoError(t, err) ch, err = channelList.GetByID(context.Background(), chid) require.NoError(t, err) @@ -340,12 +350,12 @@ func TestChannels(t *testing.T) { notifier := func(evt datatransfer.Event, chst datatransfer.ChannelState) { received <- event{evt, chst} } - channelList, err := channels.New(ds, notifier, decoderByType, decoderByType, &fakeEnv{}, peers[0]) + channelList, err := channels.New(ds, notifier, &fakeEnv{}, peers[0]) require.NoError(t, err) err = channelList.Start(ctx) require.NoError(t, err) - chid, err := channelList.CreateNew(peers[3], tid1, cids[0], selector, fv1, peers[3], peers[0], peers[3]) + chid, _, err := channelList.CreateNew(peers[3], tid1, cids[0], selector, fv1, peers[3], peers[0], peers[3]) require.NoError(t, err) state := checkEvent(ctx, t, received, datatransfer.Open) require.Equal(t, datatransfer.Requested, state.Status()) @@ -360,7 +370,7 @@ func TestChannels(t *testing.T) { t.Run("test self peer and other peer", func(t *testing.T) { peers := testutil.GeneratePeers(3) // sender is self peer - chid, err := channelList.CreateNew(peers[1], tid1, cids[0], selector, fv1, peers[1], peers[1], peers[2]) + chid, _, err := channelList.CreateNew(peers[1], tid1, cids[0], selector, fv1, peers[1], peers[1], peers[2]) require.NoError(t, err) ch, err := channelList.GetByID(context.Background(), chid) require.NoError(t, err) @@ -368,7 +378,7 @@ func TestChannels(t *testing.T) { require.Equal(t, peers[2], ch.OtherPeer()) // recipient is self peer - chid, err = channelList.CreateNew(peers[2], datatransfer.TransferID(1001), cids[0], selector, fv1, peers[1], peers[2], peers[1]) + chid, _, err = channelList.CreateNew(peers[2], datatransfer.TransferID(1001), cids[0], selector, fv1, peers[1], peers[2], peers[1]) require.NoError(t, err) ch, err = channelList.GetByID(context.Background(), chid) require.NoError(t, err) @@ -390,6 +400,153 @@ func TestIsChannelCleaningUp(t *testing.T) { require.False(t, channels.IsChannelCleaningUp(datatransfer.Cancelled)) } +func TestMigrations(t *testing.T) { + ctx := context.Background() + ctx, cancel := context.WithTimeout(ctx, 2*time.Second) + defer cancel() + + ds := dss.MutexWrap(datastore.NewMapDatastore()) + received := make(chan event) + notifier := func(evt datatransfer.Event, chst datatransfer.ChannelState) { + received <- event{evt, chst} + } + numChannels := 5 + transferIDs := make([]datatransfer.TransferID, numChannels) + initiators := make([]peer.ID, numChannels) + responders := make([]peer.ID, numChannels) + baseCids := make([]cid.Cid, numChannels) + + totalSizes := make([]uint64, numChannels) + sents := make([]uint64, numChannels) + receiveds := make([]uint64, numChannels) + + messages := make([]string, numChannels) + vouchers := make([]datatransfer.TypedVoucher, numChannels) + voucherResults := make([]datatransfer.TypedVoucher, numChannels) + sentIndex := make([]int64, numChannels) + receivedIndex := make([]int64, numChannels) + queuedIndex := make([]int64, numChannels) + allSelector := builder.NewSelectorSpecBuilder(basicnode.Prototype.Any).Matcher().Node() + selfPeer := testutil.GeneratePeers(1)[0] + + list, err := migrations.GetChannelStateMigrations(selfPeer) + require.NoError(t, err) + vds, up := versionedds.NewVersionedDatastore(ds, list, versioning.VersionKey("2")) + require.NoError(t, up(ctx)) + + initialStatuses := []datatransfer.Status{ + datatransfer.Requested, + datatransfer.InitiatorPaused, + datatransfer.ResponderPaused, + datatransfer.BothPaused, + datatransfer.Ongoing, + } + for i := 0; i < numChannels; i++ { + transferIDs[i] = datatransfer.TransferID(rand.Uint64()) + initiators[i] = testutil.GeneratePeers(1)[0] + responders[i] = testutil.GeneratePeers(1)[0] + baseCids[i] = testutil.GenerateCids(1)[0] + totalSizes[i] = rand.Uint64() + sents[i] = rand.Uint64() + receiveds[i] = rand.Uint64() + messages[i] = string(testutil.RandomBytes(20)) + vouchers[i] = testutil.NewTestTypedVoucher() + voucherResults[i] = testutil.NewTestTypedVoucher() + sentIndex[i] = rand.Int63() + receivedIndex[i] = rand.Int63() + queuedIndex[i] = rand.Int63() + channel := migrations.ChannelStateV2{ + TransferID: transferIDs[i], + Initiator: initiators[i], + Responder: responders[i], + BaseCid: baseCids[i], + Selector: internal.CborGenCompatibleNode{ + Node: allSelector, + }, + Sender: initiators[i], + Recipient: responders[i], + TotalSize: totalSizes[i], + Status: initialStatuses[i], + Sent: sents[i], + Received: receiveds[i], + Message: messages[i], + Vouchers: []internal.EncodedVoucher{ + { + Type: vouchers[i].Type, + Voucher: internal.CborGenCompatibleNode{ + Node: vouchers[i].Voucher, + }, + }, + }, + VoucherResults: []internal.EncodedVoucherResult{ + { + Type: voucherResults[i].Type, + VoucherResult: internal.CborGenCompatibleNode{ + Node: voucherResults[i].Voucher, + }, + }, + }, + SentBlocksTotal: sentIndex[i], + ReceivedBlocksTotal: receivedIndex[i], + QueuedBlocksTotal: queuedIndex[i], + SelfPeer: selfPeer, + } + buf := new(bytes.Buffer) + err = channel.MarshalCBOR(buf) + require.NoError(t, err) + err = vds.Put(ctx, datastore.NewKey(datatransfer.ChannelID{ + Initiator: initiators[i], + Responder: responders[i], + ID: transferIDs[i], + }.String()), buf.Bytes()) + require.NoError(t, err) + } + + channelList, err := channels.New(ds, notifier, &fakeEnv{}, selfPeer) + require.NoError(t, err) + err = channelList.Start(ctx) + require.NoError(t, err) + + expectedStatuses := []datatransfer.Status{ + datatransfer.Requested, + datatransfer.Ongoing, + datatransfer.Ongoing, + datatransfer.Ongoing, + datatransfer.Ongoing, + } + + expectedInitiatorPaused := []bool{false, true, false, true, false} + expectedResponderPaused := []bool{false, false, true, true, false} + for i := 0; i < numChannels; i++ { + + channel, err := channelList.GetByID(ctx, datatransfer.ChannelID{ + Initiator: initiators[i], + Responder: responders[i], + ID: transferIDs[i], + }) + require.NoError(t, err) + require.Equal(t, selfPeer, channel.SelfPeer()) + require.Equal(t, transferIDs[i], channel.TransferID()) + require.Equal(t, baseCids[i], channel.BaseCID()) + require.Equal(t, allSelector, channel.Selector()) + require.Equal(t, initiators[i], channel.Sender()) + require.Equal(t, responders[i], channel.Recipient()) + require.Equal(t, totalSizes[i], channel.TotalSize()) + require.Equal(t, sents[i], channel.Sent()) + require.Equal(t, receiveds[i], channel.Received()) + require.Equal(t, messages[i], channel.Message()) + require.Equal(t, vouchers[i], channel.LastVoucher()) + require.Equal(t, voucherResults[i], channel.LastVoucherResult()) + require.Equal(t, expectedStatuses[i], channel.Status()) + require.Equal(t, expectedInitiatorPaused[i], channel.InitiatorPaused()) + require.Equal(t, expectedResponderPaused[i], channel.ResponderPaused()) + require.Equal(t, basicnode.NewInt(sentIndex[i]), channel.SentIndex()) + require.Equal(t, basicnode.NewInt(receivedIndex[i]), channel.ReceivedIndex()) + require.Equal(t, basicnode.NewInt(queuedIndex[i]), channel.QueuedIndex()) + + } +} + type event struct { event datatransfer.Event state datatransfer.ChannelState @@ -422,14 +579,3 @@ func (fe *fakeEnv) ID() peer.ID { func (fe *fakeEnv) CleanupChannel(chid datatransfer.ChannelID) { } - -func decoderByType(identifier datatransfer.TypeIdentifier) (encoding.Decoder, bool) { - if identifier == testutil.NewFakeDTType().Type() { - decoder, err := encoding.NewDecoder(testutil.NewFakeDTType()) - if err != nil { - return nil, false - } - return decoder, true - } - return nil, false -} diff --git a/channels/internal/internalchannel.go b/channels/internal/internalchannel.go index f6cf916b..4209edaa 100644 --- a/channels/internal/internalchannel.go +++ b/channels/internal/internalchannel.go @@ -1,15 +1,57 @@ package internal import ( + "bytes" "fmt" + "io" "github.com/ipfs/go-cid" + "github.com/ipld/go-ipld-prime/codec/dagcbor" + "github.com/ipld/go-ipld-prime/datamodel" + "github.com/ipld/go-ipld-prime/node/basicnode" + "github.com/ipld/go-ipld-prime/schema" peer "github.com/libp2p/go-libp2p-core/peer" cbg "github.com/whyrusleeping/cbor-gen" - datatransfer "github.com/filecoin-project/go-data-transfer" + datatransfer "github.com/filecoin-project/go-data-transfer/v2" ) +type CborGenCompatibleNode struct { + Node datamodel.Node +} + +func (sn CborGenCompatibleNode) IsNull() bool { + return sn.Node == nil || sn.Node == datamodel.Null +} + +// UnmarshalCBOR is for cbor-gen compatibility +func (sn *CborGenCompatibleNode) UnmarshalCBOR(r io.Reader) error { + // use cbg.Deferred.UnmarshalCBOR to figure out how much to pull + def := cbg.Deferred{} + if err := def.UnmarshalCBOR(r); err != nil { + return err + } + // convert it to a Node + na := basicnode.Prototype.Any.NewBuilder() + if err := dagcbor.Decode(na, bytes.NewReader(def.Raw)); err != nil { + return err + } + sn.Node = na.Build() + return nil +} + +// MarshalCBOR is for cbor-gen compatibility +func (sn *CborGenCompatibleNode) MarshalCBOR(w io.Writer) error { + node := datamodel.Null + if sn != nil && sn.Node != nil { + node = sn.Node + if tn, ok := node.(schema.TypedNode); ok { + node = tn.Representation() + } + } + return dagcbor.Encode(node, w) +} + //go:generate cbor-gen-for --map-encoding ChannelState EncodedVoucher EncodedVoucherResult // EncodedVoucher is how the voucher is stored on disk @@ -17,7 +59,7 @@ type EncodedVoucher struct { // Vouchers identifier for decoding Type datatransfer.TypeIdentifier // used to verify this channel - Voucher *cbg.Deferred + Voucher CborGenCompatibleNode } // EncodedVoucherResult is how the voucher result is stored on disk @@ -25,7 +67,7 @@ type EncodedVoucherResult struct { // Vouchers identifier for decoding Type datatransfer.TypeIdentifier // used to verify this channel - VoucherResult *cbg.Deferred + VoucherResult CborGenCompatibleNode } // ChannelState is the internal representation on disk for the channel fsm @@ -41,7 +83,7 @@ type ChannelState struct { // base CID for the piece being transferred BaseCid cid.Cid // portion of Piece to return, specified by an IPLD selector - Selector *cbg.Deferred + Selector CborGenCompatibleNode // the party that is sending the data (not who initiated the request) Sender peer.ID // the party that is receiving the data (not who initiated the request) @@ -62,13 +104,23 @@ type ChannelState struct { VoucherResults []EncodedVoucherResult // Number of blocks that have been received, including blocks that are // present in more than one place in the DAG - ReceivedBlocksTotal int64 + ReceivedIndex CborGenCompatibleNode // Number of blocks that have been queued, including blocks that are // present in more than one place in the DAG - QueuedBlocksTotal int64 + QueuedIndex CborGenCompatibleNode // Number of blocks that have been sent, including blocks that are // present in more than one place in the DAG - SentBlocksTotal int64 + SentIndex CborGenCompatibleNode + // DataLimit is the maximum data that can be transferred on this channel before + // revalidation. 0 indicates no limit. + DataLimit uint64 + // RequiresFinalization indicates at the end of the transfer, the channel should + // be left open for a final settlement + RequiresFinalization bool + // ResponderPaused indicates whether the responder is in a paused state + ResponderPaused bool + // InitiatorPaused indicates whether the initiator is in a paused state + InitiatorPaused bool // Stages traces the execution fo a data transfer. // // EXPERIMENTAL; subject to change. diff --git a/channels/internal/internalchannel_cbor_gen.go b/channels/internal/internalchannel_cbor_gen.go index 5a6cfb19..c7b1e958 100644 --- a/channels/internal/internalchannel_cbor_gen.go +++ b/channels/internal/internalchannel_cbor_gen.go @@ -7,7 +7,7 @@ import ( "io" "sort" - datatransfer "github.com/filecoin-project/go-data-transfer" + datatransfer "github.com/filecoin-project/go-data-transfer/v2" cid "github.com/ipfs/go-cid" peer "github.com/libp2p/go-libp2p-core/peer" cbg "github.com/whyrusleeping/cbor-gen" @@ -23,7 +23,7 @@ func (t *ChannelState) MarshalCBOR(w io.Writer) error { _, err := w.Write(cbg.CborNull) return err } - if _, err := w.Write([]byte{180}); err != nil { + if _, err := w.Write([]byte{184, 24}); err != nil { return err } @@ -130,7 +130,7 @@ func (t *ChannelState) MarshalCBOR(w io.Writer) error { return xerrors.Errorf("failed to write cid field t.BaseCid: %w", err) } - // t.Selector (typegen.Deferred) (struct) + // t.Selector (internal.CborGenCompatibleNode) (struct) if len("Selector") > cbg.MaxLength { return xerrors.Errorf("Value in field \"Selector\" was too long") } @@ -345,70 +345,116 @@ func (t *ChannelState) MarshalCBOR(w io.Writer) error { } } - // t.ReceivedBlocksTotal (int64) (int64) - if len("ReceivedBlocksTotal") > cbg.MaxLength { - return xerrors.Errorf("Value in field \"ReceivedBlocksTotal\" was too long") + // t.ReceivedIndex (internal.CborGenCompatibleNode) (struct) + if len("ReceivedIndex") > cbg.MaxLength { + return xerrors.Errorf("Value in field \"ReceivedIndex\" was too long") } - if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("ReceivedBlocksTotal"))); err != nil { + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("ReceivedIndex"))); err != nil { return err } - if _, err := io.WriteString(w, string("ReceivedBlocksTotal")); err != nil { + if _, err := io.WriteString(w, string("ReceivedIndex")); err != nil { return err } - if t.ReceivedBlocksTotal >= 0 { - if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajUnsignedInt, uint64(t.ReceivedBlocksTotal)); err != nil { - return err - } - } else { - if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajNegativeInt, uint64(-t.ReceivedBlocksTotal-1)); err != nil { - return err - } + if err := t.ReceivedIndex.MarshalCBOR(w); err != nil { + return err } - // t.QueuedBlocksTotal (int64) (int64) - if len("QueuedBlocksTotal") > cbg.MaxLength { - return xerrors.Errorf("Value in field \"QueuedBlocksTotal\" was too long") + // t.QueuedIndex (internal.CborGenCompatibleNode) (struct) + if len("QueuedIndex") > cbg.MaxLength { + return xerrors.Errorf("Value in field \"QueuedIndex\" was too long") } - if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("QueuedBlocksTotal"))); err != nil { + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("QueuedIndex"))); err != nil { return err } - if _, err := io.WriteString(w, string("QueuedBlocksTotal")); err != nil { + if _, err := io.WriteString(w, string("QueuedIndex")); err != nil { return err } - if t.QueuedBlocksTotal >= 0 { - if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajUnsignedInt, uint64(t.QueuedBlocksTotal)); err != nil { - return err - } - } else { - if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajNegativeInt, uint64(-t.QueuedBlocksTotal-1)); err != nil { - return err - } + if err := t.QueuedIndex.MarshalCBOR(w); err != nil { + return err } - // t.SentBlocksTotal (int64) (int64) - if len("SentBlocksTotal") > cbg.MaxLength { - return xerrors.Errorf("Value in field \"SentBlocksTotal\" was too long") + // t.SentIndex (internal.CborGenCompatibleNode) (struct) + if len("SentIndex") > cbg.MaxLength { + return xerrors.Errorf("Value in field \"SentIndex\" was too long") } - if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("SentBlocksTotal"))); err != nil { + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("SentIndex"))); err != nil { return err } - if _, err := io.WriteString(w, string("SentBlocksTotal")); err != nil { + if _, err := io.WriteString(w, string("SentIndex")); err != nil { return err } - if t.SentBlocksTotal >= 0 { - if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajUnsignedInt, uint64(t.SentBlocksTotal)); err != nil { - return err - } - } else { - if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajNegativeInt, uint64(-t.SentBlocksTotal-1)); err != nil { - return err - } + if err := t.SentIndex.MarshalCBOR(w); err != nil { + return err + } + + // t.DataLimit (uint64) (uint64) + if len("DataLimit") > cbg.MaxLength { + return xerrors.Errorf("Value in field \"DataLimit\" was too long") + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("DataLimit"))); err != nil { + return err + } + if _, err := io.WriteString(w, string("DataLimit")); err != nil { + return err + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajUnsignedInt, uint64(t.DataLimit)); err != nil { + return err + } + + // t.RequiresFinalization (bool) (bool) + if len("RequiresFinalization") > cbg.MaxLength { + return xerrors.Errorf("Value in field \"RequiresFinalization\" was too long") + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("RequiresFinalization"))); err != nil { + return err + } + if _, err := io.WriteString(w, string("RequiresFinalization")); err != nil { + return err + } + + if err := cbg.WriteBool(w, t.RequiresFinalization); err != nil { + return err + } + + // t.ResponderPaused (bool) (bool) + if len("ResponderPaused") > cbg.MaxLength { + return xerrors.Errorf("Value in field \"ResponderPaused\" was too long") + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("ResponderPaused"))); err != nil { + return err + } + if _, err := io.WriteString(w, string("ResponderPaused")); err != nil { + return err + } + + if err := cbg.WriteBool(w, t.ResponderPaused); err != nil { + return err + } + + // t.InitiatorPaused (bool) (bool) + if len("InitiatorPaused") > cbg.MaxLength { + return xerrors.Errorf("Value in field \"InitiatorPaused\" was too long") + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("InitiatorPaused"))); err != nil { + return err + } + if _, err := io.WriteString(w, string("InitiatorPaused")); err != nil { + return err + } + + if err := cbg.WriteBool(w, t.InitiatorPaused); err != nil { + return err } // t.Stages (datatransfer.ChannelStages) (struct) @@ -523,16 +569,15 @@ func (t *ChannelState) UnmarshalCBOR(r io.Reader) error { t.BaseCid = c } - // t.Selector (typegen.Deferred) (struct) + // t.Selector (internal.CborGenCompatibleNode) (struct) case "Selector": { - t.Selector = new(cbg.Deferred) - if err := t.Selector.UnmarshalCBOR(br); err != nil { - return xerrors.Errorf("failed to read deferred field: %w", err) + return xerrors.Errorf("unmarshaling t.Selector: %w", err) } + } // t.Sender (peer.ID) (string) case "Sender": @@ -702,83 +747,104 @@ func (t *ChannelState) UnmarshalCBOR(r io.Reader) error { t.VoucherResults[i] = v } - // t.ReceivedBlocksTotal (int64) (int64) - case "ReceivedBlocksTotal": + // t.ReceivedIndex (internal.CborGenCompatibleNode) (struct) + case "ReceivedIndex": + { - maj, extra, err := cbg.CborReadHeaderBuf(br, scratch) - var extraI int64 - if err != nil { - return err - } - switch maj { - case cbg.MajUnsignedInt: - extraI = int64(extra) - if extraI < 0 { - return fmt.Errorf("int64 positive overflow") - } - case cbg.MajNegativeInt: - extraI = int64(extra) - if extraI < 0 { - return fmt.Errorf("int64 negative oveflow") - } - extraI = -1 - extraI - default: - return fmt.Errorf("wrong type for int64 field: %d", maj) + + if err := t.ReceivedIndex.UnmarshalCBOR(br); err != nil { + return xerrors.Errorf("unmarshaling t.ReceivedIndex: %w", err) } - t.ReceivedBlocksTotal = int64(extraI) } - // t.QueuedBlocksTotal (int64) (int64) - case "QueuedBlocksTotal": + // t.QueuedIndex (internal.CborGenCompatibleNode) (struct) + case "QueuedIndex": + { - maj, extra, err := cbg.CborReadHeaderBuf(br, scratch) - var extraI int64 - if err != nil { - return err + + if err := t.QueuedIndex.UnmarshalCBOR(br); err != nil { + return xerrors.Errorf("unmarshaling t.QueuedIndex: %w", err) } - switch maj { - case cbg.MajUnsignedInt: - extraI = int64(extra) - if extraI < 0 { - return fmt.Errorf("int64 positive overflow") - } - case cbg.MajNegativeInt: - extraI = int64(extra) - if extraI < 0 { - return fmt.Errorf("int64 negative oveflow") - } - extraI = -1 - extraI - default: - return fmt.Errorf("wrong type for int64 field: %d", maj) + + } + // t.SentIndex (internal.CborGenCompatibleNode) (struct) + case "SentIndex": + + { + + if err := t.SentIndex.UnmarshalCBOR(br); err != nil { + return xerrors.Errorf("unmarshaling t.SentIndex: %w", err) } - t.QueuedBlocksTotal = int64(extraI) } - // t.SentBlocksTotal (int64) (int64) - case "SentBlocksTotal": + // t.DataLimit (uint64) (uint64) + case "DataLimit": + { - maj, extra, err := cbg.CborReadHeaderBuf(br, scratch) - var extraI int64 + + maj, extra, err = cbg.CborReadHeaderBuf(br, scratch) if err != nil { return err } - switch maj { - case cbg.MajUnsignedInt: - extraI = int64(extra) - if extraI < 0 { - return fmt.Errorf("int64 positive overflow") - } - case cbg.MajNegativeInt: - extraI = int64(extra) - if extraI < 0 { - return fmt.Errorf("int64 negative oveflow") - } - extraI = -1 - extraI - default: - return fmt.Errorf("wrong type for int64 field: %d", maj) + if maj != cbg.MajUnsignedInt { + return fmt.Errorf("wrong type for uint64 field") } + t.DataLimit = uint64(extra) + + } + // t.RequiresFinalization (bool) (bool) + case "RequiresFinalization": + + maj, extra, err = cbg.CborReadHeaderBuf(br, scratch) + if err != nil { + return err + } + if maj != cbg.MajOther { + return fmt.Errorf("booleans must be major type 7") + } + switch extra { + case 20: + t.RequiresFinalization = false + case 21: + t.RequiresFinalization = true + default: + return fmt.Errorf("booleans are either major type 7, value 20 or 21 (got %d)", extra) + } + // t.ResponderPaused (bool) (bool) + case "ResponderPaused": + + maj, extra, err = cbg.CborReadHeaderBuf(br, scratch) + if err != nil { + return err + } + if maj != cbg.MajOther { + return fmt.Errorf("booleans must be major type 7") + } + switch extra { + case 20: + t.ResponderPaused = false + case 21: + t.ResponderPaused = true + default: + return fmt.Errorf("booleans are either major type 7, value 20 or 21 (got %d)", extra) + } + // t.InitiatorPaused (bool) (bool) + case "InitiatorPaused": - t.SentBlocksTotal = int64(extraI) + maj, extra, err = cbg.CborReadHeaderBuf(br, scratch) + if err != nil { + return err + } + if maj != cbg.MajOther { + return fmt.Errorf("booleans must be major type 7") + } + switch extra { + case 20: + t.InitiatorPaused = false + case 21: + t.InitiatorPaused = true + default: + return fmt.Errorf("booleans are either major type 7, value 20 or 21 (got %d)", extra) } // t.Stages (datatransfer.ChannelStages) (struct) case "Stages": @@ -843,7 +909,7 @@ func (t *EncodedVoucher) MarshalCBOR(w io.Writer) error { return err } - // t.Voucher (typegen.Deferred) (struct) + // t.Voucher (internal.CborGenCompatibleNode) (struct) if len("Voucher") > cbg.MaxLength { return xerrors.Errorf("Value in field \"Voucher\" was too long") } @@ -905,16 +971,15 @@ func (t *EncodedVoucher) UnmarshalCBOR(r io.Reader) error { t.Type = datatransfer.TypeIdentifier(sval) } - // t.Voucher (typegen.Deferred) (struct) + // t.Voucher (internal.CborGenCompatibleNode) (struct) case "Voucher": { - t.Voucher = new(cbg.Deferred) - if err := t.Voucher.UnmarshalCBOR(br); err != nil { - return xerrors.Errorf("failed to read deferred field: %w", err) + return xerrors.Errorf("unmarshaling t.Voucher: %w", err) } + } default: @@ -959,7 +1024,7 @@ func (t *EncodedVoucherResult) MarshalCBOR(w io.Writer) error { return err } - // t.VoucherResult (typegen.Deferred) (struct) + // t.VoucherResult (internal.CborGenCompatibleNode) (struct) if len("VoucherResult") > cbg.MaxLength { return xerrors.Errorf("Value in field \"VoucherResult\" was too long") } @@ -1021,16 +1086,15 @@ func (t *EncodedVoucherResult) UnmarshalCBOR(r io.Reader) error { t.Type = datatransfer.TypeIdentifier(sval) } - // t.VoucherResult (typegen.Deferred) (struct) + // t.VoucherResult (internal.CborGenCompatibleNode) (struct) case "VoucherResult": { - t.VoucherResult = new(cbg.Deferred) - if err := t.VoucherResult.UnmarshalCBOR(br); err != nil { - return xerrors.Errorf("failed to read deferred field: %w", err) + return xerrors.Errorf("unmarshaling t.VoucherResult: %w", err) } + } default: diff --git a/channels/internal/migrations/migrations.go b/channels/internal/migrations/migrations.go index b6a1ed6a..210dfc00 100644 --- a/channels/internal/migrations/migrations.go +++ b/channels/internal/migrations/migrations.go @@ -1,13 +1,119 @@ package migrations import ( + "github.com/ipfs/go-cid" + basicnode "github.com/ipld/go-ipld-prime/node/basic" peer "github.com/libp2p/go-libp2p-core/peer" versioning "github.com/filecoin-project/go-ds-versioning/pkg" "github.com/filecoin-project/go-ds-versioning/pkg/versioned" + + datatransfer "github.com/filecoin-project/go-data-transfer/v2" + "github.com/filecoin-project/go-data-transfer/v2/channels/internal" ) +//go:generate cbor-gen-for --map-encoding ChannelStateV2 + +// ChannelStateV2 is the internal representation on disk for the channel fsm, version 2 +type ChannelStateV2 struct { + // PeerId of the manager peer + SelfPeer peer.ID + // an identifier for this channel shared by request and responder, set by requester through protocol + TransferID datatransfer.TransferID + // Initiator is the person who intiated this datatransfer request + Initiator peer.ID + // Responder is the person who is responding to this datatransfer request + Responder peer.ID + // base CID for the piece being transferred + BaseCid cid.Cid + // portion of Piece to return, specified by an IPLD selector + Selector internal.CborGenCompatibleNode + // the party that is sending the data (not who initiated the request) + Sender peer.ID + // the party that is receiving the data (not who initiated the request) + Recipient peer.ID + // expected amount of data to be transferred + TotalSize uint64 + // current status of this deal + Status datatransfer.Status + // total bytes read from this node and queued for sending (0 if receiver) + Queued uint64 + // total bytes sent from this node (0 if receiver) + Sent uint64 + // total bytes received by this node (0 if sender) + Received uint64 + // more informative status on a channel + Message string + Vouchers []internal.EncodedVoucher + VoucherResults []internal.EncodedVoucherResult + // Number of blocks that have been received, including blocks that are + // present in more than one place in the DAG + ReceivedBlocksTotal int64 + // Number of blocks that have been queued, including blocks that are + // present in more than one place in the DAG + QueuedBlocksTotal int64 + // Number of blocks that have been sent, including blocks that are + // present in more than one place in the DAG + SentBlocksTotal int64 + // DataLimit is the maximum data that can be transferred on this channel before + // revalidation. 0 indicates no limit. + DataLimit uint64 + // RequiresFinalization indicates at the end of the transfer, the channel should + // be left open for a final settlement + RequiresFinalization bool + // Stages traces the execution fo a data transfer. + // + // EXPERIMENTAL; subject to change. + Stages *datatransfer.ChannelStages +} + +func NoOpChannelState0To2(oldChannelState *ChannelStateV2) (*ChannelStateV2, error) { + return oldChannelState, nil +} + +func MigrateChannelState2To3(oldChannelState *ChannelStateV2) (*internal.ChannelState, error) { + receivedIndex := basicnode.NewInt(oldChannelState.ReceivedBlocksTotal) + sentIndex := basicnode.NewInt(oldChannelState.SentBlocksTotal) + queuedIndex := basicnode.NewInt(oldChannelState.QueuedBlocksTotal) + + responderPaused := oldChannelState.Status == datatransfer.ResponderPaused || oldChannelState.Status == datatransfer.BothPaused + initiatorPaused := oldChannelState.Status == datatransfer.InitiatorPaused || oldChannelState.Status == datatransfer.BothPaused + newStatus := oldChannelState.Status + if newStatus == datatransfer.ResponderPaused || newStatus == datatransfer.InitiatorPaused || newStatus == datatransfer.BothPaused { + newStatus = datatransfer.Ongoing + } + return &internal.ChannelState{ + SelfPeer: oldChannelState.SelfPeer, + TransferID: oldChannelState.TransferID, + Initiator: oldChannelState.Initiator, + Responder: oldChannelState.Responder, + BaseCid: oldChannelState.BaseCid, + Selector: oldChannelState.Selector, + Sender: oldChannelState.Sender, + Recipient: oldChannelState.Recipient, + TotalSize: oldChannelState.TotalSize, + Status: newStatus, + Queued: oldChannelState.Queued, + Sent: oldChannelState.Sent, + Received: oldChannelState.Received, + Message: oldChannelState.Message, + Vouchers: oldChannelState.Vouchers, + VoucherResults: oldChannelState.VoucherResults, + ReceivedIndex: internal.CborGenCompatibleNode{Node: receivedIndex}, + SentIndex: internal.CborGenCompatibleNode{Node: sentIndex}, + QueuedIndex: internal.CborGenCompatibleNode{Node: queuedIndex}, + DataLimit: oldChannelState.DataLimit, + RequiresFinalization: oldChannelState.RequiresFinalization, + InitiatorPaused: initiatorPaused, + ResponderPaused: responderPaused, + Stages: oldChannelState.Stages, + }, nil +} + // GetChannelStateMigrations returns a migration list for the channel states func GetChannelStateMigrations(selfPeer peer.ID) (versioning.VersionedMigrationList, error) { - return versioned.BuilderList{}.Build() + return versioned.BuilderList{ + versioned.NewVersionedBuilder(NoOpChannelState0To2, "2"), + versioned.NewVersionedBuilder(MigrateChannelState2To3, "3").OldVersion("2"), + }.Build() } diff --git a/channels/internal/migrations/migrations_cbor_gen.go b/channels/internal/migrations/migrations_cbor_gen.go new file mode 100644 index 00000000..c4ca74fd --- /dev/null +++ b/channels/internal/migrations/migrations_cbor_gen.go @@ -0,0 +1,876 @@ +// Code generated by github.com/whyrusleeping/cbor-gen. DO NOT EDIT. + +package migrations + +import ( + "fmt" + "io" + "sort" + + datatransfer "github.com/filecoin-project/go-data-transfer/v2" + internal "github.com/filecoin-project/go-data-transfer/v2/channels/internal" + cid "github.com/ipfs/go-cid" + peer "github.com/libp2p/go-libp2p-core/peer" + cbg "github.com/whyrusleeping/cbor-gen" + xerrors "golang.org/x/xerrors" +) + +var _ = xerrors.Errorf +var _ = cid.Undef +var _ = sort.Sort + +func (t *ChannelStateV2) MarshalCBOR(w io.Writer) error { + if t == nil { + _, err := w.Write(cbg.CborNull) + return err + } + if _, err := w.Write([]byte{182}); err != nil { + return err + } + + scratch := make([]byte, 9) + + // t.SelfPeer (peer.ID) (string) + if len("SelfPeer") > cbg.MaxLength { + return xerrors.Errorf("Value in field \"SelfPeer\" was too long") + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("SelfPeer"))); err != nil { + return err + } + if _, err := io.WriteString(w, string("SelfPeer")); err != nil { + return err + } + + if len(t.SelfPeer) > cbg.MaxLength { + return xerrors.Errorf("Value in field t.SelfPeer was too long") + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len(t.SelfPeer))); err != nil { + return err + } + if _, err := io.WriteString(w, string(t.SelfPeer)); err != nil { + return err + } + + // t.TransferID (datatransfer.TransferID) (uint64) + if len("TransferID") > cbg.MaxLength { + return xerrors.Errorf("Value in field \"TransferID\" was too long") + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("TransferID"))); err != nil { + return err + } + if _, err := io.WriteString(w, string("TransferID")); err != nil { + return err + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajUnsignedInt, uint64(t.TransferID)); err != nil { + return err + } + + // t.Initiator (peer.ID) (string) + if len("Initiator") > cbg.MaxLength { + return xerrors.Errorf("Value in field \"Initiator\" was too long") + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("Initiator"))); err != nil { + return err + } + if _, err := io.WriteString(w, string("Initiator")); err != nil { + return err + } + + if len(t.Initiator) > cbg.MaxLength { + return xerrors.Errorf("Value in field t.Initiator was too long") + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len(t.Initiator))); err != nil { + return err + } + if _, err := io.WriteString(w, string(t.Initiator)); err != nil { + return err + } + + // t.Responder (peer.ID) (string) + if len("Responder") > cbg.MaxLength { + return xerrors.Errorf("Value in field \"Responder\" was too long") + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("Responder"))); err != nil { + return err + } + if _, err := io.WriteString(w, string("Responder")); err != nil { + return err + } + + if len(t.Responder) > cbg.MaxLength { + return xerrors.Errorf("Value in field t.Responder was too long") + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len(t.Responder))); err != nil { + return err + } + if _, err := io.WriteString(w, string(t.Responder)); err != nil { + return err + } + + // t.BaseCid (cid.Cid) (struct) + if len("BaseCid") > cbg.MaxLength { + return xerrors.Errorf("Value in field \"BaseCid\" was too long") + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("BaseCid"))); err != nil { + return err + } + if _, err := io.WriteString(w, string("BaseCid")); err != nil { + return err + } + + if err := cbg.WriteCidBuf(scratch, w, t.BaseCid); err != nil { + return xerrors.Errorf("failed to write cid field t.BaseCid: %w", err) + } + + // t.Selector (internal.CborGenCompatibleNode) (struct) + if len("Selector") > cbg.MaxLength { + return xerrors.Errorf("Value in field \"Selector\" was too long") + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("Selector"))); err != nil { + return err + } + if _, err := io.WriteString(w, string("Selector")); err != nil { + return err + } + + if err := t.Selector.MarshalCBOR(w); err != nil { + return err + } + + // t.Sender (peer.ID) (string) + if len("Sender") > cbg.MaxLength { + return xerrors.Errorf("Value in field \"Sender\" was too long") + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("Sender"))); err != nil { + return err + } + if _, err := io.WriteString(w, string("Sender")); err != nil { + return err + } + + if len(t.Sender) > cbg.MaxLength { + return xerrors.Errorf("Value in field t.Sender was too long") + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len(t.Sender))); err != nil { + return err + } + if _, err := io.WriteString(w, string(t.Sender)); err != nil { + return err + } + + // t.Recipient (peer.ID) (string) + if len("Recipient") > cbg.MaxLength { + return xerrors.Errorf("Value in field \"Recipient\" was too long") + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("Recipient"))); err != nil { + return err + } + if _, err := io.WriteString(w, string("Recipient")); err != nil { + return err + } + + if len(t.Recipient) > cbg.MaxLength { + return xerrors.Errorf("Value in field t.Recipient was too long") + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len(t.Recipient))); err != nil { + return err + } + if _, err := io.WriteString(w, string(t.Recipient)); err != nil { + return err + } + + // t.TotalSize (uint64) (uint64) + if len("TotalSize") > cbg.MaxLength { + return xerrors.Errorf("Value in field \"TotalSize\" was too long") + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("TotalSize"))); err != nil { + return err + } + if _, err := io.WriteString(w, string("TotalSize")); err != nil { + return err + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajUnsignedInt, uint64(t.TotalSize)); err != nil { + return err + } + + // t.Status (datatransfer.Status) (uint64) + if len("Status") > cbg.MaxLength { + return xerrors.Errorf("Value in field \"Status\" was too long") + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("Status"))); err != nil { + return err + } + if _, err := io.WriteString(w, string("Status")); err != nil { + return err + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajUnsignedInt, uint64(t.Status)); err != nil { + return err + } + + // t.Queued (uint64) (uint64) + if len("Queued") > cbg.MaxLength { + return xerrors.Errorf("Value in field \"Queued\" was too long") + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("Queued"))); err != nil { + return err + } + if _, err := io.WriteString(w, string("Queued")); err != nil { + return err + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajUnsignedInt, uint64(t.Queued)); err != nil { + return err + } + + // t.Sent (uint64) (uint64) + if len("Sent") > cbg.MaxLength { + return xerrors.Errorf("Value in field \"Sent\" was too long") + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("Sent"))); err != nil { + return err + } + if _, err := io.WriteString(w, string("Sent")); err != nil { + return err + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajUnsignedInt, uint64(t.Sent)); err != nil { + return err + } + + // t.Received (uint64) (uint64) + if len("Received") > cbg.MaxLength { + return xerrors.Errorf("Value in field \"Received\" was too long") + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("Received"))); err != nil { + return err + } + if _, err := io.WriteString(w, string("Received")); err != nil { + return err + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajUnsignedInt, uint64(t.Received)); err != nil { + return err + } + + // t.Message (string) (string) + if len("Message") > cbg.MaxLength { + return xerrors.Errorf("Value in field \"Message\" was too long") + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("Message"))); err != nil { + return err + } + if _, err := io.WriteString(w, string("Message")); err != nil { + return err + } + + if len(t.Message) > cbg.MaxLength { + return xerrors.Errorf("Value in field t.Message was too long") + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len(t.Message))); err != nil { + return err + } + if _, err := io.WriteString(w, string(t.Message)); err != nil { + return err + } + + // t.Vouchers ([]internal.EncodedVoucher) (slice) + if len("Vouchers") > cbg.MaxLength { + return xerrors.Errorf("Value in field \"Vouchers\" was too long") + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("Vouchers"))); err != nil { + return err + } + if _, err := io.WriteString(w, string("Vouchers")); err != nil { + return err + } + + if len(t.Vouchers) > cbg.MaxLength { + return xerrors.Errorf("Slice value in field t.Vouchers was too long") + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajArray, uint64(len(t.Vouchers))); err != nil { + return err + } + for _, v := range t.Vouchers { + if err := v.MarshalCBOR(w); err != nil { + return err + } + } + + // t.VoucherResults ([]internal.EncodedVoucherResult) (slice) + if len("VoucherResults") > cbg.MaxLength { + return xerrors.Errorf("Value in field \"VoucherResults\" was too long") + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("VoucherResults"))); err != nil { + return err + } + if _, err := io.WriteString(w, string("VoucherResults")); err != nil { + return err + } + + if len(t.VoucherResults) > cbg.MaxLength { + return xerrors.Errorf("Slice value in field t.VoucherResults was too long") + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajArray, uint64(len(t.VoucherResults))); err != nil { + return err + } + for _, v := range t.VoucherResults { + if err := v.MarshalCBOR(w); err != nil { + return err + } + } + + // t.ReceivedBlocksTotal (int64) (int64) + if len("ReceivedBlocksTotal") > cbg.MaxLength { + return xerrors.Errorf("Value in field \"ReceivedBlocksTotal\" was too long") + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("ReceivedBlocksTotal"))); err != nil { + return err + } + if _, err := io.WriteString(w, string("ReceivedBlocksTotal")); err != nil { + return err + } + + if t.ReceivedBlocksTotal >= 0 { + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajUnsignedInt, uint64(t.ReceivedBlocksTotal)); err != nil { + return err + } + } else { + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajNegativeInt, uint64(-t.ReceivedBlocksTotal-1)); err != nil { + return err + } + } + + // t.QueuedBlocksTotal (int64) (int64) + if len("QueuedBlocksTotal") > cbg.MaxLength { + return xerrors.Errorf("Value in field \"QueuedBlocksTotal\" was too long") + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("QueuedBlocksTotal"))); err != nil { + return err + } + if _, err := io.WriteString(w, string("QueuedBlocksTotal")); err != nil { + return err + } + + if t.QueuedBlocksTotal >= 0 { + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajUnsignedInt, uint64(t.QueuedBlocksTotal)); err != nil { + return err + } + } else { + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajNegativeInt, uint64(-t.QueuedBlocksTotal-1)); err != nil { + return err + } + } + + // t.SentBlocksTotal (int64) (int64) + if len("SentBlocksTotal") > cbg.MaxLength { + return xerrors.Errorf("Value in field \"SentBlocksTotal\" was too long") + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("SentBlocksTotal"))); err != nil { + return err + } + if _, err := io.WriteString(w, string("SentBlocksTotal")); err != nil { + return err + } + + if t.SentBlocksTotal >= 0 { + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajUnsignedInt, uint64(t.SentBlocksTotal)); err != nil { + return err + } + } else { + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajNegativeInt, uint64(-t.SentBlocksTotal-1)); err != nil { + return err + } + } + + // t.DataLimit (uint64) (uint64) + if len("DataLimit") > cbg.MaxLength { + return xerrors.Errorf("Value in field \"DataLimit\" was too long") + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("DataLimit"))); err != nil { + return err + } + if _, err := io.WriteString(w, string("DataLimit")); err != nil { + return err + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajUnsignedInt, uint64(t.DataLimit)); err != nil { + return err + } + + // t.RequiresFinalization (bool) (bool) + if len("RequiresFinalization") > cbg.MaxLength { + return xerrors.Errorf("Value in field \"RequiresFinalization\" was too long") + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("RequiresFinalization"))); err != nil { + return err + } + if _, err := io.WriteString(w, string("RequiresFinalization")); err != nil { + return err + } + + if err := cbg.WriteBool(w, t.RequiresFinalization); err != nil { + return err + } + + // t.Stages (datatransfer.ChannelStages) (struct) + if len("Stages") > cbg.MaxLength { + return xerrors.Errorf("Value in field \"Stages\" was too long") + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("Stages"))); err != nil { + return err + } + if _, err := io.WriteString(w, string("Stages")); err != nil { + return err + } + + if err := t.Stages.MarshalCBOR(w); err != nil { + return err + } + return nil +} + +func (t *ChannelStateV2) UnmarshalCBOR(r io.Reader) error { + *t = ChannelStateV2{} + + br := cbg.GetPeeker(r) + scratch := make([]byte, 8) + + maj, extra, err := cbg.CborReadHeaderBuf(br, scratch) + if err != nil { + return err + } + if maj != cbg.MajMap { + return fmt.Errorf("cbor input should be of type map") + } + + if extra > cbg.MaxLength { + return fmt.Errorf("ChannelStateV2: map struct too large (%d)", extra) + } + + var name string + n := extra + + for i := uint64(0); i < n; i++ { + + { + sval, err := cbg.ReadStringBuf(br, scratch) + if err != nil { + return err + } + + name = string(sval) + } + + switch name { + // t.SelfPeer (peer.ID) (string) + case "SelfPeer": + + { + sval, err := cbg.ReadStringBuf(br, scratch) + if err != nil { + return err + } + + t.SelfPeer = peer.ID(sval) + } + // t.TransferID (datatransfer.TransferID) (uint64) + case "TransferID": + + { + + maj, extra, err = cbg.CborReadHeaderBuf(br, scratch) + if err != nil { + return err + } + if maj != cbg.MajUnsignedInt { + return fmt.Errorf("wrong type for uint64 field") + } + t.TransferID = datatransfer.TransferID(extra) + + } + // t.Initiator (peer.ID) (string) + case "Initiator": + + { + sval, err := cbg.ReadStringBuf(br, scratch) + if err != nil { + return err + } + + t.Initiator = peer.ID(sval) + } + // t.Responder (peer.ID) (string) + case "Responder": + + { + sval, err := cbg.ReadStringBuf(br, scratch) + if err != nil { + return err + } + + t.Responder = peer.ID(sval) + } + // t.BaseCid (cid.Cid) (struct) + case "BaseCid": + + { + + c, err := cbg.ReadCid(br) + if err != nil { + return xerrors.Errorf("failed to read cid field t.BaseCid: %w", err) + } + + t.BaseCid = c + + } + // t.Selector (internal.CborGenCompatibleNode) (struct) + case "Selector": + + { + + if err := t.Selector.UnmarshalCBOR(br); err != nil { + return xerrors.Errorf("unmarshaling t.Selector: %w", err) + } + + } + // t.Sender (peer.ID) (string) + case "Sender": + + { + sval, err := cbg.ReadStringBuf(br, scratch) + if err != nil { + return err + } + + t.Sender = peer.ID(sval) + } + // t.Recipient (peer.ID) (string) + case "Recipient": + + { + sval, err := cbg.ReadStringBuf(br, scratch) + if err != nil { + return err + } + + t.Recipient = peer.ID(sval) + } + // t.TotalSize (uint64) (uint64) + case "TotalSize": + + { + + maj, extra, err = cbg.CborReadHeaderBuf(br, scratch) + if err != nil { + return err + } + if maj != cbg.MajUnsignedInt { + return fmt.Errorf("wrong type for uint64 field") + } + t.TotalSize = uint64(extra) + + } + // t.Status (datatransfer.Status) (uint64) + case "Status": + + { + + maj, extra, err = cbg.CborReadHeaderBuf(br, scratch) + if err != nil { + return err + } + if maj != cbg.MajUnsignedInt { + return fmt.Errorf("wrong type for uint64 field") + } + t.Status = datatransfer.Status(extra) + + } + // t.Queued (uint64) (uint64) + case "Queued": + + { + + maj, extra, err = cbg.CborReadHeaderBuf(br, scratch) + if err != nil { + return err + } + if maj != cbg.MajUnsignedInt { + return fmt.Errorf("wrong type for uint64 field") + } + t.Queued = uint64(extra) + + } + // t.Sent (uint64) (uint64) + case "Sent": + + { + + maj, extra, err = cbg.CborReadHeaderBuf(br, scratch) + if err != nil { + return err + } + if maj != cbg.MajUnsignedInt { + return fmt.Errorf("wrong type for uint64 field") + } + t.Sent = uint64(extra) + + } + // t.Received (uint64) (uint64) + case "Received": + + { + + maj, extra, err = cbg.CborReadHeaderBuf(br, scratch) + if err != nil { + return err + } + if maj != cbg.MajUnsignedInt { + return fmt.Errorf("wrong type for uint64 field") + } + t.Received = uint64(extra) + + } + // t.Message (string) (string) + case "Message": + + { + sval, err := cbg.ReadStringBuf(br, scratch) + if err != nil { + return err + } + + t.Message = string(sval) + } + // t.Vouchers ([]internal.EncodedVoucher) (slice) + case "Vouchers": + + maj, extra, err = cbg.CborReadHeaderBuf(br, scratch) + if err != nil { + return err + } + + if extra > cbg.MaxLength { + return fmt.Errorf("t.Vouchers: array too large (%d)", extra) + } + + if maj != cbg.MajArray { + return fmt.Errorf("expected cbor array") + } + + if extra > 0 { + t.Vouchers = make([]internal.EncodedVoucher, extra) + } + + for i := 0; i < int(extra); i++ { + + var v internal.EncodedVoucher + if err := v.UnmarshalCBOR(br); err != nil { + return err + } + + t.Vouchers[i] = v + } + + // t.VoucherResults ([]internal.EncodedVoucherResult) (slice) + case "VoucherResults": + + maj, extra, err = cbg.CborReadHeaderBuf(br, scratch) + if err != nil { + return err + } + + if extra > cbg.MaxLength { + return fmt.Errorf("t.VoucherResults: array too large (%d)", extra) + } + + if maj != cbg.MajArray { + return fmt.Errorf("expected cbor array") + } + + if extra > 0 { + t.VoucherResults = make([]internal.EncodedVoucherResult, extra) + } + + for i := 0; i < int(extra); i++ { + + var v internal.EncodedVoucherResult + if err := v.UnmarshalCBOR(br); err != nil { + return err + } + + t.VoucherResults[i] = v + } + + // t.ReceivedBlocksTotal (int64) (int64) + case "ReceivedBlocksTotal": + { + maj, extra, err := cbg.CborReadHeaderBuf(br, scratch) + var extraI int64 + if err != nil { + return err + } + switch maj { + case cbg.MajUnsignedInt: + extraI = int64(extra) + if extraI < 0 { + return fmt.Errorf("int64 positive overflow") + } + case cbg.MajNegativeInt: + extraI = int64(extra) + if extraI < 0 { + return fmt.Errorf("int64 negative oveflow") + } + extraI = -1 - extraI + default: + return fmt.Errorf("wrong type for int64 field: %d", maj) + } + + t.ReceivedBlocksTotal = int64(extraI) + } + // t.QueuedBlocksTotal (int64) (int64) + case "QueuedBlocksTotal": + { + maj, extra, err := cbg.CborReadHeaderBuf(br, scratch) + var extraI int64 + if err != nil { + return err + } + switch maj { + case cbg.MajUnsignedInt: + extraI = int64(extra) + if extraI < 0 { + return fmt.Errorf("int64 positive overflow") + } + case cbg.MajNegativeInt: + extraI = int64(extra) + if extraI < 0 { + return fmt.Errorf("int64 negative oveflow") + } + extraI = -1 - extraI + default: + return fmt.Errorf("wrong type for int64 field: %d", maj) + } + + t.QueuedBlocksTotal = int64(extraI) + } + // t.SentBlocksTotal (int64) (int64) + case "SentBlocksTotal": + { + maj, extra, err := cbg.CborReadHeaderBuf(br, scratch) + var extraI int64 + if err != nil { + return err + } + switch maj { + case cbg.MajUnsignedInt: + extraI = int64(extra) + if extraI < 0 { + return fmt.Errorf("int64 positive overflow") + } + case cbg.MajNegativeInt: + extraI = int64(extra) + if extraI < 0 { + return fmt.Errorf("int64 negative oveflow") + } + extraI = -1 - extraI + default: + return fmt.Errorf("wrong type for int64 field: %d", maj) + } + + t.SentBlocksTotal = int64(extraI) + } + // t.DataLimit (uint64) (uint64) + case "DataLimit": + + { + + maj, extra, err = cbg.CborReadHeaderBuf(br, scratch) + if err != nil { + return err + } + if maj != cbg.MajUnsignedInt { + return fmt.Errorf("wrong type for uint64 field") + } + t.DataLimit = uint64(extra) + + } + // t.RequiresFinalization (bool) (bool) + case "RequiresFinalization": + + maj, extra, err = cbg.CborReadHeaderBuf(br, scratch) + if err != nil { + return err + } + if maj != cbg.MajOther { + return fmt.Errorf("booleans must be major type 7") + } + switch extra { + case 20: + t.RequiresFinalization = false + case 21: + t.RequiresFinalization = true + default: + return fmt.Errorf("booleans are either major type 7, value 20 or 21 (got %d)", extra) + } + // t.Stages (datatransfer.ChannelStages) (struct) + case "Stages": + + { + + b, err := br.ReadByte() + if err != nil { + return err + } + if b != cbg.CborNull[0] { + if err := br.UnreadByte(); err != nil { + return err + } + t.Stages = new(datatransfer.ChannelStages) + if err := t.Stages.UnmarshalCBOR(br); err != nil { + return xerrors.Errorf("unmarshaling t.Stages pointer: %w", err) + } + } + + } + + default: + // Field doesn't exist on this type, so ignore it + cbg.ScanForLinks(r, func(cid.Cid) {}) + } + } + + return nil +} diff --git a/encoding/encoding.go b/encoding/encoding.go deleted file mode 100644 index dec7abcd..00000000 --- a/encoding/encoding.go +++ /dev/null @@ -1,171 +0,0 @@ -package encoding - -import ( - "bytes" - "reflect" - - cbor "github.com/ipfs/go-ipld-cbor" - "github.com/ipld/go-ipld-prime" - "github.com/ipld/go-ipld-prime/codec/dagcbor" - "github.com/ipld/go-ipld-prime/datamodel" - "github.com/ipld/go-ipld-prime/node/basicnode" - "github.com/ipld/go-ipld-prime/schema" - cborgen "github.com/whyrusleeping/cbor-gen" - "golang.org/x/xerrors" -) - -// Encodable is an object that can be written to CBOR and decoded back -type Encodable interface{} - -// Encode encodes an encodable to CBOR, using the best available path for -// writing to CBOR -func Encode(value Encodable) ([]byte, error) { - if cbgEncodable, ok := value.(cborgen.CBORMarshaler); ok { - buf := new(bytes.Buffer) - err := cbgEncodable.MarshalCBOR(buf) - if err != nil { - return nil, err - } - return buf.Bytes(), nil - } - if ipldEncodable, ok := value.(datamodel.Node); ok { - if tn, ok := ipldEncodable.(schema.TypedNode); ok { - ipldEncodable = tn.Representation() - } - buf := &bytes.Buffer{} - err := dagcbor.Encode(ipldEncodable, buf) - if err != nil { - return nil, err - } - return buf.Bytes(), nil - } - return cbor.DumpObject(value) -} - -func EncodeToNode(encodable Encodable) (datamodel.Node, error) { - byts, err := Encode(encodable) - if err != nil { - return nil, err - } - na := basicnode.Prototype.Any.NewBuilder() - if err := dagcbor.Decode(na, bytes.NewReader(byts)); err != nil { - return nil, err - } - return na.Build(), nil -} - -// Decoder is CBOR decoder for a given encodable type -type Decoder interface { - DecodeFromCbor([]byte) (Encodable, error) - DecodeFromNode(datamodel.Node) (Encodable, error) -} - -// NewDecoder creates a new Decoder that will decode into new instances of the given -// object type. It will use the decoding that is optimal for that type -// It returns error if it's not possible to setup a decoder for this type -func NewDecoder(decodeType Encodable) (Decoder, error) { - // check if type is datamodel.Node, if so, just use style - if ipldDecodable, ok := decodeType.(datamodel.Node); ok { - return &ipldDecoder{ipldDecodable.Prototype()}, nil - } - // check if type is a pointer, as we need that to make new copies - // for cborgen types & regular IPLD types - decodeReflectType := reflect.TypeOf(decodeType) - if decodeReflectType.Kind() != reflect.Ptr { - return nil, xerrors.New("type must be a pointer") - } - // check if type is a cbor-gen type - if _, ok := decodeType.(cborgen.CBORUnmarshaler); ok { - return &cbgDecoder{decodeReflectType}, nil - } - // type does is neither ipld-prime nor cbor-gen, so we need to see if it - // can rountrip with oldschool ipld-format - encoded, err := cbor.DumpObject(decodeType) - if err != nil { - return nil, xerrors.New("Object type did not encode") - } - newDecodable := reflect.New(decodeReflectType.Elem()).Interface() - if err := cbor.DecodeInto(encoded, newDecodable); err != nil { - return nil, xerrors.New("Object type did not decode") - } - return &defaultDecoder{decodeReflectType}, nil -} - -type ipldDecoder struct { - style ipld.NodePrototype -} - -func (decoder *ipldDecoder) DecodeFromCbor(encoded []byte) (Encodable, error) { - builder := decoder.style.NewBuilder() - buf := bytes.NewReader(encoded) - err := dagcbor.Decode(builder, buf) - if err != nil { - return nil, err - } - return builder.Build(), nil -} - -func (decoder *ipldDecoder) DecodeFromNode(node datamodel.Node) (Encodable, error) { - builder := decoder.style.NewBuilder() - if err := builder.AssignNode(node); err != nil { - return nil, err - } - return builder.Build(), nil -} - -type cbgDecoder struct { - cbgType reflect.Type -} - -func (decoder *cbgDecoder) DecodeFromCbor(encoded []byte) (Encodable, error) { - decodedValue := reflect.New(decoder.cbgType.Elem()) - decoded, ok := decodedValue.Interface().(cborgen.CBORUnmarshaler) - if !ok || reflect.ValueOf(decoded).IsNil() { - return nil, xerrors.New("problem instantiating decoded value") - } - buf := bytes.NewReader(encoded) - err := decoded.UnmarshalCBOR(buf) - if err != nil { - return nil, err - } - return decoded, nil -} - -func (decoder *cbgDecoder) DecodeFromNode(node datamodel.Node) (Encodable, error) { - if tn, ok := node.(schema.TypedNode); ok { - node = tn.Representation() - } - buf := &bytes.Buffer{} - if err := dagcbor.Encode(node, buf); err != nil { - return nil, err - } - return decoder.DecodeFromCbor(buf.Bytes()) -} - -type defaultDecoder struct { - ptrType reflect.Type -} - -func (decoder *defaultDecoder) DecodeFromCbor(encoded []byte) (Encodable, error) { - decodedValue := reflect.New(decoder.ptrType.Elem()) - decoded, ok := decodedValue.Interface().(Encodable) - if !ok || reflect.ValueOf(decoded).IsNil() { - return nil, xerrors.New("problem instantiating decoded value") - } - err := cbor.DecodeInto(encoded, decoded) - if err != nil { - return nil, err - } - return decoded, nil -} - -func (decoder *defaultDecoder) DecodeFromNode(node datamodel.Node) (Encodable, error) { - if tn, ok := node.(schema.TypedNode); ok { - node = tn.Representation() - } - buf := &bytes.Buffer{} - if err := dagcbor.Encode(node, buf); err != nil { - return nil, err - } - return decoder.DecodeFromCbor(buf.Bytes()) -} diff --git a/encoding/encoding_test.go b/encoding/encoding_test.go deleted file mode 100644 index 43a66f5d..00000000 --- a/encoding/encoding_test.go +++ /dev/null @@ -1,37 +0,0 @@ -package encoding_test - -import ( - "testing" - - "github.com/stretchr/testify/require" - - "github.com/filecoin-project/go-data-transfer/encoding" - "github.com/filecoin-project/go-data-transfer/encoding/testdata" -) - -func TestRoundTrip(t *testing.T) { - testCases := map[string]struct { - val encoding.Encodable - }{ - "can encode/decode IPLD prime types": { - val: testdata.Prime, - }, - "can encode/decode cbor-gen types": { - val: testdata.Cbg, - }, - "can encode/decode old ipld format types": { - val: testdata.Standard, - }, - } - for testCase, data := range testCases { - t.Run(testCase, func(t *testing.T) { - encoded, err := encoding.Encode(data.val) - require.NoError(t, err) - decoder, err := encoding.NewDecoder(data.val) - require.NoError(t, err) - decoded, err := decoder.DecodeFromCbor(encoded) - require.NoError(t, err) - require.Equal(t, data.val, decoded) - }) - } -} diff --git a/encoding/testdata/testdata.go b/encoding/testdata/testdata.go deleted file mode 100644 index 5bed37ba..00000000 --- a/encoding/testdata/testdata.go +++ /dev/null @@ -1,37 +0,0 @@ -package testdata - -import ( - cbor "github.com/ipfs/go-ipld-cbor" - "github.com/ipld/go-ipld-prime/fluent" - basicnode "github.com/ipld/go-ipld-prime/node/basic" -) - -// Prime = an instance of an ipld prime piece of data -var Prime = fluent.MustBuildMap(basicnode.Prototype.Map, 2, func(na fluent.MapAssembler) { - nva := na.AssembleEntry("X") - nva.AssignInt(100) - nva = na.AssembleEntry("Y") - nva.AssignString("appleSauce") -}) - -type standardType struct { - X int - Y string -} - -func init() { - cbor.RegisterCborType(standardType{}) -} - -// Standard = an instance that is neither ipld prime nor cbor -var Standard *standardType = &standardType{X: 100, Y: "appleSauce"} - -//go:generate cbor-gen-for cbgType - -type cbgType struct { - X uint64 - Y string -} - -// Cbg = an instance of a cbor-gen type -var Cbg *cbgType = &cbgType{X: 100, Y: "appleSauce"} diff --git a/encoding/testdata/testdata_cbor_gen.go b/encoding/testdata/testdata_cbor_gen.go deleted file mode 100644 index 67c6c688..00000000 --- a/encoding/testdata/testdata_cbor_gen.go +++ /dev/null @@ -1,84 +0,0 @@ -// Code generated by github.com/whyrusleeping/cbor-gen. DO NOT EDIT. - -package testdata - -import ( - "fmt" - "io" - - cbg "github.com/whyrusleeping/cbor-gen" - xerrors "golang.org/x/xerrors" -) - -var _ = xerrors.Errorf - -func (t *cbgType) MarshalCBOR(w io.Writer) error { - if t == nil { - _, err := w.Write(cbg.CborNull) - return err - } - if _, err := w.Write([]byte{130}); err != nil { - return err - } - - // t.X (uint64) (uint64) - - if _, err := w.Write(cbg.CborEncodeMajorType(cbg.MajUnsignedInt, uint64(t.X))); err != nil { - return err - } - - // t.Y (string) (string) - if len(t.Y) > cbg.MaxLength { - return xerrors.Errorf("Value in field t.Y was too long") - } - - if _, err := w.Write(cbg.CborEncodeMajorType(cbg.MajTextString, uint64(len(t.Y)))); err != nil { - return err - } - if _, err := w.Write([]byte(t.Y)); err != nil { - return err - } - return nil -} - -func (t *cbgType) UnmarshalCBOR(r io.Reader) error { - br := cbg.GetPeeker(r) - - maj, extra, err := cbg.CborReadHeader(br) - if err != nil { - return err - } - if maj != cbg.MajArray { - return fmt.Errorf("cbor input should be of type array") - } - - if extra != 2 { - return fmt.Errorf("cbor input had wrong number of fields") - } - - // t.X (uint64) (uint64) - - { - - maj, extra, err = cbg.CborReadHeader(br) - if err != nil { - return err - } - if maj != cbg.MajUnsignedInt { - return fmt.Errorf("wrong type for uint64 field") - } - t.X = uint64(extra) - - } - // t.Y (string) (string) - - { - sval, err := cbg.ReadString(br) - if err != nil { - return err - } - - t.Y = string(sval) - } - return nil -} diff --git a/errors.go b/errors.go index 0e9903f6..592444c0 100644 --- a/errors.go +++ b/errors.go @@ -17,14 +17,6 @@ const ErrHandlerNotSet = errorType("event handler has not been set") // ErrChannelNotFound means the channel this command was issued for does not exist const ErrChannelNotFound = errorType("channel not found") -// ErrPause is a special error that the DataReceived / DataSent hooks can -// use to pause the channel -const ErrPause = errorType("pause channel") - -// ErrResume is a special error that the RequestReceived / ResponseReceived hooks can -// use to resume the channel -const ErrResume = errorType("resume channel") - // ErrRejected indicates a request was not accepted const ErrRejected = errorType("response rejected") diff --git a/events.go b/events.go index 664579c4..1cc42be8 100644 --- a/events.go +++ b/events.go @@ -61,7 +61,7 @@ const ( // initiator BeginFinalizing - // Disconnected emits when we are not able to connect to the other party + // DEPRECATED in favor of SendMessageError Disconnected // Complete is emitted when a data transfer is complete @@ -91,7 +91,7 @@ const ( // data has been received. DataReceivedProgress - // Deprecated in favour of RequestCancelled + // DEPRECATED in favour of RequestCancelled RequestTimedOut // SendDataError indicates that the transport layer had an error trying @@ -102,7 +102,7 @@ const ( // receiving data from the remote peer ReceiveDataError - // TransferRequestQueued indicates that a new data transfer request has been queued in the transport layer + // DEPRECATED in favor of TransferInitiated TransferRequestQueued // RequestCancelled indicates that a transport layer request was cancelled by the request opener @@ -110,6 +110,25 @@ const ( // Opened is fired when a request for data is sent from this node to a peer Opened + + // SetDataLimit is fired when a responder sets a limit for data it will allow + // before pausing the request + SetDataLimit + + // SetRequiresFinalization is fired when a responder sets a limit for data it will allow + // before pausing the request + SetRequiresFinalization + + // DataLimitExceeded is fired when a request exceeds it's data limit. It has the effect of + // pausing the responder, but is distinct from PauseResponder to indicate why the pause + // happened + DataLimitExceeded + + // TransferInitiated indicates the transport has begun transferring data + TransferInitiated + + // SendMessageError indicates an error sending a data transfer message + SendMessageError ) // Events are human readable names for data transfer events @@ -144,6 +163,12 @@ var Events = map[EventCode]string{ ReceiveDataError: "ReceiveDataError", TransferRequestQueued: "TransferRequestQueued", RequestCancelled: "RequestCancelled", + Opened: "Opened", + SetDataLimit: "SetDataLimit", + SetRequiresFinalization: "SetRequiresFinalization", + DataLimitExceeded: "DataLimitExceeded", + TransferInitiated: "TransferInitiated", + SendMessageError: "SendMessageError", } // Event is a struct containing information about a data transfer event diff --git a/go.mod b/go.mod index 38ea9cc6..608fa39f 100644 --- a/go.mod +++ b/go.mod @@ -1,4 +1,4 @@ -module github.com/filecoin-project/go-data-transfer +module github.com/filecoin-project/go-data-transfer/v2 go 1.17 @@ -10,22 +10,21 @@ require ( github.com/hannahhoward/go-pubsub v0.0.0-20200423002714-8d62886cc36e github.com/ipfs/go-block-format v0.0.3 github.com/ipfs/go-blockservice v0.2.1 - github.com/ipfs/go-cid v0.1.0 + github.com/ipfs/go-cid v0.2.0 github.com/ipfs/go-datastore v0.5.1 github.com/ipfs/go-ds-badger v0.3.0 - github.com/ipfs/go-graphsync v0.13.1 + github.com/ipfs/go-graphsync v0.13.3-0.20220625074430-a95496cf1534 github.com/ipfs/go-ipfs-blockstore v1.1.2 github.com/ipfs/go-ipfs-blocksutil v0.0.1 github.com/ipfs/go-ipfs-chunker v0.0.5 github.com/ipfs/go-ipfs-delay v0.0.1 github.com/ipfs/go-ipfs-exchange-offline v0.1.1 github.com/ipfs/go-ipfs-files v0.0.8 - github.com/ipfs/go-ipld-cbor v0.0.5 github.com/ipfs/go-ipld-format v0.2.0 github.com/ipfs/go-log/v2 v2.5.1 github.com/ipfs/go-merkledag v0.5.1 github.com/ipfs/go-unixfs v0.3.1 - github.com/ipld/go-ipld-prime v0.16.0 + github.com/ipld/go-ipld-prime v0.17.1-0.20220624062450-534ccf82237d github.com/jbenet/go-random v0.0.0-20190219211222-123a90aedc0c github.com/jpillora/backoff v1.0.0 github.com/libp2p/go-libp2p v0.19.4 @@ -71,6 +70,7 @@ require ( github.com/ipfs/go-ipfs-posinfo v0.0.1 // indirect github.com/ipfs/go-ipfs-pq v0.0.2 // indirect github.com/ipfs/go-ipfs-util v0.0.2 // indirect + github.com/ipfs/go-ipld-cbor v0.0.5 // indirect github.com/ipfs/go-ipld-legacy v0.1.0 // indirect github.com/ipfs/go-log v1.0.5 // indirect github.com/ipfs/go-metrics-interface v0.0.1 // indirect @@ -105,7 +105,7 @@ require ( github.com/multiformats/go-multiaddr-dns v0.3.1 // indirect github.com/multiformats/go-multiaddr-fmt v0.1.0 // indirect github.com/multiformats/go-multibase v0.0.3 // indirect - github.com/multiformats/go-multicodec v0.4.1 // indirect + github.com/multiformats/go-multicodec v0.5.0 // indirect github.com/multiformats/go-multihash v0.1.0 // indirect github.com/multiformats/go-multistream v0.3.0 // indirect github.com/multiformats/go-varint v0.0.6 // indirect @@ -119,7 +119,6 @@ require ( github.com/spaolacci/murmur3 v1.1.0 // indirect github.com/urfave/cli/v2 v2.0.0 // indirect github.com/whyrusleeping/chunker v0.0.0-20181014151217-fe64bd25879f // indirect - go.uber.org/goleak v1.1.12 // indirect go.uber.org/multierr v1.8.0 // indirect go.uber.org/zap v1.21.0 // indirect golang.org/x/crypto v0.0.0-20220411220226-7b82a4e95df4 // indirect diff --git a/go.sum b/go.sum index c78f04f5..7da44a75 100644 --- a/go.sum +++ b/go.sum @@ -126,7 +126,6 @@ github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDk github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc= github.com/cncf/udpa/go v0.0.0-20201120205902-5459f2c99403/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk= github.com/cncf/udpa/go v0.0.0-20210930031921-04548b0d99d4/go.mod h1:6pvJx4me5XPnfI9Z40ddWsdw2W/uZgQLFXToKeRcDiI= -github.com/cncf/xds/go v0.0.0-20210312221358-fbca930ec8ed/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= github.com/cncf/xds/go v0.0.0-20210805033703-aa0b78936158/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= github.com/cncf/xds/go v0.0.0-20210922020428-25de7278fc84/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= github.com/cncf/xds/go v0.0.0-20211011173535-cb28da3451f1/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= @@ -196,7 +195,6 @@ github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymF github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98= github.com/envoyproxy/go-control-plane v0.9.9-0.20201210154907-fd9021fe5dad/go.mod h1:cXg6YxExXjJnVBQHBLXeUAgxn2UodCpnH306RInaBQk= -github.com/envoyproxy/go-control-plane v0.9.9-0.20210512163311-63b5d3c536b0/go.mod h1:hliV/p42l8fGbc6Y9bQ70uLwIvmJyVE5k4iMKlh8wCQ= github.com/envoyproxy/go-control-plane v0.9.10-0.20210907150352-cf90f659a021/go.mod h1:AFq3mo9L8Lqqiid3OhADV3RfLJnjiw63cSpi+fDTRC0= github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5KwzbycvMj4= @@ -219,8 +217,9 @@ github.com/franela/goblin v0.0.0-20200105215937-c9ffbefa60db/go.mod h1:7dvUGVsVB github.com/franela/goreq v0.0.0-20171204163338-bcd34c9993f8/go.mod h1:ZhphrRTfi2rbfLwlschooIH4+wKKDR4Pdxhh+TRoA20= github.com/frankban/quicktest v1.11.3/go.mod h1:wRf/ReqHper53s+kmmSZizM8NamnL3IM0I9ntUbOk+k= github.com/frankban/quicktest v1.14.0/go.mod h1:NeW+ay9A/U67EYXNFA1nPE8e/tnQv/09mUdL/ijj8og= -github.com/frankban/quicktest v1.14.2 h1:SPb1KFFmM+ybpEjPUhCCkZOM5xlovT5UbrMvWnXyBns= github.com/frankban/quicktest v1.14.2/go.mod h1:mgiwOwqx65TmIk1wJ6Q7wvnVMocbUorkibMOrVTHZps= +github.com/frankban/quicktest v1.14.3 h1:FJKSZTDHjyhriyC81FLQ0LY93eSai0ZyR/ZIkd3ZUKE= +github.com/frankban/quicktest v1.14.3/go.mod h1:mgiwOwqx65TmIk1wJ6Q7wvnVMocbUorkibMOrVTHZps= github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ= github.com/fsnotify/fsnotify v1.5.1 h1:mZcQUHVQUQWoPXXtuf9yuEXKudkV2sx1E06UadKWpgI= @@ -312,8 +311,9 @@ github.com/google/go-cmp v0.5.3/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/ github.com/google/go-cmp v0.5.4/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/go-cmp v0.5.7 h1:81/ik6ipDQS2aGcBfIN5dHDB36BwrStyeAQquSYCV4o= github.com/google/go-cmp v0.5.7/go.mod h1:n+brtR0CgQNWTVd5ZUFpTBC8YFBDLK/h/bpaJ8/DtOE= +github.com/google/go-cmp v0.5.8 h1:e6P7q2lk1O+qJJb4BtCQXlK8vWEO8V1ZeuEdJNOqZyg= +github.com/google/go-cmp v0.5.8/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/go-github v17.0.0+incompatible/go.mod h1:zLgOLi98H3fifZn+44m+umXrS52loVEgC2AApnigrVQ= github.com/google/go-querystring v1.0.0/go.mod h1:odCYkC5MyYFN7vkCjXpyrEuKhc/BUO6wN/zVPAxq5ck= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= @@ -389,7 +389,6 @@ github.com/hashicorp/serf v0.8.2/go.mod h1:6hOLApaqBFA1NXqRQAsxw9QxuDEvNxSQRwA/J github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= github.com/hudl/fargo v1.3.0/go.mod h1:y3CKSmjA+wD2gak7sUSXTAoopbhU08POFhmITJgmKTg= github.com/huin/goupnp v1.0.0/go.mod h1:n9v9KO1tAxYH82qOn+UTIFQDmx5n1Zxd/ClZDMX7Bnc= -github.com/huin/goupnp v1.0.2/go.mod h1:0dxJBVBHqTMjIUMkESDTNgOOx/Mw5wYIfyFmdzSamkM= github.com/huin/goupnp v1.0.3 h1:N8No57ls+MnjlB+JPiCVSOyy/ot7MJTqlo7rn+NYSqQ= github.com/huin/goupnp v1.0.3/go.mod h1:ZxNlw5WqJj6wSsRK5+YfflQGXYfccj5VgQsMNixHM7Y= github.com/huin/goutil v0.0.0-20170803182201-1ca381bf3150/go.mod h1:PpLOETDnJ0o3iZrZfqZzyLl6l7F3c6L1oWn7OICBi6o= @@ -415,8 +414,9 @@ github.com/ipfs/go-cid v0.0.4/go.mod h1:4LLaPOQwmk5z9LBgQnpkivrx8BJjUyGwTXCd5Xfj github.com/ipfs/go-cid v0.0.5/go.mod h1:plgt+Y5MnOey4vO4UlUazGqdbEXuFYitED67FexhXog= github.com/ipfs/go-cid v0.0.6/go.mod h1:6Ux9z5e+HpkQdckYoX1PG/6xqKspzlEIR5SDmgqgC/I= github.com/ipfs/go-cid v0.0.7/go.mod h1:6Ux9z5e+HpkQdckYoX1PG/6xqKspzlEIR5SDmgqgC/I= -github.com/ipfs/go-cid v0.1.0 h1:YN33LQulcRHjfom/i25yoOZR4Telp1Hr/2RU3d0PnC0= github.com/ipfs/go-cid v0.1.0/go.mod h1:rH5/Xv83Rfy8Rw6xG+id3DYAMUVmem1MowoKwdXmN2o= +github.com/ipfs/go-cid v0.2.0 h1:01JTiihFq9en9Vz0lc0VDWvZe/uBonGpzo4THP0vcQ0= +github.com/ipfs/go-cid v0.2.0/go.mod h1:P+HXFDF4CVhaVayiEb4wkAy7zBHxBwsJyt0Y5U6MLro= github.com/ipfs/go-datastore v0.0.1/go.mod h1:d4KVXhMt913cLBEI/PXAy6ko+W7e9AhyAKBGh803qeE= github.com/ipfs/go-datastore v0.1.1/go.mod h1:w38XXW9kVFNp57Zj5knbKWM2T+KOZCGDRVNdgPHtbHw= github.com/ipfs/go-datastore v0.4.0/go.mod h1:SX/xMIKoCszPqp+z9JhPYCmoOoXTvaa13XEbGtsFUhA= @@ -438,8 +438,8 @@ github.com/ipfs/go-ds-leveldb v0.0.1/go.mod h1:feO8V3kubwsEF22n0YRQCffeb79OOYIyk github.com/ipfs/go-ds-leveldb v0.4.1/go.mod h1:jpbku/YqBSsBc1qgME8BkWS4AxzF2cEu1Ii2r79Hh9s= github.com/ipfs/go-ds-leveldb v0.4.2/go.mod h1:jpbku/YqBSsBc1qgME8BkWS4AxzF2cEu1Ii2r79Hh9s= github.com/ipfs/go-ds-leveldb v0.5.0/go.mod h1:d3XG9RUDzQ6V4SHi8+Xgj9j1XuEk1z82lquxrVbml/Q= -github.com/ipfs/go-graphsync v0.13.1 h1:lWiP/WLycoPUYyj3IDEi1GJNP30kFuYOvimcfeuZyQs= -github.com/ipfs/go-graphsync v0.13.1/go.mod h1:y8e8G6CmZeL9Srvx1l15CtGiRdf3h5JdQuqPz/iYL0A= +github.com/ipfs/go-graphsync v0.13.3-0.20220625074430-a95496cf1534 h1:sn7viAPyx3qZVhfRpXhW23mPtzl9rjJKtJ/HM/HsyZg= +github.com/ipfs/go-graphsync v0.13.3-0.20220625074430-a95496cf1534/go.mod h1:RKAui2+/HmlIVnuAXJIn0jltvOAXkl7wz3SYysmYnPI= github.com/ipfs/go-ipfs-blockstore v0.2.1/go.mod h1:jGesd8EtCM3/zPgx+qr0/feTXGUeRai6adgwC+Q+JvE= github.com/ipfs/go-ipfs-blockstore v1.1.2 h1:WCXoZcMYnvOTmlpX+RSSnhVN0uCmbWTeepTGX5lgiXw= github.com/ipfs/go-ipfs-blockstore v1.1.2/go.mod h1:w51tNR9y5+QXB0wkNcHt4O2aSZjTdqaEWaQdSxEyUOY= @@ -518,8 +518,9 @@ github.com/ipld/go-codec-dagpb v1.3.1/go.mod h1:ErNNglIi5KMur/MfFE/svtgQthzVvf+4 github.com/ipld/go-ipld-prime v0.9.1-0.20210324083106-dc342a9917db/go.mod h1:KvBLMr4PX1gWptgkzRjVZCrLmSGcZCb/jioOQwCqZN8= github.com/ipld/go-ipld-prime v0.11.0/go.mod h1:+WIAkokurHmZ/KwzDOMUuoeJgaRQktHtEaLglS3ZeV8= github.com/ipld/go-ipld-prime v0.14.0/go.mod h1:9ASQLwUFLptCov6lIYc70GRB4V7UTyLD0IJtrDJe6ZM= -github.com/ipld/go-ipld-prime v0.16.0 h1:RS5hhjB/mcpeEPJvfyj0qbOj/QL+/j05heZ0qa97dVo= github.com/ipld/go-ipld-prime v0.16.0/go.mod h1:axSCuOCBPqrH+gvXr2w9uAOulJqBPhHPT2PjoiiU1qA= +github.com/ipld/go-ipld-prime v0.17.1-0.20220624062450-534ccf82237d h1:aY4pwcHVHonF+edc4gzRr3HA7vAaindLXz7InFIUgiY= +github.com/ipld/go-ipld-prime v0.17.1-0.20220624062450-534ccf82237d/go.mod h1:aYcKm5TIvGfY8P3QBKz/2gKcLxzJ1zDaD+o0bOowhgs= github.com/ipld/go-ipld-prime/storage/bsadapter v0.0.0-20211210234204-ce2a1c70cd73/go.mod h1:2PJ0JgxyB08t0b2WKrcuqI3di0V+5n6RS/LTUJhkoxY= github.com/jackpal/gateway v1.0.5/go.mod h1:lTpwd4ACLXmpyiCTRtfiNyVnUmqT9RivzCDQetPfnjA= github.com/jackpal/go-nat-pmp v1.0.1/go.mod h1:QPH045xvCAeXUZOxsnwmrtiCoxIr9eob+4orBN1SBKc= @@ -612,7 +613,6 @@ github.com/libp2p/go-libp2p v0.7.0/go.mod h1:hZJf8txWeCduQRDC/WSqBGMxaTHCOYHt2xS github.com/libp2p/go-libp2p v0.7.4/go.mod h1:oXsBlTLF1q7pxr+9w6lqzS1ILpyHsaBPniVO7zIHGMw= github.com/libp2p/go-libp2p v0.8.1/go.mod h1:QRNH9pwdbEBpx5DTJYg+qxcVaDMAz3Ee/qDKwXujH5o= github.com/libp2p/go-libp2p v0.14.3/go.mod h1:d12V4PdKbpL0T1/gsUNN8DfgMuRPDX8bS2QxCZlwRH0= -github.com/libp2p/go-libp2p v0.16.0/go.mod h1:ump42BsirwAWxKzsCiFnTtN1Yc+DuPu76fyMX364/O4= github.com/libp2p/go-libp2p v0.19.4 h1:50YL0YwPhWKDd+qbZQDEdnsmVAAkaCQrWUjpdHv4hNA= github.com/libp2p/go-libp2p v0.19.4/go.mod h1:MIt8y481VDhUe4ErWi1a4bvt/CjjFfOq6kZTothWIXY= github.com/libp2p/go-libp2p-asn-util v0.1.0 h1:rABPCO77SjdbJ/eJ/ynIo8vWICy1VEnL5JAxJbQLo1E= @@ -622,7 +622,6 @@ github.com/libp2p/go-libp2p-autonat v0.2.0/go.mod h1:DX+9teU4pEEoZUqR1PiMlqliONQ github.com/libp2p/go-libp2p-autonat v0.2.1/go.mod h1:MWtAhV5Ko1l6QBsHQNSuM6b1sRkXrpk0/LqCr+vCVxI= github.com/libp2p/go-libp2p-autonat v0.2.2/go.mod h1:HsM62HkqZmHR2k1xgX34WuWDzk/nBwNHoeyyT4IWV6A= github.com/libp2p/go-libp2p-autonat v0.4.2/go.mod h1:YxaJlpr81FhdOv3W3BTconZPfhaYivRdf53g+S2wobk= -github.com/libp2p/go-libp2p-autonat v0.6.0/go.mod h1:bFC6kY8jwzNNWoqc8iGE57vsfwyJ/lP4O4DOV1e0B2o= github.com/libp2p/go-libp2p-blankhost v0.1.1/go.mod h1:pf2fvdLJPsC1FsVrNP3DUUvMzUts2dsLLBEpo1vW1ro= github.com/libp2p/go-libp2p-blankhost v0.1.4/go.mod h1:oJF0saYsAXQCSfDq254GMNmLNz6ZTHTOvtF4ZydUvwU= github.com/libp2p/go-libp2p-blankhost v0.2.0/go.mod h1:eduNKXGTioTuQAUcZ5epXi9vMl+t4d8ugUBRQ4SqaNQ= @@ -655,7 +654,6 @@ github.com/libp2p/go-libp2p-core v0.8.1/go.mod h1:FfewUH/YpvWbEB+ZY9AQRQ4TAD8sJB github.com/libp2p/go-libp2p-core v0.8.2/go.mod h1:FfewUH/YpvWbEB+ZY9AQRQ4TAD8sJBt/G1rVvhz5XT8= github.com/libp2p/go-libp2p-core v0.8.5/go.mod h1:FfewUH/YpvWbEB+ZY9AQRQ4TAD8sJBt/G1rVvhz5XT8= github.com/libp2p/go-libp2p-core v0.8.6/go.mod h1:dgHr0l0hIKfWpGpqAMbpo19pen9wJfdCGv51mTmdpmM= -github.com/libp2p/go-libp2p-core v0.9.0/go.mod h1:ESsbz31oC3C1AvMJoGx26RTuCkNhmkSRCqZ0kQtJ2/8= github.com/libp2p/go-libp2p-core v0.10.0/go.mod h1:ECdxehoYosLYHgDDFa2N4yE8Y7aQRAMf0sX9mf2sbGg= github.com/libp2p/go-libp2p-core v0.11.0/go.mod h1:ECdxehoYosLYHgDDFa2N4yE8Y7aQRAMf0sX9mf2sbGg= github.com/libp2p/go-libp2p-core v0.12.0/go.mod h1:ECdxehoYosLYHgDDFa2N4yE8Y7aQRAMf0sX9mf2sbGg= @@ -666,7 +664,6 @@ github.com/libp2p/go-libp2p-crypto v0.1.0/go.mod h1:sPUokVISZiy+nNuTTH/TY+leRSxn github.com/libp2p/go-libp2p-discovery v0.2.0/go.mod h1:s4VGaxYMbw4+4+tsoQTqh7wfxg97AEdo4GYBt6BadWg= github.com/libp2p/go-libp2p-discovery v0.3.0/go.mod h1:o03drFnz9BVAZdzC/QUQ+NeQOu38Fu7LJGEOK2gQltw= github.com/libp2p/go-libp2p-discovery v0.5.0/go.mod h1:+srtPIU9gDaBNu//UHvcdliKBIcr4SfDcm0/PfPJLug= -github.com/libp2p/go-libp2p-discovery v0.6.0/go.mod h1:/u1voHt0tKIe5oIA1RHBKQLVCWPna2dXmPNHc2zR9S8= github.com/libp2p/go-libp2p-loggables v0.1.0 h1:h3w8QFfCt2UJl/0/NW4K829HX/0S4KD31PQ7m8UXXO8= github.com/libp2p/go-libp2p-loggables v0.1.0/go.mod h1:EyumB2Y6PrYjr55Q3/tiJ/o3xoDasoRYM7nOzEpoa90= github.com/libp2p/go-libp2p-mplex v0.2.0/go.mod h1:Ejl9IyjvXJ0T9iqUTE1jpYATQ9NM3g+OtR+EMMODbKo= @@ -683,7 +680,6 @@ github.com/libp2p/go-libp2p-nat v0.1.0/go.mod h1:DQzAG+QbDYjN1/C3B6vXucLtz3u9rEo github.com/libp2p/go-libp2p-netutil v0.1.0 h1:zscYDNVEcGxyUpMd0JReUZTrpMfia8PmLKcKF72EAMQ= github.com/libp2p/go-libp2p-netutil v0.1.0/go.mod h1:3Qv/aDqtMLTUyQeundkKsA+YCThNdbQD54k3TqjpbFU= github.com/libp2p/go-libp2p-noise v0.2.0/go.mod h1:IEbYhBBzGyvdLBoxxULL/SGbJARhUeqlO8lVSREYu2Q= -github.com/libp2p/go-libp2p-noise v0.3.0/go.mod h1:JNjHbociDJKHD64KTkzGnzqJ0FEV5gHJa6AB00kbCNQ= github.com/libp2p/go-libp2p-noise v0.4.0 h1:khcMsGhHNdGqKE5LDLrnHwZvdGVMsrnD4GTkTWkwmLU= github.com/libp2p/go-libp2p-noise v0.4.0/go.mod h1:BzzY5pyzCYSyJbQy9oD8z5oP2idsafjt4/X42h9DjZU= github.com/libp2p/go-libp2p-peer v0.2.0/go.mod h1:RCffaCvUyW2CJmG2gAWVqwePwW7JMgxjsHm7+J5kjWY= @@ -701,7 +697,6 @@ github.com/libp2p/go-libp2p-pnet v0.2.0 h1:J6htxttBipJujEjz1y0a5+eYoiPcFHhSYHH6n github.com/libp2p/go-libp2p-pnet v0.2.0/go.mod h1:Qqvq6JH/oMZGwqs3N1Fqhv8NVhrdYcO0BW4wssv21LA= github.com/libp2p/go-libp2p-quic-transport v0.10.0/go.mod h1:RfJbZ8IqXIhxBRm5hqUEJqjiiY8xmEuq3HUDS993MkA= github.com/libp2p/go-libp2p-quic-transport v0.13.0/go.mod h1:39/ZWJ1TW/jx1iFkKzzUg00W6tDJh73FC0xYudjr7Hc= -github.com/libp2p/go-libp2p-quic-transport v0.15.0/go.mod h1:wv4uGwjcqe8Mhjj7N/Ic0aKjA+/10UnMlSzLO0yRpYQ= github.com/libp2p/go-libp2p-quic-transport v0.16.0/go.mod h1:1BXjVMzr+w7EkPfiHkKnwsWjPjtfaNT0q8RS3tGDvEQ= github.com/libp2p/go-libp2p-quic-transport v0.17.0 h1:yFh4Gf5MlToAYLuw/dRvuzYd1EnE2pX3Lq1N6KDiWRQ= github.com/libp2p/go-libp2p-quic-transport v0.17.0/go.mod h1:x4pw61P3/GRCcSLypcQJE/Q2+E9f4X+5aRcZLXf20LM= @@ -732,7 +727,6 @@ github.com/libp2p/go-libp2p-testing v0.1.1/go.mod h1:xaZWMJrPUM5GlDBxCeGUi7kI4eq github.com/libp2p/go-libp2p-testing v0.1.2-0.20200422005655-8775583591d8/go.mod h1:Qy8sAncLKpwXtS2dSnDOP8ktexIAHKu+J+pnZOFZLTc= github.com/libp2p/go-libp2p-testing v0.3.0/go.mod h1:efZkql4UZ7OVsEfaxNHZPzIehtsBXMrXnCfJIgDti5g= github.com/libp2p/go-libp2p-testing v0.4.0/go.mod h1:Q+PFXYoiYFN5CAEG2w3gLPEzotlKsNSbKQ/lImlOWF0= -github.com/libp2p/go-libp2p-testing v0.4.2/go.mod h1:Q+PFXYoiYFN5CAEG2w3gLPEzotlKsNSbKQ/lImlOWF0= github.com/libp2p/go-libp2p-testing v0.5.0/go.mod h1:QBk8fqIL1XNcno/l3/hhaIEn4aLRijpYOR+zVjjlh+A= github.com/libp2p/go-libp2p-testing v0.7.0/go.mod h1:OLbdn9DbgdMwv00v+tlp1l3oe2Cl+FAjoWIA2pa0X6E= github.com/libp2p/go-libp2p-testing v0.9.0/go.mod h1:Td7kbdkWqYTJYQGTwzlgXwaqldraIanyjuRiAbK/XQU= @@ -740,14 +734,12 @@ github.com/libp2p/go-libp2p-testing v0.9.2 h1:dCpODRtRaDZKF8HXT9qqqgON+OMEB423Kn github.com/libp2p/go-libp2p-testing v0.9.2/go.mod h1:Td7kbdkWqYTJYQGTwzlgXwaqldraIanyjuRiAbK/XQU= github.com/libp2p/go-libp2p-tls v0.1.3/go.mod h1:wZfuewxOndz5RTnCAxFliGjvYSDA40sKitV4c50uI1M= github.com/libp2p/go-libp2p-tls v0.3.0/go.mod h1:fwF5X6PWGxm6IDRwF3V8AVCCj/hOd5oFlg+wo2FxJDY= -github.com/libp2p/go-libp2p-tls v0.3.1/go.mod h1:fwF5X6PWGxm6IDRwF3V8AVCCj/hOd5oFlg+wo2FxJDY= github.com/libp2p/go-libp2p-tls v0.4.1 h1:1ByJUbyoMXvYXDoW6lLsMxqMViQNXmt+CfQqlnCpY+M= github.com/libp2p/go-libp2p-tls v0.4.1/go.mod h1:EKCixHEysLNDlLUoKxv+3f/Lp90O2EXNjTr0UQDnrIw= github.com/libp2p/go-libp2p-transport-upgrader v0.1.1/go.mod h1:IEtA6or8JUbsV07qPW4r01GnTenLW4oi3lOPbUMGJJA= github.com/libp2p/go-libp2p-transport-upgrader v0.2.0/go.mod h1:mQcrHj4asu6ArfSoMuyojOdjx73Q47cYD7s5+gZOlns= github.com/libp2p/go-libp2p-transport-upgrader v0.3.0/go.mod h1:i+SKzbRnvXdVbU3D1dwydnTmKRPXiAR/fyvi1dXuL4o= github.com/libp2p/go-libp2p-transport-upgrader v0.4.2/go.mod h1:NR8ne1VwfreD5VIWIU62Agt/J18ekORFU/j1i2y8zvk= -github.com/libp2p/go-libp2p-transport-upgrader v0.4.3/go.mod h1:bpkldbOWXMrXhpZbSV1mQxTrefOg2Fi+k1ClDSA4ppw= github.com/libp2p/go-libp2p-transport-upgrader v0.5.0/go.mod h1:Rc+XODlB3yce7dvFV4q/RmyJGsFcCZRkeZMu/Zdg0mo= github.com/libp2p/go-libp2p-transport-upgrader v0.7.0/go.mod h1:GIR2aTRp1J5yjVlkUoFqMkdobfob6RnAwYg/RZPhrzg= github.com/libp2p/go-libp2p-transport-upgrader v0.7.1 h1:MSMe+tUfxpC9GArTz7a4G5zQKQgGh00Vio87d3j3xIg= @@ -760,7 +752,6 @@ github.com/libp2p/go-libp2p-yamux v0.2.8/go.mod h1:/t6tDqeuZf0INZMTgd0WxIRbtK2Ez github.com/libp2p/go-libp2p-yamux v0.4.0/go.mod h1:+DWDjtFMzoAwYLVkNZftoucn7PelNoy5nm3tZ3/Zw30= github.com/libp2p/go-libp2p-yamux v0.5.0/go.mod h1:AyR8k5EzyM2QN9Bbdg6X1SkVVuqLwTGf0L4DFq9g6po= github.com/libp2p/go-libp2p-yamux v0.5.4/go.mod h1:tfrXbyaTqqSU654GTvK3ocnSZL3BuHoeTSqhcel1wsE= -github.com/libp2p/go-libp2p-yamux v0.6.0/go.mod h1:MRhd6mAYnFRnSISp4M8i0ClV/j+mWHo2mYLifWGw33k= github.com/libp2p/go-libp2p-yamux v0.8.0/go.mod h1:yTkPgN2ib8FHyU1ZcVD7aelzyAqXXwEPbyx+aSKm9h8= github.com/libp2p/go-libp2p-yamux v0.8.1/go.mod h1:rUozF8Jah2dL9LLGyBaBeTQeARdwhefMCTQVQt6QobE= github.com/libp2p/go-libp2p-yamux v0.9.1 h1:oplewiRix8s45SOrI30rCPZG5mM087YZp+VYhXAh4+c= @@ -778,7 +769,6 @@ github.com/libp2p/go-mplex v0.4.0/go.mod h1:y26Lx+wNVtMYMaPu300Cbot5LkEZ4tJaNYeH github.com/libp2p/go-msgio v0.0.2/go.mod h1:63lBBgOTDKQL6EWazRMCwXsEeEeK9O2Cd+0+6OOuipQ= github.com/libp2p/go-msgio v0.0.4/go.mod h1:63lBBgOTDKQL6EWazRMCwXsEeEeK9O2Cd+0+6OOuipQ= github.com/libp2p/go-msgio v0.0.6/go.mod h1:4ecVB6d9f4BDSL5fqvPiC4A3KivjWn+Venn/1ALLMWA= -github.com/libp2p/go-msgio v0.1.0/go.mod h1:eNlv2vy9V2X/kNldcZ+SShFE++o2Yjxwx6RAYsmgJnE= github.com/libp2p/go-msgio v0.2.0 h1:W6shmB+FeynDrUVl2dgFQvzfBZcXiyqY4VmpQLu9FqU= github.com/libp2p/go-msgio v0.2.0/go.mod h1:dBVM1gW3Jk9XqHkU4eKdGvVHdLa51hoGfll6jMJMSlY= github.com/libp2p/go-nat v0.0.4/go.mod h1:Nmw50VAvKuk38jUBcmNh6p9lUJLoODbJRvYAa/+KSDo= @@ -825,7 +815,6 @@ github.com/libp2p/go-tcp-transport v0.5.1/go.mod h1:UPPL0DIjQqiWRwVAb+CEQlaAG0rp github.com/libp2p/go-ws-transport v0.2.0/go.mod h1:9BHJz/4Q5A9ludYWKoGCFC5gUElzlHoKzu0yY9p/klM= github.com/libp2p/go-ws-transport v0.3.0/go.mod h1:bpgTJmRZAvVHrgHybCVyqoBmyLQ1fiZuEaBYusP5zsk= github.com/libp2p/go-ws-transport v0.4.0/go.mod h1:EcIEKqf/7GDjth6ksuS/6p7R49V4CBY6/E7R/iyhYUA= -github.com/libp2p/go-ws-transport v0.5.0/go.mod h1:I2juo1dNTbl8BKSBYo98XY85kU2xds1iamArLvl8kNg= github.com/libp2p/go-ws-transport v0.6.0 h1:326XBL6Q+5CQ2KtjXz32+eGu02W/Kz2+Fm4SpXdr0q4= github.com/libp2p/go-ws-transport v0.6.0/go.mod h1:dXqtI9e2JV9FtF1NOtWVZSKXh5zXvnuwPXfj8GPBbYU= github.com/libp2p/go-yamux v1.2.2/go.mod h1:FGTiPvoV/3DVdgWpX+tM0OW3tsM+W5bSE3gZwqQTcow= @@ -837,7 +826,6 @@ github.com/libp2p/go-yamux v1.4.0/go.mod h1:fr7aVgmdNGJK+N1g+b6DW6VxzbRCjCOejR/h github.com/libp2p/go-yamux v1.4.1 h1:P1Fe9vF4th5JOxxgQvfbOHkrGqIZniTLf+ddhZp8YTI= github.com/libp2p/go-yamux v1.4.1/go.mod h1:fr7aVgmdNGJK+N1g+b6DW6VxzbRCjCOejR/hkmpooHE= github.com/libp2p/go-yamux/v2 v2.2.0/go.mod h1:3So6P6TV6r75R9jiBpiIKgU/66lOarCZjqROGxzPpPQ= -github.com/libp2p/go-yamux/v2 v2.3.0/go.mod h1:iTU+lOIn/2h0AgKcL49clNTwfEw+WSfDYrXe05EyKIs= github.com/libp2p/go-yamux/v3 v3.0.1/go.mod h1:s2LsDhHbh+RfCsQoICSYt58U2f8ijtPANFD8BmE74Bo= github.com/libp2p/go-yamux/v3 v3.0.2/go.mod h1:s2LsDhHbh+RfCsQoICSYt58U2f8ijtPANFD8BmE74Bo= github.com/libp2p/go-yamux/v3 v3.1.1/go.mod h1:jeLEQgLXqE2YqX1ilAClIfCMDY+0uXQUKmmb/qp0gT4= @@ -848,7 +836,6 @@ github.com/lightstep/lightstep-tracer-common/golang/gogo v0.0.0-20190605223551-b github.com/lightstep/lightstep-tracer-go v0.18.1/go.mod h1:jlF1pusYV4pidLvZ+XD0UBX0ZE6WURAspgAczcDHrL4= github.com/lucas-clemente/quic-go v0.19.3/go.mod h1:ADXpNbTQjq1hIzCpB+y/k5iz4n4z4IwqoLb94Kh5Hu8= github.com/lucas-clemente/quic-go v0.23.0/go.mod h1:paZuzjXCE5mj6sikVLMvqXk8lJV2AsqtJ6bDhjEfxx0= -github.com/lucas-clemente/quic-go v0.24.0/go.mod h1:paZuzjXCE5mj6sikVLMvqXk8lJV2AsqtJ6bDhjEfxx0= github.com/lucas-clemente/quic-go v0.25.0/go.mod h1:YtzP8bxRVCBlO77yRanE264+fY/T2U9ZlW1AaHOsMOg= github.com/lucas-clemente/quic-go v0.27.0/go.mod h1:AzgQoPda7N+3IqMMMkywBKggIFo2KT6pfnlrQ2QieeI= github.com/lucas-clemente/quic-go v0.27.1 h1:sOw+4kFSVrdWOYmUjufQ9GBVPqZ+tu+jMtXxXNmRJyk= @@ -968,8 +955,9 @@ github.com/multiformats/go-multibase v0.0.3/go.mod h1:5+1R4eQrT3PkYZ24C3W2Ue2tPw github.com/multiformats/go-multicodec v0.3.0/go.mod h1:qGGaQmioCDh+TeFOnxrbU0DaIPw8yFgAZgFG0V7p1qQ= github.com/multiformats/go-multicodec v0.3.1-0.20210902112759-1539a079fd61/go.mod h1:1Hj/eHRaVWSXiSNNfcEPcwZleTmdNP81xlxDLnWU9GQ= github.com/multiformats/go-multicodec v0.3.1-0.20211210143421-a526f306ed2c/go.mod h1:1Hj/eHRaVWSXiSNNfcEPcwZleTmdNP81xlxDLnWU9GQ= -github.com/multiformats/go-multicodec v0.4.1 h1:BSJbf+zpghcZMZrwTYBGwy0CPcVZGWiC72Cp8bBd4R4= github.com/multiformats/go-multicodec v0.4.1/go.mod h1:1Hj/eHRaVWSXiSNNfcEPcwZleTmdNP81xlxDLnWU9GQ= +github.com/multiformats/go-multicodec v0.5.0 h1:EgU6cBe/D7WRwQb1KmnBvU7lrcFGMggZVTPtOW9dDHs= +github.com/multiformats/go-multicodec v0.5.0/go.mod h1:DiY2HFaEp5EhEXb/iYzVAunmyX/aSFMxq2KMKfWEues= github.com/multiformats/go-multihash v0.0.1/go.mod h1:w/5tugSrLEbWqlcgJabL3oHFKTwfvkofsjW2Qa1ct4U= github.com/multiformats/go-multihash v0.0.5/go.mod h1:lt/HCbqlQwlPBz7lv0sQCdtfcMtlJvakRUn/0Ual8po= github.com/multiformats/go-multihash v0.0.8/go.mod h1:YSLudS+Pi8NHE7o6tb3D8vrpKa63epEDmG8nTduyAew= @@ -1089,7 +1077,6 @@ github.com/prometheus/common v0.10.0/go.mod h1:Tlit/dnDKsSWFlCLTWaA1cyBgKHSMdTB8 github.com/prometheus/common v0.15.0/go.mod h1:U+gB1OBLb1lF3O42bTCL+FK18tX9Oar16Clt/msog/s= github.com/prometheus/common v0.18.0/go.mod h1:U+gB1OBLb1lF3O42bTCL+FK18tX9Oar16Clt/msog/s= github.com/prometheus/common v0.26.0/go.mod h1:M7rCNAaPfAosfx8veZJCuw84e35h3Cfd9VFqTh1DIvc= -github.com/prometheus/common v0.30.0/go.mod h1:vu+V0TpY+O6vW9J44gczi3Ap/oXXR10b+M/gUGO4Hls= github.com/prometheus/common v0.32.1/go.mod h1:vu+V0TpY+O6vW9J44gczi3Ap/oXXR10b+M/gUGO4Hls= github.com/prometheus/common v0.33.0 h1:rHgav/0a6+uYgGdNt3jwz8FNSesO/Hsang3O0T9A5SE= github.com/prometheus/common v0.33.0/go.mod h1:gB3sOl7P0TvJabZpLY5uQMpUqRCPPCyRLCZYc7JZTNE= @@ -1204,8 +1191,9 @@ github.com/urfave/cli/v2 v2.0.0 h1:+HU9SCbu8GnEUFtIBfuUNXN39ofWViIEJIp6SURMpCg= github.com/urfave/cli/v2 v2.0.0/go.mod h1:SE9GqnLQmjVa0iPEY0f1w3ygNIYcIJ0OKPMoW2caLfQ= github.com/viant/assertly v0.4.8/go.mod h1:aGifi++jvCrUaklKEKT0BU95igDNaqkvz+49uaYMPRU= github.com/viant/toolbox v0.24.0/go.mod h1:OxMCG57V0PXuIP2HNQrtJf2CjqdmbrOx5EkMILuUhzM= -github.com/warpfork/go-testmark v0.3.0 h1:Q81c4u7hT+BR5kNfNQhEF0VT2pmL7+Kk0wD+ORYl7iA= github.com/warpfork/go-testmark v0.3.0/go.mod h1:jhEf8FVxd+F17juRubpmut64NEG6I2rgkUhlcqqXwE0= +github.com/warpfork/go-testmark v0.10.0 h1:E86YlUMYfwIacEsQGlnTvjk1IgYkyTGjPhF0RnwTCmw= +github.com/warpfork/go-testmark v0.10.0/go.mod h1:jhEf8FVxd+F17juRubpmut64NEG6I2rgkUhlcqqXwE0= github.com/warpfork/go-wish v0.0.0-20180510122957-5ad1f5abf436/go.mod h1:x6AKhvSSexNrVSrViXSHUEbICjmGXhtgABaHIySUSGw= github.com/warpfork/go-wish v0.0.0-20190328234359-8b3e70f8e830/go.mod h1:x6AKhvSSexNrVSrViXSHUEbICjmGXhtgABaHIySUSGw= github.com/warpfork/go-wish v0.0.0-20200122115046-b9ea61034e4a h1:G++j5e0OC488te356JvdhaM8YS6nMsjLAYF7JxCv07w= @@ -1289,7 +1277,6 @@ go.uber.org/zap v1.13.0/go.mod h1:zwrFLgMcdUuIBviXEYEH1YKNaOBnKXsx2IPda5bBwHM= go.uber.org/zap v1.14.1/go.mod h1:Mb2vm2krFEG5DV0W9qcHBYFtp/Wku1cvYaqPsS/WYfc= go.uber.org/zap v1.15.0/go.mod h1:Mb2vm2krFEG5DV0W9qcHBYFtp/Wku1cvYaqPsS/WYfc= go.uber.org/zap v1.16.0/go.mod h1:MA8QOfq0BHJwdXa996Y4dYkAqRKB8/1K1QMMZVaNZjQ= -go.uber.org/zap v1.19.0/go.mod h1:xg/QME4nWcxGxrpdeYfq7UvYrLh66cuVKdrbD1XF/NI= go.uber.org/zap v1.19.1/go.mod h1:j3DNczoxDZroyBnOT1L/Q79cfUMGZxlv/9dzN7SM1rI= go.uber.org/zap v1.21.0 h1:WefMeulhovoZ2sYXz7st6K0sLj7bBhpiFaud4r4zST8= go.uber.org/zap v1.21.0/go.mod h1:wjWOCqI0f2ZZrJF/UufIOkiC8ii6tm1iqIsLo76RfJw= @@ -1325,7 +1312,6 @@ golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm golang.org/x/crypto v0.0.0-20210506145944-38f3c27a63bf/go.mod h1:P+XmwS30IXTQdn5tA2iutPOUgjI07+tq3H3K9MVA1s8= golang.org/x/crypto v0.0.0-20210513164829-c07d793c2f9a/go.mod h1:P+XmwS30IXTQdn5tA2iutPOUgjI07+tq3H3K9MVA1s8= golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= -golang.org/x/crypto v0.0.0-20210813211128-0a44fdfbc16e/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20220411220226-7b82a4e95df4 h1:kUhD7nTDoI3fVd9G4ORWrbV5NY0liEs/Jg2pv5f+bBA= golang.org/x/crypto v0.0.0-20220411220226-7b82a4e95df4/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= @@ -1421,7 +1407,6 @@ golang.org/x/net v0.0.0-20210423184538-5f58ad60dda6/go.mod h1:OJAsFXCWl8Ukc7SiCT golang.org/x/net v0.0.0-20210428140749-89ef3d95e781/go.mod h1:OJAsFXCWl8Ukc7SiCT/9KSuxbyM7479/AVlXFRxuMCk= golang.org/x/net v0.0.0-20210525063256-abc453219eb5/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20210726213435-c6fcb2dbf985/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= -golang.org/x/net v0.0.0-20210813160813-60bc85c4be6d/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20211015210444-4f30a5c0130f/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= @@ -1529,7 +1514,6 @@ golang.org/x/sys v0.0.0-20210514084401-e8d321eab015/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20210603081109-ebe580a85c40/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20210816183151-1e6c022a8912/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20211019181941-9d821ace8654/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220114195835-da31bd327af9/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -1703,7 +1687,6 @@ google.golang.org/grpc v1.31.1/go.mod h1:N36X2cJ7JwdamYAgDz+s+rVMFjt3numwzf/HckM google.golang.org/grpc v1.33.1/go.mod h1:fr5YgcSWrqhRRxogOsw7RzIpsmvOZ6IcH4kBYTpR3n0= google.golang.org/grpc v1.33.2/go.mod h1:JMHMWHQWaTccqQQlmk3MJZS+GWXOdAesneDmEnv2fbc= google.golang.org/grpc v1.36.0/go.mod h1:qjiiYl8FncCW8feJPdyg3v6XW24KsRHe+dy9BAGRRjU= -google.golang.org/grpc v1.40.0/go.mod h1:ogyxbiOoUXAkP+4+xa6PZSE9DZgIHtSpzjDTB9KAK34= google.golang.org/grpc v1.45.0 h1:NEpgUqV3Z+ZjkqMsxMg11IaDrXY4RY6CQukSGK0uI1M= google.golang.org/grpc v1.45.0/go.mod h1:lN7owxKUQEqMfSyQikvvk5tf/6zMPsrK+ONuO11+0rQ= google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= diff --git a/impl/environment.go b/impl/environment.go index 102e1441..4753344f 100644 --- a/impl/environment.go +++ b/impl/environment.go @@ -3,23 +3,15 @@ package impl import ( "github.com/libp2p/go-libp2p-core/peer" - datatransfer "github.com/filecoin-project/go-data-transfer" + datatransfer "github.com/filecoin-project/go-data-transfer/v2" ) type channelEnvironment struct { m *manager } -func (ce *channelEnvironment) Protect(id peer.ID, tag string) { - ce.m.dataTransferNetwork.Protect(id, tag) -} - -func (ce *channelEnvironment) Unprotect(id peer.ID, tag string) bool { - return ce.m.dataTransferNetwork.Unprotect(id, tag) -} - func (ce *channelEnvironment) ID() peer.ID { - return ce.m.dataTransferNetwork.ID() + return ce.m.peerID } func (ce *channelEnvironment) CleanupChannel(chid datatransfer.ChannelID) { diff --git a/impl/events.go b/impl/events.go index 2c6860fb..af2085f1 100644 --- a/impl/events.go +++ b/impl/events.go @@ -3,333 +3,150 @@ package impl import ( "context" "errors" + "fmt" - "github.com/ipfs/go-cid" - "github.com/ipld/go-ipld-prime" - cidlink "github.com/ipld/go-ipld-prime/linking/cid" - "github.com/libp2p/go-libp2p-core/peer" - "go.opentelemetry.io/otel" - "go.opentelemetry.io/otel/attribute" - "go.opentelemetry.io/otel/trace" "golang.org/x/xerrors" - datatransfer "github.com/filecoin-project/go-data-transfer" - "github.com/filecoin-project/go-data-transfer/encoding" - "github.com/filecoin-project/go-data-transfer/registry" + datatransfer "github.com/filecoin-project/go-data-transfer/v2" + "github.com/filecoin-project/go-data-transfer/v2/channels" + "github.com/filecoin-project/go-data-transfer/v2/message" ) -// OnChannelOpened is called when we send a request for data to the other -// peer on the given channel ID -func (m *manager) OnChannelOpened(chid datatransfer.ChannelID) error { - log.Infof("channel %s: opened", chid) - - // Check if the channel is being tracked - has, err := m.channels.HasChannel(chid) - if err != nil { - return err - } - if !has { - return datatransfer.ErrChannelNotFound - } - - // Fire an event - return m.channels.ChannelOpened(chid) -} - -// OnDataReceived is called when the transport layer reports that it has -// received some data from the sender. -// It fires an event on the channel, updating the sum of received data and -// calls revalidators so they can pause / resume the channel or send a -// message over the transport. -func (m *manager) OnDataReceived(chid datatransfer.ChannelID, link ipld.Link, size uint64, index int64, unique bool) error { +// OnTransportEvent is dispatched when an event occurs on the transport +func (m *manager) OnTransportEvent(chid datatransfer.ChannelID, evt datatransfer.TransportEvent) { ctx, _ := m.spansIndex.SpanForChannel(context.TODO(), chid) - ctx, span := otel.Tracer("data-transfer").Start(ctx, "dataReceived", trace.WithAttributes( - attribute.String("channelID", chid.String()), - attribute.String("link", link.String()), - attribute.Int64("index", index), - attribute.Int64("size", int64(size)), - )) - defer span.End() - - isNew, err := m.channels.DataReceived(chid, link.(cidlink.Link).Cid, size, index, unique) + err := m.processTransferEvent(ctx, chid, evt) if err != nil { - return err - } - - // If this block has already been received on the channel, take no further - // action (this can happen when the data-transfer channel is restarted) - if !isNew { - return nil - } - - // If this node initiated the data transfer, there's nothing more to do - if chid.Initiator == m.peerID { - return nil - } - - // Check each revalidator to see if they want to pause / resume, or send - // a message over the transport - var result datatransfer.VoucherResult - var handled bool - _ = m.revalidators.Each(func(_ datatransfer.TypeIdentifier, _ encoding.Decoder, processor registry.Processor) error { - revalidator := processor.(datatransfer.Revalidator) - handled, result, err = revalidator.OnPushDataReceived(chid, size) - if handled { - return errors.New("stop processing") - } - return nil - }) - if err != nil || result != nil { - msg, err := m.processRevalidationResult(chid, result, err) - if msg != nil { - ctx, _ := m.spansIndex.SpanForChannel(context.TODO(), chid) - if err := m.dataTransferNetwork.SendMessage(ctx, chid.Initiator, msg); err != nil { - return err - } - } - return err + log.Infof("error on channel: %s, closing channel", err) + err := m.closeChannelWithError(ctx, chid, err) + if err != nil { + log.Errorf("error closing channel: %s", err) + } + } +} + +func (m *manager) processTransferEvent(ctx context.Context, chid datatransfer.ChannelID, transportEvent datatransfer.TransportEvent) error { + switch evt := transportEvent.(type) { + case datatransfer.TransportOpenedChannel: + return m.channels.ChannelOpened(chid) + case datatransfer.TransportInitiatedTransfer: + return m.channels.TransferInitiated(chid) + case datatransfer.TransportReceivedData: + return m.channels.DataReceived(chid, evt.Size, evt.Index) + case datatransfer.TransportSentData: + return m.channels.DataSent(chid, evt.Size, evt.Index) + case datatransfer.TransportQueuedData: + return m.channels.DataQueued(chid, evt.Size, evt.Index) + case datatransfer.TransportReachedDataLimit: + if err := m.channels.DataLimitExceeded(chid); err != nil { + return err + } + msg := message.UpdateResponse(chid.ID, true) + return m.transport.SendMessage(ctx, chid, msg) + case datatransfer.TransportTransferCancelled: + log.Warnf("channel %+v was cancelled: %s", chid, evt.ErrorMessage) + return m.channels.RequestCancelled(chid, errors.New(evt.ErrorMessage)) + + case datatransfer.TransportErrorSendingData: + log.Debugf("channel %+v had transport send error: %s", chid, evt.ErrorMessage) + return m.channels.SendDataError(chid, errors.New(evt.ErrorMessage)) + case datatransfer.TransportErrorReceivingData: + log.Debugf("channel %+v had transport receive error: %s", chid, evt.ErrorMessage) + return m.channels.ReceiveDataError(chid, errors.New(evt.ErrorMessage)) + case datatransfer.TransportCompletedTransfer: + return m.channelCompleted(chid, evt.Success, evt.ErrorMessage) + case datatransfer.TransportReceivedRestartExistingChannelRequest: + return m.restartExistingChannelRequestReceived(chid) + case datatransfer.TransportErrorSendingMessage: + return m.channels.SendMessageError(chid, errors.New(evt.ErrorMessage)) + case datatransfer.TransportPaused: + return m.pause(chid) + case datatransfer.TransportResumed: + return m.resume(chid) } - return nil } -// OnDataQueued is called when the transport layer reports that it has queued -// up some data to be sent to the requester. -// It fires an event on the channel, updating the sum of queued data and calls -// revalidators so they can pause / resume or send a message over the transport. -func (m *manager) OnDataQueued(chid datatransfer.ChannelID, link ipld.Link, size uint64, index int64, unique bool) (datatransfer.Message, error) { - // The transport layer reports that some data has been queued up to be sent - // to the requester, so fire a DataQueued event on the channels state - // machine. - - ctx, _ := m.spansIndex.SpanForChannel(context.TODO(), chid) - ctx, span := otel.Tracer("data-transfer").Start(ctx, "dataQueued", trace.WithAttributes( - attribute.String("channelID", chid.String()), - attribute.String("link", link.String()), - attribute.Int64("size", int64(size)), - )) - defer span.End() - - isNew, err := m.channels.DataQueued(chid, link.(cidlink.Link).Cid, size, index, unique) - if err != nil { - return nil, err - } - - // If this block has already been queued on the channel, take no further - // action (this can happen when the data-transfer channel is restarted) - if !isNew { - return nil, nil - } - - // If this node initiated the data transfer, there's nothing more to do - if chid.Initiator == m.peerID { - return nil, nil - } - - // Check each revalidator to see if they want to pause / resume, or send - // a message over the transport. - // For example if the data-sender is waiting for the receiver to pay for - // data they may pause the data-transfer. - var result datatransfer.VoucherResult - var handled bool - _ = m.revalidators.Each(func(_ datatransfer.TypeIdentifier, _ encoding.Decoder, processor registry.Processor) error { - revalidator := processor.(datatransfer.Revalidator) - handled, result, err = revalidator.OnPullDataSent(chid, size) - if handled { - return errors.New("stop processing") - } - return nil - }) - if err != nil || result != nil { - return m.processRevalidationResult(chid, result, err) - } - - return nil, nil -} - -func (m *manager) OnDataSent(chid datatransfer.ChannelID, link ipld.Link, size uint64, index int64, unique bool) error { - - ctx, _ := m.spansIndex.SpanForChannel(context.TODO(), chid) - ctx, span := otel.Tracer("data-transfer").Start(ctx, "dataSent", trace.WithAttributes( - attribute.String("channelID", chid.String()), - attribute.String("link", link.String()), - attribute.Int64("size", int64(size)), - )) - defer span.End() - - _, err := m.channels.DataSent(chid, link.(cidlink.Link).Cid, size, index, unique) - return err -} - -func (m *manager) OnRequestReceived(chid datatransfer.ChannelID, request datatransfer.Request) (datatransfer.Response, error) { - if request.IsRestart() { - return m.receiveRestartRequest(chid, request) - } - - if request.IsNew() { - return m.receiveNewRequest(chid, request) - } - if request.IsCancel() { - log.Infof("channel %s: received cancel request, cleaning up channel", chid) - - m.transport.CleanupChannel(chid) - return nil, m.channels.Cancel(chid) - } - if request.IsVoucher() { - return m.processUpdateVoucher(chid, request) - } - if request.IsPaused() { - return nil, m.pauseOther(chid) - } - err := m.resumeOther(chid) - if err != nil { - return nil, err - } - chst, err := m.channels.GetByID(context.TODO(), chid) - if err != nil { - return nil, err - } - if chst.Status() == datatransfer.ResponderPaused || - chst.Status() == datatransfer.ResponderFinalizing { - return nil, datatransfer.ErrPause - } - return nil, nil -} - -func (m *manager) OnTransferQueued(chid datatransfer.ChannelID) { - m.channels.TransferRequestQueued(chid) -} - +// OnResponseReceived is called when a Response message is received from the responder +// on the initiator func (m *manager) OnResponseReceived(chid datatransfer.ChannelID, response datatransfer.Response) error { - if response.IsComplete() { - log.Infow("received complete response", "chid", chid, "isAccepted", response.Accepted()) - } + // if response is cancel, process as cancel if response.IsCancel() { log.Infof("channel %s: received cancel response, cancelling channel", chid) return m.channels.Cancel(chid) } - if response.IsVoucherResult() { + + // does this response contain a response to a validation attempt? + if response.IsValidationResult() { + + // is there a voucher response in this message? if !response.EmptyVoucherResult() { - vresult, err := m.decodeVoucherResult(response) + // if so decode and save it + vresult, err := response.VoucherResult() if err != nil { return err } - err = m.channels.NewVoucherResult(chid, vresult) + err = m.channels.NewVoucherResult(chid, datatransfer.TypedVoucher{Voucher: vresult, Type: response.VoucherResultType()}) if err != nil { return err } } + + // was the validateion attempt successful? if !response.Accepted() { + // if not, error and fail log.Infof("channel %s: received rejected response, erroring out channel", chid) return m.channels.Error(chid, datatransfer.ErrRejected) } - if response.IsNew() { - log.Infof("channel %s: received new response, accepting channel", chid) - err := m.channels.Accept(chid) - if err != nil { - return err - } + } + + // was this the first response to our initial request + if response.IsNew() { + log.Infof("channel %s: received new response, accepting channel", chid) + // if so, record an accept event (not accepted has already been handled) + err := m.channels.Accept(chid) + if err != nil { + return err } + } - if response.IsRestart() { - log.Infof("channel %s: received restart response, restarting channel", chid) - err := m.channels.Restart(chid) - if err != nil { - return err - } + // was this a response to a restart attempt? + if response.IsRestart() { + log.Infof("channel %s: received restart response, restarting channel", chid) + // if so, record restart + err := m.channels.Restart(chid) + if err != nil { + return err } } - if response.IsComplete() && response.Accepted() { + + // was this response a final status message? + if response.IsComplete() { + // is the responder paused pending final settlement? if !response.IsPaused() { + // if not, mark the responder done and return log.Infow("received complete response,responder not paused, completing channel", "chid", chid) return m.channels.ResponderCompletes(chid) } + // if yes, mark the responder being in final settlement log.Infow("received complete response, responder is paused, not completing channel", "chid", chid) err := m.channels.ResponderBeginsFinalization(chid) if err != nil { - return nil + return err } } + + // handle pause/resume for all response types if response.IsPaused() { return m.pauseOther(chid) } return m.resumeOther(chid) } -func (m *manager) OnRequestCancelled(chid datatransfer.ChannelID, err error) error { - log.Warnf("channel %+v was cancelled: %s", chid, err) - return m.channels.RequestCancelled(chid, err) -} - -func (m *manager) OnRequestDisconnected(chid datatransfer.ChannelID, err error) error { - log.Warnf("channel %+v has stalled or disconnected: %s", chid, err) - return m.channels.Disconnected(chid, err) -} - -func (m *manager) OnSendDataError(chid datatransfer.ChannelID, err error) error { - log.Debugf("channel %+v had transport send error: %s", chid, err) - return m.channels.SendDataError(chid, err) -} - -func (m *manager) OnReceiveDataError(chid datatransfer.ChannelID, err error) error { - log.Debugf("channel %+v had transport receive error: %s", chid, err) - return m.channels.ReceiveDataError(chid, err) -} - -// OnChannelCompleted is called -// - by the requester when all data for a transfer has been received -// - by the responder when all data for a transfer has been sent -func (m *manager) OnChannelCompleted(chid datatransfer.ChannelID, completeErr error) error { - // If the channel completed successfully - if completeErr == nil { - // If the channel was initiated by the other peer - if chid.Initiator != m.peerID { - log.Infow("received OnChannelCompleted, will send completion message to initiator", "chid", chid) - msg, err := m.completeMessage(chid) - if err != nil { - return err - } - if msg != nil { - // Send the other peer a message that the transfer has completed - log.Infow("sending completion message to initiator", "chid", chid) - ctx, _ := m.spansIndex.SpanForChannel(context.Background(), chid) - if err := m.dataTransferNetwork.SendMessage(ctx, chid.Initiator, msg); err != nil { - err := xerrors.Errorf("channel %s: failed to send completion message to initiator: %w", chid, err) - log.Warnw("failed to send completion message to initiator", "chid", chid, "err", err) - return m.OnRequestDisconnected(chid, err) - } - log.Infow("successfully sent completion message to initiator", "chid", chid) - } - if msg.Accepted() { - if msg.IsPaused() { - return m.channels.BeginFinalizing(chid) - } - return m.channels.Complete(chid) - } - return m.channels.Error(chid, err) - } - - // The channel was initiated by this node, so move to the finished state - log.Infof("channel %s: transfer initiated by local node is complete", chid) - return m.channels.FinishTransfer(chid) - } - - // There was an error so fire an Error event - chst, err := m.channels.GetByID(context.TODO(), chid) - if err != nil { - return err - } - // send an error, but only if we haven't already errored for some reason - if chst.Status() != datatransfer.Failing && chst.Status() != datatransfer.Failed { - err := xerrors.Errorf("data transfer channel %s failed to transfer data: %w", chid, completeErr) - log.Warnf(err.Error()) - return m.channels.Error(chid, err) - } - return nil -} - +// OnContextAugment provides an oppurtunity for transports to have data transfer add data to their context (i.e. +// to tie into tracing, etc) func (m *manager) OnContextAugment(chid datatransfer.ChannelID) func(context.Context) context.Context { return func(ctx context.Context) context.Context { ctx, _ = m.spansIndex.SpanForChannel(ctx, chid) @@ -337,254 +154,72 @@ func (m *manager) OnContextAugment(chid datatransfer.ChannelID) func(context.Con } } -func (m *manager) receiveRestartRequest(chid datatransfer.ChannelID, incoming datatransfer.Request) (datatransfer.Response, error) { - log.Infof("channel %s: received restart request", chid) - - result, err := m.restartRequest(chid, incoming) - msg, msgErr := m.response(true, false, err, incoming.TransferID(), result) - if msgErr != nil { - return nil, msgErr - } - return msg, err -} - -func (m *manager) receiveNewRequest(chid datatransfer.ChannelID, incoming datatransfer.Request) (datatransfer.Response, error) { - log.Infof("channel %s: received new channel request from %s", chid, chid.Initiator) - - result, err := m.acceptRequest(chid, incoming) - msg, msgErr := m.response(false, true, err, incoming.TransferID(), result) - if msgErr != nil { - return nil, msgErr - } - return msg, err -} - -func (m *manager) restartRequest(chid datatransfer.ChannelID, - incoming datatransfer.Request) (datatransfer.VoucherResult, error) { - - initiator := chid.Initiator - if m.peerID == initiator { - return nil, xerrors.New("initiator cannot be manager peer for a restart request") - } - - if err := m.validateRestartRequest(context.Background(), initiator, chid, incoming); err != nil { - return nil, xerrors.Errorf("restart request for channel %s failed validation: %w", chid, err) - } - - stor, err := incoming.Selector() - if err != nil { - return nil, err - } - - voucher, result, err := m.validateVoucher(true, chid, initiator, incoming, incoming.IsPull(), incoming.BaseCid(), stor) - if err != nil && err != datatransfer.ErrPause { - return result, xerrors.Errorf("failed to validate voucher: %w", err) - } - voucherErr := err - - if result != nil { - err := m.channels.NewVoucherResult(chid, result) - if err != nil { - return result, err - } - } - if err := m.channels.Restart(chid); err != nil { - return result, xerrors.Errorf("failed to restart channel %s: %w", chid, err) - } - processor, has := m.transportConfigurers.Processor(voucher.Type()) - if has { - transportConfigurer := processor.(datatransfer.TransportConfigurer) - transportConfigurer(chid, voucher, m.transport) - } - m.dataTransferNetwork.Protect(initiator, chid.String()) - if voucherErr == datatransfer.ErrPause { - err := m.channels.PauseResponder(chid) - if err != nil { - return result, err - } - } - return result, voucherErr -} - -func (m *manager) acceptRequest(chid datatransfer.ChannelID, incoming datatransfer.Request) (datatransfer.VoucherResult, error) { +// channelCompleted is called +// - by the requester when all data for a transfer has been received +// - by the responder when all data for a transfer has been sent +func (m *manager) channelCompleted(chid datatransfer.ChannelID, success bool, errorMessage string) error { - stor, err := incoming.Selector() + // read the channel state + chst, err := m.channels.GetByID(context.TODO(), chid) if err != nil { - return nil, err - } - - voucher, result, err := m.validateVoucher(false, chid, chid.Initiator, incoming, incoming.IsPull(), incoming.BaseCid(), stor) - if err != nil && err != datatransfer.ErrPause { - return result, err - } - voucherErr := err - - var dataSender, dataReceiver peer.ID - if incoming.IsPull() { - dataSender = m.peerID - dataReceiver = chid.Initiator - } else { - dataSender = chid.Initiator - dataReceiver = m.peerID + return err } - log.Infow("data-transfer request validated, will create & start tracking channel", "channelID", chid, "payloadCid", incoming.BaseCid()) - _, err = m.channels.CreateNew(m.peerID, incoming.TransferID(), incoming.BaseCid(), stor, voucher, chid.Initiator, dataSender, dataReceiver) - if err != nil { - log.Errorw("failed to create and start tracking channel", "channelID", chid, "err", err) - return result, err - } - log.Debugw("successfully created and started tracking channel", "channelID", chid) - if result != nil { - err := m.channels.NewVoucherResult(chid, result) - if err != nil { - return result, err - } - } - if err := m.channels.Accept(chid); err != nil { - return result, err - } - processor, has := m.transportConfigurers.Processor(voucher.Type()) - if has { - transportConfigurer := processor.(datatransfer.TransportConfigurer) - transportConfigurer(chid, voucher, m.transport) - } - m.dataTransferNetwork.Protect(chid.Initiator, chid.String()) - if voucherErr == datatransfer.ErrPause { - err := m.channels.PauseResponder(chid) - if err != nil { - return result, err + // If the transferred errored on completion + if !success { + // send an error, but only if we haven't already errored/finished transfer already for some reason + if !chst.Status().TransferComplete() { + err := fmt.Errorf("data transfer channel %s failed to transfer data: %s", chid, errorMessage) + log.Warnf(err.Error()) + return m.channels.Error(chid, err) } - } - return result, voucherErr -} - -// validateVoucher converts a voucher in an incoming message to its appropriate -// voucher struct, then runs the validator and returns the results. -// returns error if: -// * reading voucher fails -// * deserialization of selector fails -// * validation fails -func (m *manager) validateVoucher( - isRestart bool, - chid datatransfer.ChannelID, - sender peer.ID, - incoming datatransfer.Request, - isPull bool, - baseCid cid.Cid, - stor ipld.Node, -) (datatransfer.Voucher, datatransfer.VoucherResult, error) { - vouch, err := m.decodeVoucher(incoming, m.validatedTypes) - if err != nil { - return nil, nil, err - } - var validatorFunc func(bool, datatransfer.ChannelID, peer.ID, datatransfer.Voucher, cid.Cid, ipld.Node) (datatransfer.VoucherResult, error) - processor, _ := m.validatedTypes.Processor(vouch.Type()) - validator := processor.(datatransfer.RequestValidator) - if isPull { - validatorFunc = validator.ValidatePull - } else { - validatorFunc = validator.ValidatePush + return nil } - result, err := validatorFunc(isRestart, chid, sender, vouch, baseCid, stor) - return vouch, result, err -} - -// revalidateVoucher converts a voucher in an incoming message to its appropriate -// voucher struct, then runs the revalidator and returns the results. -// returns error if: -// * reading voucher fails -// * deserialization of selector fails -// * validation fails -func (m *manager) revalidateVoucher(chid datatransfer.ChannelID, - incoming datatransfer.Request) (datatransfer.Voucher, datatransfer.VoucherResult, error) { - vouch, err := m.decodeVoucher(incoming, m.revalidators) - if err != nil { - return nil, nil, err + // if the channel was initiated by this node, simply record the transfer being finished + if chid.Initiator == m.peerID { + log.Infof("channel %s: transfer initiated by local node is complete", chid) + return m.channels.FinishTransfer(chid) } - processor, _ := m.revalidators.Processor(vouch.Type()) - validator := processor.(datatransfer.Revalidator) - result, err := validator.Revalidate(chid, vouch) - return vouch, result, err -} + // otherwise, process as responder + log.Infow("received OnChannelCompleted, will send completion message to initiator", "chid", chid) -func (m *manager) processUpdateVoucher(chid datatransfer.ChannelID, request datatransfer.Request) (datatransfer.Response, error) { - vouch, result, voucherErr := m.revalidateVoucher(chid, request) - if vouch != nil { - err := m.channels.NewVoucher(chid, vouch) - if err != nil { - return nil, err - } + // generate and send the final status message + msg := message.CompleteResponse(chst.TransferID(), true, chst.RequiresFinalization(), nil) + log.Infow("sending completion message to initiator", "chid", chid) + ctx, _ := m.spansIndex.SpanForChannel(context.Background(), chid) + if err := m.transport.SendMessage(ctx, chid, msg); err != nil { + err := xerrors.Errorf("channel %s: failed to send completion message to initiator: %w", chid, err) + log.Warnw("failed to send completion message to initiator", "chid", chid, "err", err) + return m.channels.SendMessageError(chid, err) } - return m.processRevalidationResult(chid, result, voucherErr) -} + log.Infow("successfully sent completion message to initiator", "chid", chid) -func (m *manager) revalidationResponse(chid datatransfer.ChannelID, result datatransfer.VoucherResult, resultErr error) (datatransfer.Response, error) { - chst, err := m.channels.GetByID(context.TODO(), chid) - if err != nil { - return nil, err - } - if chst.Status() == datatransfer.Finalizing { - return m.completeResponse(resultErr, chid.ID, result) + // set the channel state based on whether its paused final settlement + if chst.RequiresFinalization() { + return m.channels.BeginFinalizing(chid) } - return m.response(false, false, resultErr, chid.ID, result) + return m.channels.Complete(chid) } -func (m *manager) processRevalidationResult(chid datatransfer.ChannelID, result datatransfer.VoucherResult, resultErr error) (datatransfer.Response, error) { - vresMessage, err := m.revalidationResponse(chid, result, resultErr) - - if err != nil { - return nil, err - } - if result != nil { - err := m.channels.NewVoucherResult(chid, result) - if err != nil { - return nil, err - } - } - - if resultErr == nil { - return vresMessage, nil - } - - if resultErr == datatransfer.ErrPause { - err := m.pause(chid) - if err != nil { - return nil, err - } - return vresMessage, datatransfer.ErrPause +func (m *manager) restartExistingChannelRequestReceived(chid datatransfer.ChannelID) error { + ctx, _ := m.spansIndex.SpanForChannel(context.TODO(), chid) + // validate channel exists -> in non-terminal state and that the sender matches + channel, err := m.channels.GetByID(context.TODO(), chid) + if err != nil || channel == nil { + // nothing to do here, we wont handle the request + return err } - if resultErr == datatransfer.ErrResume { - err = m.resume(chid) - if err != nil { - return nil, err - } - return vresMessage, datatransfer.ErrResume + // channel should NOT be terminated + if channels.IsChannelTerminated(channel.Status()) { + return fmt.Errorf("cannot restart channel %s: channel already terminated", chid) } - return vresMessage, resultErr -} -func (m *manager) completeMessage(chid datatransfer.ChannelID) (datatransfer.Response, error) { - var result datatransfer.VoucherResult - var resultErr error - var handled bool - _ = m.revalidators.Each(func(_ datatransfer.TypeIdentifier, _ encoding.Decoder, processor registry.Processor) error { - revalidator := processor.(datatransfer.Revalidator) - handled, result, resultErr = revalidator.OnComplete(chid) - if handled { - return errors.New("stop processing") - } - return nil - }) - if result != nil { - err := m.channels.NewVoucherResult(chid, result) - if err != nil { - return nil, err - } + if err := m.openRestartChannel(ctx, channel); err != nil { + return fmt.Errorf("failed to open restart channel %s: %s", chid, err) } - return m.completeResponse(resultErr, chid.ID, result) + return nil } diff --git a/impl/impl.go b/impl/impl.go index c3f511d4..156869b6 100644 --- a/impl/impl.go +++ b/impl/impl.go @@ -10,8 +10,7 @@ import ( "github.com/ipfs/go-cid" "github.com/ipfs/go-datastore" logging "github.com/ipfs/go-log/v2" - "github.com/ipld/go-ipld-prime" - cidlink "github.com/ipld/go-ipld-prime/linking/cid" + "github.com/ipld/go-ipld-prime/datamodel" "github.com/libp2p/go-libp2p-core/peer" "go.opentelemetry.io/otel" "go.opentelemetry.io/otel/attribute" @@ -19,24 +18,20 @@ import ( "go.opentelemetry.io/otel/trace" "golang.org/x/xerrors" - datatransfer "github.com/filecoin-project/go-data-transfer" - "github.com/filecoin-project/go-data-transfer/channelmonitor" - "github.com/filecoin-project/go-data-transfer/channels" - "github.com/filecoin-project/go-data-transfer/encoding" - "github.com/filecoin-project/go-data-transfer/message" - "github.com/filecoin-project/go-data-transfer/network" - "github.com/filecoin-project/go-data-transfer/registry" - "github.com/filecoin-project/go-data-transfer/tracing" + datatransfer "github.com/filecoin-project/go-data-transfer/v2" + "github.com/filecoin-project/go-data-transfer/v2/channelmonitor" + "github.com/filecoin-project/go-data-transfer/v2/channels" + "github.com/filecoin-project/go-data-transfer/v2/message" + "github.com/filecoin-project/go-data-transfer/v2/message/types" + "github.com/filecoin-project/go-data-transfer/v2/registry" + "github.com/filecoin-project/go-data-transfer/v2/tracing" ) var log = logging.Logger("dt-impl") var cancelSendTimeout = 30 * time.Second type manager struct { - dataTransferNetwork network.DataTransferNetwork validatedTypes *registry.Registry - resultTypes *registry.Registry - revalidators *registry.Registry transportConfigurers *registry.Registry pubSub *pubsub.PubSub readySub *pubsub.PubSub @@ -92,22 +87,19 @@ func ChannelRestartConfig(cfg channelmonitor.Config) DataTransferOption { } // NewDataTransfer initializes a new instance of a data transfer manager -func NewDataTransfer(ds datastore.Batching, dataTransferNetwork network.DataTransferNetwork, transport datatransfer.Transport, options ...DataTransferOption) (datatransfer.Manager, error) { +func NewDataTransfer(ds datastore.Batching, peerID peer.ID, transport datatransfer.Transport, options ...DataTransferOption) (datatransfer.Manager, error) { m := &manager{ - dataTransferNetwork: dataTransferNetwork, validatedTypes: registry.NewRegistry(), - resultTypes: registry.NewRegistry(), - revalidators: registry.NewRegistry(), transportConfigurers: registry.NewRegistry(), pubSub: pubsub.New(dispatcher), readySub: pubsub.New(readyDispatcher), - peerID: dataTransferNetwork.ID(), + peerID: peerID, transport: transport, transferIDGen: newTimeCounter(), spansIndex: tracing.NewSpansIndex(), } - channels, err := channels.New(ds, m.notifier, m.voucherDecoder, m.resultTypes.Decoder, &channelEnvironment{m}, dataTransferNetwork.ID()) + channels, err := channels.New(ds, m.notifier, &channelEnvironment{m}, peerID) if err != nil { return nil, err } @@ -125,14 +117,6 @@ func NewDataTransfer(ds datastore.Batching, dataTransferNetwork network.DataTran return m, nil } -func (m *manager) voucherDecoder(voucherType datatransfer.TypeIdentifier) (encoding.Decoder, bool) { - decoder, has := m.validatedTypes.Decoder(voucherType) - if !has { - return m.revalidators.Decoder(voucherType) - } - return decoder, true -} - func (m *manager) notifier(evt datatransfer.Event, chst datatransfer.ChannelState) { err := m.pubSub.Publish(internalEvent{evt, chst}) if err != nil { @@ -155,8 +139,6 @@ func (m *manager) Start(ctx context.Context) error { } }() - dtReceiver := &receiver{m} - m.dataTransferNetwork.SetDelegate(dtReceiver) return m.transport.SetEventHandler(m) } @@ -178,7 +160,7 @@ func (m *manager) Stop(ctx context.Context) error { // * voucher type does not implement voucher // * there is a voucher type registered with an identical identifier // * voucherType's Kind is not reflect.Ptr -func (m *manager) RegisterVoucherType(voucherType datatransfer.Voucher, validator datatransfer.RequestValidator) error { +func (m *manager) RegisterVoucherType(voucherType datatransfer.TypeIdentifier, validator datatransfer.RequestValidator) error { err := m.validatedTypes.Register(voucherType, validator) if err != nil { return xerrors.Errorf("error registering voucher type: %w", err) @@ -188,7 +170,7 @@ func (m *manager) RegisterVoucherType(voucherType datatransfer.Voucher, validato // OpenPushDataChannel opens a data transfer that will send data to the recipient peer and // transfer parts of the piece that match the selector -func (m *manager) OpenPushDataChannel(ctx context.Context, requestTo peer.ID, voucher datatransfer.Voucher, baseCid cid.Cid, selector ipld.Node) (datatransfer.ChannelID, error) { +func (m *manager) OpenPushDataChannel(ctx context.Context, requestTo peer.ID, voucher datatransfer.TypedVoucher, baseCid cid.Cid, selector datamodel.Node) (datatransfer.ChannelID, error) { log.Infof("open push channel to %s with base cid %s", requestTo, baseCid) req, err := m.newRequest(ctx, selector, false, voucher, baseCid, requestTo) @@ -196,20 +178,25 @@ func (m *manager) OpenPushDataChannel(ctx context.Context, requestTo peer.ID, vo return datatransfer.ChannelID{}, err } - chid, err := m.channels.CreateNew(m.peerID, req.TransferID(), baseCid, selector, voucher, + chid, channel, err := m.channels.CreateNew(m.peerID, req.TransferID(), baseCid, selector, voucher, m.peerID, m.peerID, requestTo) // initiator = us, sender = us, receiver = them if err != nil { return chid, err } + return chid, m.openChannel(ctx, channel, req) +} + +func (m *manager) openChannel(ctx context.Context, channel datatransfer.Channel, request datatransfer.Request) error { + chid := channel.ChannelID() + voucher := channel.Voucher() ctx, span := m.spansIndex.SpanForChannel(ctx, chid) - processor, has := m.transportConfigurers.Processor(voucher.Type()) + processor, has := m.transportConfigurers.Processor(voucher.Type) if has { transportConfigurer := processor.(datatransfer.TransportConfigurer) transportConfigurer(chid, voucher, m.transport) } - m.dataTransferNetwork.Protect(requestTo, chid.String()) - monitoredChan := m.channelMonitor.AddPushChannel(chid) - if err := m.dataTransferNetwork.SendMessage(ctx, requestTo, req); err != nil { + monitoredChan := m.channelMonitor.AddChannel(chid, channel.IsPull()) + if err := m.transport.OpenChannel(ctx, channel, request); err != nil { err = fmt.Errorf("Unable to send request: %w", err) span.RecordError(err) span.SetStatus(codes.Error, err.Error()) @@ -221,15 +208,15 @@ func (m *manager) OpenPushDataChannel(ctx context.Context, requestTo peer.ID, vo monitoredChan.Shutdown() } - return chid, err + return err } - return chid, nil + return nil } // OpenPullDataChannel opens a data transfer that will request data from the sending peer and // transfer parts of the piece that match the selector -func (m *manager) OpenPullDataChannel(ctx context.Context, requestTo peer.ID, voucher datatransfer.Voucher, baseCid cid.Cid, selector ipld.Node) (datatransfer.ChannelID, error) { +func (m *manager) OpenPullDataChannel(ctx context.Context, requestTo peer.ID, voucher datatransfer.TypedVoucher, baseCid cid.Cid, selector datamodel.Node) (datatransfer.ChannelID, error) { log.Infof("open pull channel to %s with base cid %s", requestTo, baseCid) req, err := m.newRequest(ctx, selector, true, voucher, baseCid, requestTo) @@ -237,45 +224,26 @@ func (m *manager) OpenPullDataChannel(ctx context.Context, requestTo peer.ID, vo return datatransfer.ChannelID{}, err } // initiator = us, sender = them, receiver = us - chid, err := m.channels.CreateNew(m.peerID, req.TransferID(), baseCid, selector, voucher, + chid, channel, err := m.channels.CreateNew(m.peerID, req.TransferID(), baseCid, selector, voucher, m.peerID, requestTo, m.peerID) if err != nil { return chid, err } - ctx, span := m.spansIndex.SpanForChannel(ctx, chid) - processor, has := m.transportConfigurers.Processor(voucher.Type()) - if has { - transportConfigurer := processor.(datatransfer.TransportConfigurer) - transportConfigurer(chid, voucher, m.transport) - } - m.dataTransferNetwork.Protect(requestTo, chid.String()) - monitoredChan := m.channelMonitor.AddPullChannel(chid) - if err := m.transport.OpenChannel(ctx, requestTo, chid, cidlink.Link{Cid: baseCid}, selector, nil, req); err != nil { - err = fmt.Errorf("Unable to send request: %w", err) - span.RecordError(err) - span.SetStatus(codes.Error, err.Error()) - _ = m.channels.Error(chid, err) - - // If pull channel monitoring is enabled, shutdown the monitor as it - // wasn't possible to start the data transfer - if monitoredChan != nil { - monitoredChan.Shutdown() - } - return chid, err - } - return chid, nil + return chid, m.openChannel(ctx, channel, req) } // SendVoucher sends an intermediate voucher as needed when the receiver sends a request for revalidation -func (m *manager) SendVoucher(ctx context.Context, channelID datatransfer.ChannelID, voucher datatransfer.Voucher) error { - chst, err := m.channels.GetByID(ctx, channelID) +func (m *manager) SendVoucher(ctx context.Context, channelID datatransfer.ChannelID, voucher datatransfer.TypedVoucher) error { + has, err := m.channels.HasChannel(channelID) + if !has { + return datatransfer.ErrChannelNotFound + } if err != nil { return err } ctx, _ = m.spansIndex.SpanForChannel(ctx, channelID) ctx, span := otel.Tracer("data-transfer").Start(ctx, "sendVoucher", trace.WithAttributes( attribute.String("channelID", channelID.String()), - attribute.String("voucherType", string(voucher.Type())), )) defer span.End() if channelID.Initiator != m.peerID { @@ -284,112 +252,191 @@ func (m *manager) SendVoucher(ctx context.Context, channelID datatransfer.Channe span.SetStatus(codes.Error, err.Error()) return err } - updateRequest, err := message.VoucherRequest(channelID.ID, voucher.Type(), voucher) + updateRequest := message.VoucherRequest(channelID.ID, &voucher) + if err := m.transport.SendMessage(ctx, channelID, updateRequest); err != nil { + err = fmt.Errorf("Unable to send request: %w", err) + _ = m.channels.SendMessageError(channelID, err) + span.RecordError(err) + span.SetStatus(codes.Error, err.Error()) + return err + } + return m.channels.NewVoucher(channelID, voucher) +} + +func (m *manager) SendVoucherResult(ctx context.Context, channelID datatransfer.ChannelID, voucherResult datatransfer.TypedVoucher) error { + chst, err := m.channels.GetByID(ctx, channelID) if err != nil { + return err + } + ctx, _ = m.spansIndex.SpanForChannel(ctx, channelID) + ctx, span := otel.Tracer("data-transfer").Start(ctx, "sendVoucherResult", trace.WithAttributes( + attribute.String("channelID", channelID.String()), + )) + defer span.End() + if channelID.Initiator == m.peerID { + err := errors.New("cannot send voucher result for request we initiated") span.RecordError(err) span.SetStatus(codes.Error, err.Error()) return err } - if err := m.dataTransferNetwork.SendMessage(ctx, chst.OtherPeer(), updateRequest); err != nil { + + var updateResponse datatransfer.Response + if chst.Status().InFinalization() { + updateResponse = message.CompleteResponse(channelID.ID, chst.Status().IsAccepted(), chst.ResponderPaused(), &voucherResult) + } else { + updateResponse = message.VoucherResultResponse(channelID.ID, chst.Status().IsAccepted(), chst.ResponderPaused(), &voucherResult) + } + + if err := m.transport.SendMessage(ctx, channelID, updateResponse); err != nil { err = fmt.Errorf("Unable to send request: %w", err) - _ = m.OnRequestDisconnected(channelID, err) + _ = m.channels.SendMessageError(channelID, err) span.RecordError(err) span.SetStatus(codes.Error, err.Error()) return err } - return m.channels.NewVoucher(channelID, voucher) + return m.channels.NewVoucherResult(channelID, voucherResult) } -// close an open channel (effectively a cancel) -func (m *manager) CloseDataTransferChannel(ctx context.Context, chid datatransfer.ChannelID) error { - log.Infof("close channel %s", chid) +func (m *manager) UpdateValidationStatus(ctx context.Context, chid datatransfer.ChannelID, result datatransfer.ValidationResult) error { + ctx, _ = m.spansIndex.SpanForChannel(ctx, chid) + ctx, span := otel.Tracer("data-transfer").Start(ctx, "updateValidationStatus", trace.WithAttributes( + attribute.String("channelID", chid.String()), + )) + err := m.updateValidationStatus(ctx, chid, result) + if err != nil { + span.RecordError(err) + span.SetStatus(codes.Error, err.Error()) + } + span.End() + return err +} - chst, err := m.channels.GetByID(ctx, chid) +// updateValidationStatus is the implementation of the public method, which wraps this private method +// in a trace +func (m *manager) updateValidationStatus(ctx context.Context, chid datatransfer.ChannelID, result datatransfer.ValidationResult) error { + + // first check if we are the responder -- only the responder can call UpdateValidationStatus + if chid.Initiator == m.peerID { + err := errors.New("cannot send voucher result for request we initiated") + return err + } + + // read the channel state + chst, err := m.channels.GetByID(context.TODO(), chid) + if err != nil { + return err + } + + // dispatch channel events and generate a response message + err = m.processValidationUpdate(ctx, chst, result) if err != nil { return err } + + // generate a response message + messageType := types.VoucherResultMessage + if chst.Status() == datatransfer.Finalizing { + messageType = types.CompleteMessage + } + response := message.ValidationResultResponse(messageType, chid.ID, result, err, + result.LeaveRequestPaused(chst)) + + // dispatch transport updates + return m.transport.ChannelUpdated(ctx, chid, response) +} + +func (m *manager) processValidationUpdate(ctx context.Context, chst datatransfer.ChannelState, result datatransfer.ValidationResult) error { + // if the request is now rejected, error the channel + if !result.Accepted { + return m.recordRejectedValidationEvents(chst.ChannelID(), result) + } + return m.recordAcceptedValidationEvents(chst, result) + +} + +// close an open channel (effectively a cancel) +func (m *manager) CloseDataTransferChannel(ctx context.Context, chid datatransfer.ChannelID) error { + log.Infof("close channel %s", chid) + ctx, _ = m.spansIndex.SpanForChannel(ctx, chid) ctx, span := otel.Tracer("data-transfer").Start(ctx, "closeChannel", trace.WithAttributes( attribute.String("channelID", chid.String()), )) defer span.End() - // Close the channel on the local transport - err = m.transport.CloseChannel(ctx, chid) + + err := m.closeChannel(ctx, chid) if err != nil { span.RecordError(err) span.SetStatus(codes.Error, err.Error()) log.Warnf("unable to close channel %s: %s", chid, err) } + return err +} + +func (m *manager) closeChannel(ctx context.Context, chid datatransfer.ChannelID) error { + // Fire a cancel event + err := m.channels.Cancel(chid) + if err != nil { + return xerrors.Errorf("unable to send cancel to channel FSM: %w", err) + } + + // Close the channel on the local transport + err = m.transport.ChannelUpdated(ctx, chid, nil) // Send a cancel message to the remote peer async go func() { sctx, cancel := context.WithTimeout(context.Background(), cancelSendTimeout) defer cancel() - log.Infof("%s: sending cancel channel to %s for channel %s", m.peerID, chst.OtherPeer(), chid) - err = m.dataTransferNetwork.SendMessage(sctx, chst.OtherPeer(), m.cancelMessage(chid)) + log.Infof("%s: sending cancel channel to %s for channel %s", m.peerID, m.otherPeer(chid), chid) + err = m.transport.SendMessage(sctx, chid, m.cancelMessage(chid)) if err != nil { err = fmt.Errorf("unable to send cancel message for channel %s to peer %s: %w", chid, m.peerID, err) - _ = m.OnRequestDisconnected(chid, err) log.Warn(err) } }() - // Fire a cancel event - fsmerr := m.channels.Cancel(chid) - if fsmerr != nil { - return xerrors.Errorf("unable to send cancel to channel FSM: %w", fsmerr) - } - - return nil -} - -// ConnectTo opens a connection to a peer on the data-transfer protocol, -// retrying if necessary -func (m *manager) ConnectTo(ctx context.Context, p peer.ID) error { - return m.dataTransferNetwork.ConnectWithRetry(ctx, p) + return err } // close an open channel and fire an error event func (m *manager) CloseDataTransferChannelWithError(ctx context.Context, chid datatransfer.ChannelID, cherr error) error { log.Infof("close channel %s with error %s", chid, cherr) - chst, err := m.channels.GetByID(ctx, chid) - if err != nil { - return err - } ctx, _ = m.spansIndex.SpanForChannel(ctx, chid) ctx, span := otel.Tracer("data-transfer").Start(ctx, "closeChannel", trace.WithAttributes( attribute.String("channelID", chid.String()), )) defer span.End() - // Cancel the channel on the local transport - err = m.transport.CloseChannel(ctx, chid) + err := m.closeChannelWithError(ctx, chid, cherr) if err != nil { + span.RecordError(err) + span.SetStatus(codes.Error, err.Error()) log.Warnf("unable to close channel %s: %s", chid, err) } + return err +} - // Try to send a cancel message to the remote peer. It's quite likely - // we aren't able to send the message to the peer because the channel - // is already in an error state, which is probably because of connection - // issues, so if we cant send the message just log a warning. - log.Infof("%s: sending cancel channel to %s for channel %s", m.peerID, chst.OtherPeer(), chid) - err = m.dataTransferNetwork.SendMessage(ctx, chst.OtherPeer(), m.cancelMessage(chid)) - if err != nil { - // Just log a warning here because it's important that we fire the - // error event with the original error so that it doesn't get masked - // by subsequent errors. - log.Warnf("unable to send cancel message for channel %s to peer %s: %w", - chid, m.peerID, err) - } +func (m *manager) closeChannelWithError(ctx context.Context, chid datatransfer.ChannelID, cherr error) error { // Fire an error event - err = m.channels.Error(chid, cherr) - if err != nil { + if err := m.channels.Error(chid, cherr); err != nil { return xerrors.Errorf("unable to send error %s to channel FSM: %w", cherr, err) } + // Close transfport and try to send a cancel message to the remote peer. + // It's quite likely we aren't able to send the message to the peer because + // the channel is already in an error state, which is probably because of + // connection issues, so if we cant send the message just log a warning. + log.Infof("%s: sending cancel channel to %s for channel %s", m.peerID, m.otherPeer(chid), chid) + + if err := m.transport.ChannelUpdated(ctx, chid, m.cancelMessage(chid)); err != nil { + // Just log a warning here because it's important that we fire the + // error event with the original error so that it doesn't get masked + // by subsequent errors. + log.Warnf("unable to close channel %s: %s", chid, err) + } return nil } @@ -397,44 +444,44 @@ func (m *manager) CloseDataTransferChannelWithError(ctx context.Context, chid da func (m *manager) PauseDataTransferChannel(ctx context.Context, chid datatransfer.ChannelID) error { log.Infof("pause channel %s", chid) - pausable, ok := m.transport.(datatransfer.PauseableTransport) - if !ok { + if !m.transport.Capabilities().Pausable { return datatransfer.ErrUnsupported } ctx, _ = m.spansIndex.SpanForChannel(ctx, chid) - err := pausable.PauseChannel(ctx, chid) - if err != nil { - log.Warnf("Error attempting to pause at transport level: %s", err.Error()) - } - - if err := m.dataTransferNetwork.SendMessage(ctx, chid.OtherParty(m.peerID), m.pauseMessage(chid)); err != nil { - err = fmt.Errorf("Unable to send pause message: %w", err) - _ = m.OnRequestDisconnected(chid, err) + // fire the pause + if err := m.pause(chid); err != nil { return err } - return m.pause(chid) + // update transport + if err := m.transport.ChannelUpdated(ctx, chid, m.pauseMessage(chid)); err != nil { + log.Warnf("Error attempting to pause at transport level: %s", err.Error()) + } + return nil } // resume a running data transfer channel func (m *manager) ResumeDataTransferChannel(ctx context.Context, chid datatransfer.ChannelID) error { log.Infof("resume channel %s", chid) - pausable, ok := m.transport.(datatransfer.PauseableTransport) - if !ok { + if !m.transport.Capabilities().Pausable { return datatransfer.ErrUnsupported } ctx, _ = m.spansIndex.SpanForChannel(ctx, chid) - err := pausable.ResumeChannel(ctx, m.resumeMessage(chid), chid) - if err != nil { - log.Warnf("Error attempting to resume at transport level: %s", err.Error()) + // fire the resume + if err := m.resume(chid); err != nil { + return err } - return m.resume(chid) + // update transport + if err := m.transport.ChannelUpdated(ctx, chid, m.resumeMessage(chid)); err != nil { + log.Warnf("Error attempting to resume at transport level: %s", err.Error()) + } + return nil } // get channel state @@ -461,32 +508,9 @@ func (m *manager) InProgressChannels(ctx context.Context) (map[datatransfer.Chan return m.channels.InProgress() } -// RegisterRevalidator registers a revalidator for the given voucher type -// Note: this is the voucher type used to revalidate. It can share a name -// with the initial validator type and CAN be the same type, or a different type. -// The revalidator can simply be the sampe as the original request validator, -// or a different validator that satisfies the revalidator interface. -func (m *manager) RegisterRevalidator(voucherType datatransfer.Voucher, revalidator datatransfer.Revalidator) error { - err := m.revalidators.Register(voucherType, revalidator) - if err != nil { - return xerrors.Errorf("error registering revalidator type: %w", err) - } - return nil -} - -// RegisterVoucherResultType allows deserialization of a voucher result, -// so that a listener can read the metadata -func (m *manager) RegisterVoucherResultType(resultType datatransfer.VoucherResult) error { - err := m.resultTypes.Register(resultType, nil) - if err != nil { - return xerrors.Errorf("error registering voucher type: %w", err) - } - return nil -} - // RegisterTransportConfigurer registers the given transport configurer to be run on requests with the given voucher // type -func (m *manager) RegisterTransportConfigurer(voucherType datatransfer.Voucher, configurer datatransfer.TransportConfigurer) error { +func (m *manager) RegisterTransportConfigurer(voucherType datatransfer.TypeIdentifier, configurer datatransfer.TransportConfigurer) error { err := m.transportConfigurers.Register(voucherType, configurer) if err != nil { return xerrors.Errorf("error registering transport configurer: %w", err) @@ -520,40 +544,10 @@ func (m *manager) RestartDataTransferChannel(ctx context.Context, chid datatrans )) defer span.End() // initiate restart - chType := m.channelDataTransferType(channel) - switch chType { - case ManagerPeerReceivePush: - return m.restartManagerPeerReceivePush(ctx, channel) - case ManagerPeerReceivePull: - return m.restartManagerPeerReceivePull(ctx, channel) - case ManagerPeerCreatePull: - return m.openPullRestartChannel(ctx, channel) - case ManagerPeerCreatePush: - return m.openPushRestartChannel(ctx, channel) - } - - return nil -} - -func (m *manager) channelDataTransferType(channel datatransfer.ChannelState) ChannelDataTransferType { - initiator := channel.ChannelID().Initiator - if channel.IsPull() { - // we created a pull channel - if initiator == m.peerID { - return ManagerPeerCreatePull - } - - // we received a pull channel - return ManagerPeerReceivePull - } - - // we created a push channel - if initiator == m.peerID { - return ManagerPeerCreatePush + if chid.Initiator == m.peerID { + return m.openRestartChannel(ctx, channel) } - - // we received a push channel - return ManagerPeerReceivePush + return m.restartManagerPeerReceive(ctx, channel) } func (m *manager) PeerID() peer.ID { diff --git a/impl/initiating_test.go b/impl/initiating_test.go index 50401014..a8da316e 100644 --- a/impl/initiating_test.go +++ b/impl/initiating_test.go @@ -9,17 +9,16 @@ import ( "github.com/ipfs/go-cid" "github.com/ipfs/go-datastore" dss "github.com/ipfs/go-datastore/sync" - "github.com/ipld/go-ipld-prime" - cidlink "github.com/ipld/go-ipld-prime/linking/cid" + "github.com/ipld/go-ipld-prime/datamodel" + "github.com/ipld/go-ipld-prime/node/basicnode" + selectorparse "github.com/ipld/go-ipld-prime/traversal/selector/parse" "github.com/libp2p/go-libp2p-core/peer" "github.com/stretchr/testify/require" - "golang.org/x/xerrors" - datatransfer "github.com/filecoin-project/go-data-transfer" - "github.com/filecoin-project/go-data-transfer/channels" - . "github.com/filecoin-project/go-data-transfer/impl" - "github.com/filecoin-project/go-data-transfer/message" - "github.com/filecoin-project/go-data-transfer/testutil" + datatransfer "github.com/filecoin-project/go-data-transfer/v2" + . "github.com/filecoin-project/go-data-transfer/v2/impl" + "github.com/filecoin-project/go-data-transfer/v2/message" + "github.com/filecoin-project/go-data-transfer/v2/testutil" ) func TestDataTransferInitiating(t *testing.T) { @@ -37,13 +36,13 @@ func TestDataTransferInitiating(t *testing.T) { require.NoError(t, err) require.NotEmpty(t, channelID) require.Equal(t, channelID.Initiator, h.peers[0]) - require.Len(t, h.transport.OpenedChannels, 0) - require.Len(t, h.network.SentMessages, 1) - messageReceived := h.network.SentMessages[0] - require.Equal(t, messageReceived.PeerID, h.peers[1]) - received := messageReceived.Message - require.True(t, received.IsRequest()) - receivedRequest, ok := received.(datatransfer.Request) + require.Len(t, h.transport.OpenedChannels, 1) + openChannel := h.transport.OpenedChannels[0] + require.Equal(t, openChannel.Channel.ChannelID(), channelID) + require.Equal(t, openChannel.Channel.Sender(), h.peers[0]) + require.Equal(t, openChannel.Channel.BaseCID(), h.baseCid) + require.Equal(t, openChannel.Channel.Selector(), h.stor) + receivedRequest, ok := openChannel.Message.(datatransfer.Request) require.True(t, ok) require.Equal(t, receivedRequest.TransferID(), channelID.ID) require.Equal(t, receivedRequest.BaseCid(), h.baseCid) @@ -52,7 +51,7 @@ func TestDataTransferInitiating(t *testing.T) { receivedSelector, err := receivedRequest.Selector() require.NoError(t, err) require.Equal(t, receivedSelector, h.stor) - testutil.AssertFakeDTVoucher(t, receivedRequest, h.voucher) + testutil.AssertTestVoucher(t, receivedRequest, h.voucher) }, }, "OpenPullDataTransfer": { @@ -62,13 +61,12 @@ func TestDataTransferInitiating(t *testing.T) { require.NoError(t, err) require.NotEmpty(t, channelID) require.Equal(t, channelID.Initiator, h.peers[0]) - require.Len(t, h.network.SentMessages, 0) require.Len(t, h.transport.OpenedChannels, 1) openChannel := h.transport.OpenedChannels[0] - require.Equal(t, openChannel.ChannelID, channelID) - require.Equal(t, openChannel.DataSender, h.peers[1]) - require.Equal(t, openChannel.Root, cidlink.Link{Cid: h.baseCid}) - require.Equal(t, openChannel.Selector, h.stor) + require.Equal(t, openChannel.Channel.ChannelID(), channelID) + require.Equal(t, openChannel.Channel.Sender(), h.peers[1]) + require.Equal(t, openChannel.Channel.BaseCID(), h.baseCid) + require.Equal(t, openChannel.Channel.Selector(), h.stor) require.True(t, openChannel.Message.IsRequest()) receivedRequest, ok := openChannel.Message.(datatransfer.Request) require.True(t, ok) @@ -79,13 +77,13 @@ func TestDataTransferInitiating(t *testing.T) { receivedSelector, err := receivedRequest.Selector() require.NoError(t, err) require.Equal(t, receivedSelector, h.stor) - testutil.AssertFakeDTVoucher(t, receivedRequest, h.voucher) + testutil.AssertTestVoucher(t, receivedRequest, h.voucher) }, }, "SendVoucher with no channel open": { verify: func(t *testing.T, h *harness) { err := h.dt.SendVoucher(h.ctx, datatransfer.ChannelID{Initiator: h.peers[1], Responder: h.peers[0], ID: 999999}, h.voucher) - require.True(t, xerrors.As(err, new(*channels.ErrNotFound))) + require.EqualError(t, err, datatransfer.ErrChannelNotFound.Error()) }, }, "SendVoucher with channel open, push succeeds": { @@ -93,17 +91,18 @@ func TestDataTransferInitiating(t *testing.T) { verify: func(t *testing.T, h *harness) { channelID, err := h.dt.OpenPushDataChannel(h.ctx, h.peers[1], h.voucher, h.baseCid, h.stor) require.NoError(t, err) - voucher := testutil.NewFakeDTType() + voucher := testutil.NewTestTypedVoucher() err = h.dt.SendVoucher(ctx, channelID, voucher) require.NoError(t, err) - require.Len(t, h.network.SentMessages, 2) - received := h.network.SentMessages[1].Message + require.Len(t, h.transport.OpenedChannels, 1) + require.Len(t, h.transport.MessagesSent, 1) + received := h.transport.MessagesSent[0].Message require.True(t, received.IsRequest()) receivedRequest, ok := received.(datatransfer.Request) require.True(t, ok) require.True(t, receivedRequest.IsVoucher()) require.False(t, receivedRequest.IsCancel()) - testutil.AssertFakeDTVoucher(t, receivedRequest, voucher) + testutil.AssertTestVoucher(t, receivedRequest, voucher) }, }, "SendVoucher with channel open, pull succeeds": { @@ -111,37 +110,27 @@ func TestDataTransferInitiating(t *testing.T) { verify: func(t *testing.T, h *harness) { channelID, err := h.dt.OpenPullDataChannel(h.ctx, h.peers[1], h.voucher, h.baseCid, h.stor) require.NoError(t, err) - voucher := testutil.NewFakeDTType() + voucher := testutil.NewTestTypedVoucher() err = h.dt.SendVoucher(ctx, channelID, voucher) require.NoError(t, err) require.Len(t, h.transport.OpenedChannels, 1) - require.Len(t, h.network.SentMessages, 1) - received := h.network.SentMessages[0].Message + require.Len(t, h.transport.MessagesSent, 1) + received := h.transport.MessagesSent[0].Message require.True(t, received.IsRequest()) receivedRequest, ok := received.(datatransfer.Request) require.True(t, ok) require.False(t, receivedRequest.IsCancel()) require.True(t, receivedRequest.IsVoucher()) - testutil.AssertFakeDTVoucher(t, receivedRequest, voucher) + testutil.AssertTestVoucher(t, receivedRequest, voucher) }, }, "reregister voucher type again errors": { verify: func(t *testing.T, h *harness) { - voucher := testutil.NewFakeDTType() sv := testutil.NewStubbedValidator() - err := h.dt.RegisterVoucherType(h.voucher, sv) + err := h.dt.RegisterVoucherType(h.voucher.Type, sv) require.NoError(t, err) - err = h.dt.RegisterVoucherType(voucher, sv) - require.EqualError(t, err, "error registering voucher type: identifier already registered: FakeDTType") - }, - }, - "reregister non pointer errors": { - verify: func(t *testing.T, h *harness) { - sv := testutil.NewStubbedValidator() - err := h.dt.RegisterVoucherType(h.voucher, sv) - require.NoError(t, err) - err = h.dt.RegisterVoucherType(testutil.FakeDTType{}, sv) - require.EqualError(t, err, "error registering voucher type: registering entry type FakeDTType: type must be a pointer") + err = h.dt.RegisterVoucherType(testutil.TestVoucherType, sv) + require.EqualError(t, err, "error registering voucher type: identifier already registered: TestVoucher") }, }, "success response": { @@ -150,8 +139,7 @@ func TestDataTransferInitiating(t *testing.T) { channelID, err := h.dt.OpenPushDataChannel(h.ctx, h.peers[1], h.voucher, h.baseCid, h.stor) require.NoError(t, err) require.NotEmpty(t, channelID) - response, err := message.NewResponse(channelID.ID, true, false, datatransfer.EmptyTypeIdentifier, nil) - require.NoError(t, err) + response := message.NewResponse(channelID.ID, true, false, nil) err = h.transport.EventHandler.OnResponseReceived(channelID, response) require.NoError(t, err) }, @@ -162,8 +150,7 @@ func TestDataTransferInitiating(t *testing.T) { channelID, err := h.dt.OpenPushDataChannel(h.ctx, h.peers[1], h.voucher, h.baseCid, h.stor) require.NoError(t, err) require.NotEmpty(t, channelID) - response, err := message.NewResponse(channelID.ID, true, false, h.voucherResult.Type(), h.voucherResult) - require.NoError(t, err) + response := message.NewResponse(channelID.ID, true, false, &h.voucherResult) err = h.transport.EventHandler.OnResponseReceived(channelID, response) require.NoError(t, err) }, @@ -174,16 +161,16 @@ func TestDataTransferInitiating(t *testing.T) { channelID, err := h.dt.OpenPushDataChannel(h.ctx, h.peers[1], h.voucher, h.baseCid, h.stor) require.NoError(t, err) require.NotEmpty(t, channelID) - response, err := message.NewResponse(channelID.ID, true, false, datatransfer.EmptyTypeIdentifier, nil) - require.NoError(t, err) + response := message.NewResponse(channelID.ID, true, false, nil) err = h.transport.EventHandler.OnResponseReceived(channelID, response) require.NoError(t, err) err = h.dt.PauseDataTransferChannel(h.ctx, channelID) require.NoError(t, err) - require.Len(t, h.transport.PausedChannels, 1) - require.Equal(t, h.transport.PausedChannels[0], channelID) - require.Len(t, h.network.SentMessages, 2) - pauseMessage := h.network.SentMessages[1].Message + require.Len(t, h.transport.ChannelsUpdated, 1) + require.Equal(t, h.transport.ChannelsUpdated[0], channelID) + require.Len(t, h.transport.OpenedChannels, 1) + require.Len(t, h.transport.MessagesSent, 1) + pauseMessage := h.transport.MessagesSent[0].Message require.True(t, pauseMessage.IsUpdate()) require.True(t, pauseMessage.IsPaused()) require.True(t, pauseMessage.IsRequest()) @@ -191,10 +178,11 @@ func TestDataTransferInitiating(t *testing.T) { require.Equal(t, pauseMessage.TransferID(), channelID.ID) err = h.dt.ResumeDataTransferChannel(h.ctx, channelID) require.NoError(t, err) - require.Len(t, h.transport.ResumedChannels, 1) - resumedChannel := h.transport.ResumedChannels[0] - require.Equal(t, resumedChannel.ChannelID, channelID) - resumeMessage := resumedChannel.Message + require.Len(t, h.transport.ChannelsUpdated, 2) + resumedChannel := h.transport.ChannelsUpdated[1] + require.Equal(t, resumedChannel, channelID) + require.Len(t, h.transport.MessagesSent, 2) + resumeMessage := h.transport.MessagesSent[1].Message require.True(t, resumeMessage.IsUpdate()) require.False(t, resumeMessage.IsPaused()) require.True(t, resumeMessage.IsRequest()) @@ -210,13 +198,13 @@ func TestDataTransferInitiating(t *testing.T) { require.NotEmpty(t, channelID) err = h.dt.CloseDataTransferChannel(h.ctx, channelID) require.NoError(t, err) - require.Len(t, h.transport.ClosedChannels, 1) - require.Equal(t, h.transport.ClosedChannels[0], channelID) + require.Len(t, h.transport.ChannelsUpdated, 1) + require.Equal(t, h.transport.ChannelsUpdated[0], channelID) require.Eventually(t, func() bool { - return len(h.network.SentMessages) == 2 + return len(h.transport.MessagesSent) == 1 }, 5*time.Second, 200*time.Millisecond) - cancelMessage := h.network.SentMessages[1].Message + cancelMessage := h.transport.MessagesSent[0].Message require.False(t, cancelMessage.IsUpdate()) require.False(t, cancelMessage.IsPaused()) require.True(t, cancelMessage.IsRequest()) @@ -230,16 +218,15 @@ func TestDataTransferInitiating(t *testing.T) { channelID, err := h.dt.OpenPullDataChannel(h.ctx, h.peers[1], h.voucher, h.baseCid, h.stor) require.NoError(t, err) require.NotEmpty(t, channelID) - response, err := message.NewResponse(channelID.ID, true, false, datatransfer.EmptyTypeIdentifier, nil) - require.NoError(t, err) + response := message.NewResponse(channelID.ID, true, false, nil) err = h.transport.EventHandler.OnResponseReceived(channelID, response) require.NoError(t, err) err = h.dt.PauseDataTransferChannel(h.ctx, channelID) require.NoError(t, err) - require.Len(t, h.transport.PausedChannels, 1) - require.Equal(t, h.transport.PausedChannels[0], channelID) - require.Len(t, h.network.SentMessages, 1) - pauseMessage := h.network.SentMessages[0].Message + require.Len(t, h.transport.ChannelsUpdated, 1) + require.Len(t, h.transport.OpenedChannels, 1) + require.Len(t, h.transport.MessagesSent, 1) + pauseMessage := h.transport.MessagesSent[0].Message require.True(t, pauseMessage.IsUpdate()) require.True(t, pauseMessage.IsPaused()) require.True(t, pauseMessage.IsRequest()) @@ -247,10 +234,11 @@ func TestDataTransferInitiating(t *testing.T) { require.Equal(t, pauseMessage.TransferID(), channelID.ID) err = h.dt.ResumeDataTransferChannel(h.ctx, channelID) require.NoError(t, err) - require.Len(t, h.transport.ResumedChannels, 1) - resumedChannel := h.transport.ResumedChannels[0] - require.Equal(t, resumedChannel.ChannelID, channelID) - resumeMessage := resumedChannel.Message + require.Len(t, h.transport.ChannelsUpdated, 2) + resumedChannel := h.transport.ChannelsUpdated[1] + require.Equal(t, resumedChannel, channelID) + require.Len(t, h.transport.MessagesSent, 2) + resumeMessage := h.transport.MessagesSent[1].Message require.True(t, resumeMessage.IsUpdate()) require.False(t, resumeMessage.IsPaused()) require.True(t, resumeMessage.IsRequest()) @@ -266,14 +254,14 @@ func TestDataTransferInitiating(t *testing.T) { require.NotEmpty(t, channelID) err = h.dt.CloseDataTransferChannel(h.ctx, channelID) require.NoError(t, err) - require.Len(t, h.transport.ClosedChannels, 1) - require.Equal(t, h.transport.ClosedChannels[0], channelID) + require.Len(t, h.transport.ChannelsUpdated, 1) + require.Equal(t, h.transport.ChannelsUpdated[0], channelID) require.Eventually(t, func() bool { - return len(h.network.SentMessages) == 1 + return len(h.transport.MessagesSent) == 1 }, 5*time.Second, 200*time.Millisecond) - cancelMessage := h.network.SentMessages[0].Message + cancelMessage := h.transport.MessagesSent[0].Message require.False(t, cancelMessage.IsUpdate()) require.False(t, cancelMessage.IsPaused()) require.True(t, cancelMessage.IsRequest()) @@ -284,7 +272,7 @@ func TestDataTransferInitiating(t *testing.T) { "customizing push transfer": { expectedEvents: []datatransfer.EventCode{datatransfer.Open}, verify: func(t *testing.T, h *harness) { - err := h.dt.RegisterTransportConfigurer(h.voucher, func(channelID datatransfer.ChannelID, voucher datatransfer.Voucher, transport datatransfer.Transport) { + err := h.dt.RegisterTransportConfigurer(h.voucher.Type, func(channelID datatransfer.ChannelID, voucher datatransfer.TypedVoucher, transport datatransfer.Transport) { ft, ok := transport.(*testutil.FakeTransport) if !ok { return @@ -304,7 +292,7 @@ func TestDataTransferInitiating(t *testing.T) { "customizing pull transfer": { expectedEvents: []datatransfer.EventCode{datatransfer.Open}, verify: func(t *testing.T, h *harness) { - err := h.dt.RegisterTransportConfigurer(h.voucher, func(channelID datatransfer.ChannelID, voucher datatransfer.Voucher, transport datatransfer.Transport) { + err := h.dt.RegisterTransportConfigurer(h.voucher.Type, func(channelID datatransfer.ChannelID, voucher datatransfer.TypedVoucher, transport datatransfer.Transport) { ft, ok := transport.(*testutil.FakeTransport) if !ok { return @@ -331,10 +319,9 @@ func TestDataTransferInitiating(t *testing.T) { defer cancel() h.ctx = ctx h.peers = testutil.GeneratePeers(2) - h.network = testutil.NewFakeNetwork(h.peers[0]) h.transport = testutil.NewFakeTransport() h.ds = dss.MutexWrap(datastore.NewMapDatastore()) - dt, err := NewDataTransfer(h.ds, h.network, h.transport, verify.options...) + dt, err := NewDataTransfer(h.ds, h.peers[0], h.transport, verify.options...) require.NoError(t, err) testutil.StartAndWaitForReady(ctx, t, dt) h.dt = dt @@ -343,10 +330,9 @@ func TestDataTransferInitiating(t *testing.T) { events: make(chan datatransfer.EventCode, len(verify.expectedEvents)), } ev.setup(t, dt) - h.stor = testutil.AllSelector() - h.voucher = testutil.NewFakeDTType() - h.voucherResult = testutil.NewFakeDTType() - err = h.dt.RegisterVoucherResultType(h.voucherResult) + h.stor = selectorparse.CommonSelector_ExploreAllRecursively + h.voucher = testutil.NewTestTypedVoucher() + h.voucherResult = testutil.NewTestTypedVoucher() require.NoError(t, err) h.baseCid = testutil.GenerateCids(1)[0] verify.verify(t, h) @@ -363,36 +349,34 @@ func TestDataTransferRestartInitiating(t *testing.T) { verify func(t *testing.T, h *harness) }{ "RestartDataTransferChannel: Manager Peer Create Pull Restart works": { - expectedEvents: []datatransfer.EventCode{datatransfer.Open, datatransfer.DataReceivedProgress, datatransfer.DataReceived, datatransfer.DataReceivedProgress, datatransfer.DataReceived}, + expectedEvents: []datatransfer.EventCode{datatransfer.Open, datatransfer.TransferInitiated, datatransfer.DataReceivedProgress, datatransfer.DataReceived, datatransfer.DataReceivedProgress, datatransfer.DataReceived}, verify: func(t *testing.T, h *harness) { // open a pull channel channelID, err := h.dt.OpenPullDataChannel(h.ctx, h.peers[1], h.voucher, h.baseCid, h.stor) require.NoError(t, err) require.NotEmpty(t, channelID) require.Len(t, h.transport.OpenedChannels, 1) - require.Len(t, h.network.SentMessages, 0) // some cids should already be received - testCids := testutil.GenerateCids(2) ev, ok := h.dt.(datatransfer.EventsHandler) require.True(t, ok) - require.NoError(t, ev.OnDataReceived(channelID, cidlink.Link{Cid: testCids[0]}, 12345, 1, true)) - require.NoError(t, ev.OnDataReceived(channelID, cidlink.Link{Cid: testCids[1]}, 12345, 2, true)) + ev.OnTransportEvent(channelID, datatransfer.TransportInitiatedTransfer{}) + ev.OnTransportEvent(channelID, datatransfer.TransportReceivedData{Size: 12345, Index: basicnode.NewInt(1)}) + ev.OnTransportEvent(channelID, datatransfer.TransportReceivedData{Size: 12345, Index: basicnode.NewInt(2)}) // restart that pull channel err = h.dt.RestartDataTransferChannel(ctx, channelID) require.NoError(t, err) - require.Len(t, h.transport.OpenedChannels, 2) - require.Len(t, h.network.SentMessages, 0) + require.Len(t, h.transport.RestartedChannels, 1) - openChannel := h.transport.OpenedChannels[1] - require.Equal(t, openChannel.ChannelID, channelID) - require.Equal(t, openChannel.DataSender, h.peers[1]) - require.Equal(t, openChannel.Root, cidlink.Link{Cid: h.baseCid}) - require.Equal(t, openChannel.Selector, h.stor) - require.True(t, openChannel.Message.IsRequest()) + restartedChannel := h.transport.RestartedChannels[0] + require.Equal(t, restartedChannel.Channel.ChannelID(), channelID) + require.Equal(t, restartedChannel.Channel.Sender(), h.peers[1]) + require.Equal(t, restartedChannel.Channel.BaseCID(), h.baseCid) + require.Equal(t, restartedChannel.Channel.Selector(), h.stor) + require.True(t, restartedChannel.Message.IsRequest()) - receivedRequest, ok := openChannel.Message.(datatransfer.Request) + receivedRequest := restartedChannel.Message require.True(t, ok) require.Equal(t, receivedRequest.TransferID(), channelID.ID) require.Equal(t, receivedRequest.BaseCid(), h.baseCid) @@ -405,7 +389,7 @@ func TestDataTransferRestartInitiating(t *testing.T) { receivedSelector, err := receivedRequest.Selector() require.NoError(t, err) require.Equal(t, receivedSelector, h.stor) - testutil.AssertFakeDTVoucher(t, receivedRequest, h.voucher) + testutil.AssertTestVoucher(t, receivedRequest, h.voucher) }, }, "RestartDataTransferChannel: Manager Peer Create Push Restart works": { @@ -415,22 +399,17 @@ func TestDataTransferRestartInitiating(t *testing.T) { channelID, err := h.dt.OpenPushDataChannel(h.ctx, h.peers[1], h.voucher, h.baseCid, h.stor) require.NoError(t, err) require.NotEmpty(t, channelID) - require.Len(t, h.transport.OpenedChannels, 0) - require.Len(t, h.network.SentMessages, 1) + require.Len(t, h.transport.OpenedChannels, 1) // restart that push channel err = h.dt.RestartDataTransferChannel(ctx, channelID) require.NoError(t, err) - require.Len(t, h.transport.OpenedChannels, 0) - require.Len(t, h.network.SentMessages, 2) + require.Len(t, h.transport.RestartedChannels, 1) // assert restart request is well formed - messageReceived := h.network.SentMessages[1] - require.Equal(t, messageReceived.PeerID, h.peers[1]) - received := messageReceived.Message - require.True(t, received.IsRequest()) - receivedRequest, ok := received.(datatransfer.Request) - require.True(t, ok) + restartedChannel := h.transport.RestartedChannels[0] + require.Equal(t, restartedChannel.Channel.ChannelID(), channelID) + receivedRequest := restartedChannel.Message require.Equal(t, receivedRequest.TransferID(), channelID.ID) require.Equal(t, receivedRequest.BaseCid(), h.baseCid) require.False(t, receivedRequest.IsCancel()) @@ -441,37 +420,38 @@ func TestDataTransferRestartInitiating(t *testing.T) { receivedSelector, err := receivedRequest.Selector() require.NoError(t, err) require.Equal(t, receivedSelector, h.stor) - testutil.AssertFakeDTVoucher(t, receivedRequest, h.voucher) + testutil.AssertTestVoucher(t, receivedRequest, h.voucher) }, }, "RestartDataTransferChannel: Manager Peer Receive Push Restart works ": { - expectedEvents: []datatransfer.EventCode{datatransfer.Open, datatransfer.Accept}, + expectedEvents: []datatransfer.EventCode{ + datatransfer.Open, + datatransfer.Accept, + }, verify: func(t *testing.T, h *harness) { ctx := context.Background() + + h.voucherValidator.ExpectSuccessPush() + h.voucherValidator.StubResult(datatransfer.ValidationResult{Accepted: true}) + // receive a push request - h.network.Delegate.ReceiveRequest(h.ctx, h.peers[1], h.pushRequest) - require.Len(t, h.transport.OpenedChannels, 1) - require.Len(t, h.network.SentMessages, 0) + chid := datatransfer.ChannelID{Initiator: h.peers[1], Responder: h.peers[0], ID: h.pushRequest.TransferID()} + h.transport.EventHandler.OnRequestReceived(chid, h.pushRequest) require.Len(t, h.voucherValidator.ValidationsReceived, 1) // restart the push request received above and validate it - chid := datatransfer.ChannelID{Initiator: h.peers[1], Responder: h.peers[0], ID: h.pushRequest.TransferID()} + h.voucherValidator.StubRestartResult(datatransfer.ValidationResult{Accepted: true}) require.NoError(t, h.dt.RestartDataTransferChannel(ctx, chid)) - require.Len(t, h.voucherValidator.ValidationsReceived, 2) - require.Len(t, h.transport.OpenedChannels, 1) - require.Len(t, h.network.SentMessages, 1) + require.Len(t, h.voucherValidator.RevalidationsReceived, 1) + require.Len(t, h.transport.MessagesSent, 1) // assert validation on restart - vmsg := h.voucherValidator.ValidationsReceived[1] - require.Equal(t, h.voucher, vmsg.Voucher) - require.False(t, vmsg.IsPull) - require.Equal(t, h.stor, vmsg.Selector) - require.Equal(t, h.baseCid, vmsg.BaseCid) - require.Equal(t, h.peers[1], vmsg.Other) + vmsg := h.voucherValidator.RevalidationsReceived[0] + require.Equal(t, channelID(h.id, h.peers), vmsg.ChannelID) // assert req was sent correctly - req := h.network.SentMessages[0] - require.Equal(t, req.PeerID, h.peers[1]) + req := h.transport.MessagesSent[0] + require.Equal(t, chid, req.ChannelID) received := req.Message require.True(t, received.IsRequest()) receivedRequest, ok := received.(datatransfer.Request) @@ -480,39 +460,37 @@ func TestDataTransferRestartInitiating(t *testing.T) { achId, err := receivedRequest.RestartChannelId() require.NoError(t, err) require.Equal(t, chid, achId) - - h.voucherValidator.ExpectSuccessPush() }, }, "RestartDataTransferChannel: Manager Peer Receive Pull Restart works ": { - expectedEvents: []datatransfer.EventCode{datatransfer.Open, datatransfer.Accept}, + expectedEvents: []datatransfer.EventCode{ + datatransfer.Open, + datatransfer.Accept, + }, verify: func(t *testing.T, h *harness) { ctx := context.Background() // receive a pull request - h.network.Delegate.ReceiveRequest(h.ctx, h.peers[1], h.pullRequest) - require.Len(t, h.transport.OpenedChannels, 0) - require.Len(t, h.network.SentMessages, 1) + h.voucherValidator.ExpectSuccessPull() + h.voucherValidator.StubResult(datatransfer.ValidationResult{Accepted: true}) + + chid := datatransfer.ChannelID{Initiator: h.peers[1], Responder: h.peers[0], ID: h.pushRequest.TransferID()} + h.transport.EventHandler.OnRequestReceived(chid, h.pullRequest) require.Len(t, h.voucherValidator.ValidationsReceived, 1) // restart the pull request received above - h.voucherValidator.ExpectSuccessPull() - chid := datatransfer.ChannelID{Initiator: h.peers[1], Responder: h.peers[0], ID: h.pullRequest.TransferID()} + h.voucherValidator.ExpectSuccessValidateRestart() + h.voucherValidator.StubRestartResult(datatransfer.ValidationResult{Accepted: true}) require.NoError(t, h.dt.RestartDataTransferChannel(ctx, chid)) - require.Len(t, h.transport.OpenedChannels, 0) - require.Len(t, h.network.SentMessages, 2) - require.Len(t, h.voucherValidator.ValidationsReceived, 2) + require.Len(t, h.voucherValidator.RevalidationsReceived, 1) + require.Len(t, h.transport.MessagesSent, 1) // assert validation on restart - vmsg := h.voucherValidator.ValidationsReceived[1] - require.Equal(t, h.voucher, vmsg.Voucher) - require.True(t, vmsg.IsPull) - require.Equal(t, h.stor, vmsg.Selector) - require.Equal(t, h.baseCid, vmsg.BaseCid) - require.Equal(t, h.peers[1], vmsg.Other) + vmsg := h.voucherValidator.RevalidationsReceived[0] + require.Equal(t, channelID(h.id, h.peers), vmsg.ChannelID) // assert req was sent correctly - req := h.network.SentMessages[1] - require.Equal(t, req.PeerID, h.peers[1]) + req := h.transport.MessagesSent[0] + require.Equal(t, chid, req.ChannelID) received := req.Message require.True(t, received.IsRequest()) receivedRequest, ok := received.(datatransfer.Request) @@ -524,35 +502,43 @@ func TestDataTransferRestartInitiating(t *testing.T) { }, }, "RestartDataTransferChannel: Manager Peer Receive Pull Restart fails if validation fails ": { - expectedEvents: []datatransfer.EventCode{datatransfer.Open, datatransfer.Accept}, + expectedEvents: []datatransfer.EventCode{ + datatransfer.Open, + datatransfer.Accept, + }, verify: func(t *testing.T, h *harness) { ctx := context.Background() // receive a pull request - h.network.Delegate.ReceiveRequest(h.ctx, h.peers[1], h.pullRequest) - require.Len(t, h.transport.OpenedChannels, 0) - require.Len(t, h.network.SentMessages, 1) + h.voucherValidator.ExpectSuccessPull() + h.voucherValidator.StubResult(datatransfer.ValidationResult{Accepted: true}) + chid := datatransfer.ChannelID{Initiator: h.peers[1], Responder: h.peers[0], ID: h.pushRequest.TransferID()} + h.transport.EventHandler.OnRequestReceived(chid, h.pullRequest) require.Len(t, h.voucherValidator.ValidationsReceived, 1) // restart the pull request received above - h.voucherValidator.ExpectErrorPull() - chid := datatransfer.ChannelID{Initiator: h.peers[1], Responder: h.peers[0], ID: h.pullRequest.TransferID()} - require.EqualError(t, h.dt.RestartDataTransferChannel(ctx, chid), "failed to restart channel, validation error: something went wrong") + h.voucherValidator.ExpectSuccessValidateRestart() + h.voucherValidator.StubRestartResult(datatransfer.ValidationResult{Accepted: false}) + require.EqualError(t, h.dt.RestartDataTransferChannel(ctx, chid), datatransfer.ErrRejected.Error()) }, }, "RestartDataTransferChannel: Manager Peer Receive Push Restart fails if validation fails ": { - expectedEvents: []datatransfer.EventCode{datatransfer.Open, datatransfer.Accept}, + expectedEvents: []datatransfer.EventCode{ + datatransfer.Open, + datatransfer.Accept, + }, verify: func(t *testing.T, h *harness) { ctx := context.Background() // receive a push request - h.network.Delegate.ReceiveRequest(h.ctx, h.peers[1], h.pushRequest) - require.Len(t, h.transport.OpenedChannels, 1) - require.Len(t, h.network.SentMessages, 0) + h.voucherValidator.ExpectSuccessPush() + h.voucherValidator.StubResult(datatransfer.ValidationResult{Accepted: true}) + chid := datatransfer.ChannelID{Initiator: h.peers[1], Responder: h.peers[0], ID: h.pushRequest.TransferID()} + h.transport.EventHandler.OnRequestReceived(chid, h.pushRequest) require.Len(t, h.voucherValidator.ValidationsReceived, 1) // restart the pull request received above - h.voucherValidator.ExpectErrorPush() - chid := datatransfer.ChannelID{Initiator: h.peers[1], Responder: h.peers[0], ID: h.pushRequest.TransferID()} - require.EqualError(t, h.dt.RestartDataTransferChannel(ctx, chid), "failed to restart channel, validation error: something went wrong") + h.voucherValidator.ExpectSuccessValidateRestart() + h.voucherValidator.StubRestartResult(datatransfer.ValidationResult{Accepted: false}) + require.EqualError(t, h.dt.RestartDataTransferChannel(ctx, chid), datatransfer.ErrRejected.Error()) }, }, "Fails if channel does not exist": { @@ -573,13 +559,12 @@ func TestDataTransferRestartInitiating(t *testing.T) { // create the harness h.ctx = ctx h.peers = testutil.GeneratePeers(2) - h.network = testutil.NewFakeNetwork(h.peers[0]) h.transport = testutil.NewFakeTransport() h.ds = dss.MutexWrap(datastore.NewMapDatastore()) h.voucherValidator = testutil.NewStubbedValidator() // setup data transfer`` - dt, err := NewDataTransfer(h.ds, h.network, h.transport) + dt, err := NewDataTransfer(h.ds, h.peers[0], h.transport) require.NoError(t, err) testutil.StartAndWaitForReady(ctx, t, dt) h.dt = dt @@ -592,18 +577,17 @@ func TestDataTransferRestartInitiating(t *testing.T) { ev.setup(t, dt) // setup voucher processing - h.stor = testutil.AllSelector() - h.voucher = testutil.NewFakeDTType() - require.NoError(t, h.dt.RegisterVoucherType(h.voucher, h.voucherValidator)) - h.voucherResult = testutil.NewFakeDTType() - err = h.dt.RegisterVoucherResultType(h.voucherResult) + h.stor = selectorparse.CommonSelector_ExploreAllRecursively + h.voucher = testutil.NewTestTypedVoucher() + require.NoError(t, h.dt.RegisterVoucherType(h.voucher.Type, h.voucherValidator)) + h.voucherResult = testutil.NewTestTypedVoucher() require.NoError(t, err) h.baseCid = testutil.GenerateCids(1)[0] h.id = datatransfer.TransferID(rand.Int31()) - h.pushRequest, err = message.NewRequest(h.id, false, false, h.voucher.Type(), h.voucher, h.baseCid, h.stor) + h.pushRequest, err = message.NewRequest(h.id, false, false, &h.voucher, h.baseCid, h.stor) require.NoError(t, err) - h.pullRequest, err = message.NewRequest(h.id, false, true, h.voucher.Type(), h.voucher, h.baseCid, h.stor) + h.pullRequest, err = message.NewRequest(h.id, false, true, &h.voucher, h.baseCid, h.stor) require.NoError(t, err) // run tests steps and verify @@ -617,14 +601,13 @@ func TestDataTransferRestartInitiating(t *testing.T) { type harness struct { ctx context.Context peers []peer.ID - network *testutil.FakeNetwork transport *testutil.FakeTransport ds datastore.Batching dt datatransfer.Manager voucherValidator *testutil.StubbedValidator - stor ipld.Node - voucher *testutil.FakeDTType - voucherResult *testutil.FakeDTType + stor datamodel.Node + voucher datatransfer.TypedVoucher + voucherResult datatransfer.TypedVoucher baseCid cid.Cid id datatransfer.TransferID @@ -640,12 +623,8 @@ type eventVerifier struct { func (e eventVerifier) setup(t *testing.T, dt datatransfer.Manager) { if len(e.expectedEvents) > 0 { received := 0 - max := len(e.expectedEvents) dt.SubscribeToEvents(func(evt datatransfer.Event, state datatransfer.ChannelState) { received++ - if received > max { - t.Fatalf("received too many events: %s", datatransfer.Events[evt.Code]) - } e.events <- evt.Code }) } @@ -662,6 +641,12 @@ func (e eventVerifier) verify(ctx context.Context, t *testing.T) { receivedEvents = append(receivedEvents, event) } } + timer := time.NewTimer(50 * time.Millisecond) + select { + case event := <-e.events: + t.Fatalf("received extra event: %s", datatransfer.Events[event]) + case <-timer.C: + } require.Equal(t, e.expectedEvents, receivedEvents) } } diff --git a/impl/receiver.go b/impl/receiver.go deleted file mode 100644 index 6f861704..00000000 --- a/impl/receiver.go +++ /dev/null @@ -1,192 +0,0 @@ -package impl - -import ( - "context" - - cidlink "github.com/ipld/go-ipld-prime/linking/cid" - "github.com/libp2p/go-libp2p-core/peer" - "go.opentelemetry.io/otel" - "go.opentelemetry.io/otel/attribute" - "go.opentelemetry.io/otel/trace" - - datatransfer "github.com/filecoin-project/go-data-transfer" - "github.com/filecoin-project/go-data-transfer/channels" -) - -type receiver struct { - manager *manager -} - -// ReceiveRequest takes an incoming data transfer request, validates the voucher and -// processes the message. -func (r *receiver) ReceiveRequest( - ctx context.Context, - initiator peer.ID, - incoming datatransfer.Request) { - err := r.receiveRequest(ctx, initiator, incoming) - if err != nil { - log.Warnf("error processing request from %s: %s", initiator, err) - } -} - -func (r *receiver) receiveRequest(ctx context.Context, initiator peer.ID, incoming datatransfer.Request) error { - chid := datatransfer.ChannelID{Initiator: initiator, Responder: r.manager.peerID, ID: incoming.TransferID()} - ctx, _ = r.manager.spansIndex.SpanForChannel(ctx, chid) - ctx, span := otel.Tracer("data-transfer").Start(ctx, "receiveRequest", trace.WithAttributes( - attribute.String("channelID", chid.String()), - attribute.String("baseCid", incoming.BaseCid().String()), - attribute.Bool("isNew", incoming.IsNew()), - attribute.Bool("isRestart", incoming.IsRestart()), - attribute.Bool("isUpdate", incoming.IsUpdate()), - attribute.Bool("isCancel", incoming.IsCancel()), - attribute.Bool("isPaused", incoming.IsPaused()), - )) - defer span.End() - response, receiveErr := r.manager.OnRequestReceived(chid, incoming) - - if receiveErr == datatransfer.ErrResume { - chst, err := r.manager.channels.GetByID(ctx, chid) - if err != nil { - return err - } - if resumeTransportStatesResponder.Contains(chst.Status()) { - return r.manager.transport.(datatransfer.PauseableTransport).ResumeChannel(ctx, response, chid) - } - receiveErr = nil - } - - if response != nil { - if (response.IsNew() || response.IsRestart()) && response.Accepted() && !incoming.IsPull() { - var channel datatransfer.ChannelState - if response.IsRestart() { - var err error - channel, err = r.manager.channels.GetByID(ctx, chid) - if err != nil { - return err - } - } - - stor, _ := incoming.Selector() - if err := r.manager.transport.OpenChannel(ctx, initiator, chid, cidlink.Link{Cid: incoming.BaseCid()}, stor, channel, response); err != nil { - return err - } - } else { - if err := r.manager.dataTransferNetwork.SendMessage(ctx, initiator, response); err != nil { - return err - } - } - } - - if receiveErr == datatransfer.ErrPause { - return r.manager.transport.(datatransfer.PauseableTransport).PauseChannel(ctx, chid) - } - - if receiveErr != nil { - _ = r.manager.transport.CloseChannel(ctx, chid) - return receiveErr - } - - return nil -} - -// ReceiveResponse handles responses to our Push or Pull data transfer request. -// It schedules a transfer only if our Pull Request is accepted. -func (r *receiver) ReceiveResponse( - ctx context.Context, - sender peer.ID, - incoming datatransfer.Response) { - err := r.receiveResponse(ctx, sender, incoming) - if err != nil { - log.Error(err) - } -} -func (r *receiver) receiveResponse( - ctx context.Context, - sender peer.ID, - incoming datatransfer.Response) error { - chid := datatransfer.ChannelID{Initiator: r.manager.peerID, Responder: sender, ID: incoming.TransferID()} - ctx, _ = r.manager.spansIndex.SpanForChannel(ctx, chid) - ctx, span := otel.Tracer("data-transfer").Start(ctx, "receiveResponse", trace.WithAttributes( - attribute.String("channelID", chid.String()), - attribute.Bool("accepted", incoming.Accepted()), - attribute.Bool("isComplete", incoming.IsComplete()), - attribute.Bool("isNew", incoming.IsNew()), - attribute.Bool("isRestart", incoming.IsRestart()), - attribute.Bool("isUpdate", incoming.IsUpdate()), - attribute.Bool("isCancel", incoming.IsCancel()), - attribute.Bool("isPaused", incoming.IsPaused()), - )) - defer span.End() - err := r.manager.OnResponseReceived(chid, incoming) - if err == datatransfer.ErrPause { - return r.manager.transport.(datatransfer.PauseableTransport).PauseChannel(ctx, chid) - } - if err != nil { - log.Warnf("closing channel %s after getting error processing response from %s: %s", - chid, sender, err) - - _ = r.manager.transport.CloseChannel(ctx, chid) - return err - } - return nil -} - -func (r *receiver) ReceiveError(err error) { - log.Errorf("received error message on data transfer: %s", err.Error()) -} - -func (r *receiver) ReceiveRestartExistingChannelRequest(ctx context.Context, - sender peer.ID, - incoming datatransfer.Request) { - - ch, err := incoming.RestartChannelId() - if err != nil { - log.Errorf("cannot restart channel: failed to fetch channel Id: %w", err) - return - } - - ctx, _ = r.manager.spansIndex.SpanForChannel(ctx, ch) - ctx, span := otel.Tracer("data-transfer").Start(ctx, "receiveRequest", trace.WithAttributes( - attribute.String("channelID", ch.String()), - )) - defer span.End() - log.Infof("channel %s: received restart existing channel request from %s", ch, sender) - - // validate channel exists -> in non-terminal state and that the sender matches - channel, err := r.manager.channels.GetByID(ctx, ch) - if err != nil || channel == nil { - // nothing to do here, we wont handle the request - return - } - - // initiator should be me - if channel.ChannelID().Initiator != r.manager.peerID { - log.Errorf("cannot restart channel %s: channel initiator is not the manager peer", ch) - return - } - - // other peer should be the counter party on the channel - if channel.OtherPeer() != sender { - log.Errorf("cannot restart channel %s: channel counterparty is not the sender peer", ch) - return - } - - // channel should NOT be terminated - if channels.IsChannelTerminated(channel.Status()) { - log.Errorf("cannot restart channel %s: channel already terminated", ch) - return - } - - switch r.manager.channelDataTransferType(channel) { - case ManagerPeerCreatePush: - if err := r.manager.openPushRestartChannel(ctx, channel); err != nil { - log.Errorf("failed to open push restart channel %s: %s", ch, err) - } - case ManagerPeerCreatePull: - if err := r.manager.openPullRestartChannel(ctx, channel); err != nil { - log.Errorf("failed to open pull restart channel %s: %s", ch, err) - } - default: - log.Error("peer is not the creator of the channel") - } -} diff --git a/impl/receiving_requests.go b/impl/receiving_requests.go new file mode 100644 index 00000000..825ed3f2 --- /dev/null +++ b/impl/receiving_requests.go @@ -0,0 +1,310 @@ +package impl + +import ( + "context" + + "github.com/ipfs/go-cid" + "github.com/ipld/go-ipld-prime/datamodel" + "github.com/libp2p/go-libp2p-core/peer" + "golang.org/x/xerrors" + + datatransfer "github.com/filecoin-project/go-data-transfer/v2" + "github.com/filecoin-project/go-data-transfer/v2/message" + "github.com/filecoin-project/go-data-transfer/v2/message/types" +) + +func (m *manager) OnRequestReceived(chid datatransfer.ChannelID, request datatransfer.Request) (datatransfer.Response, error) { + + // if request is restart request, process as restart + if request.IsRestart() { + return m.receiveRestartRequest(chid, request) + } + + // if request is new request, process as new + if request.IsNew() { + return m.receiveNewRequest(chid, request) + } + + // if request is cancel request, process as cancel + if request.IsCancel() { + log.Infof("channel %s: received cancel request, cleaning up channel", chid) + + return nil, m.channels.Cancel(chid) + } + // if request contains a new voucher, process updated voucher + if request.IsVoucher() { + return m.processUpdateVoucher(chid, request) + } + // otherwise process as an "update" message (i.e. a pause or resume) + return m.receiveUpdateRequest(chid, request) +} + +// receiveNewRequest handles an incoming new request message +func (m *manager) receiveNewRequest(chid datatransfer.ChannelID, incoming datatransfer.Request) (datatransfer.Response, error) { + log.Infof("channel %s: received new channel request from %s", chid, chid.Initiator) + + // process the new message, including validations + result, err := m.acceptRequest(chid, incoming) + + // generate a response message + msg := message.ValidationResultResponse(types.NewMessage, incoming.TransferID(), result, err, result.ForcePause) + + // return the channel update + if err == nil && !result.Accepted { + err = datatransfer.ErrRejected + } + return msg, err +} + +// acceptRequest performs processing (including validation) on a new incoming request +func (m *manager) acceptRequest(chid datatransfer.ChannelID, incoming datatransfer.Request) (datatransfer.ValidationResult, error) { + + // read the voucher and validate the request + stor, err := incoming.Selector() + if err != nil { + return datatransfer.ValidationResult{}, err + } + + voucher, err := incoming.TypedVoucher() + if err != nil { + return datatransfer.ValidationResult{}, err + } + processor, ok := m.validatedTypes.Processor(voucher.Type) + if !ok { + return datatransfer.ValidationResult{}, xerrors.Errorf("unknown voucher type: %s", voucher.Type) + } + + var validatorFunc func(datatransfer.ChannelID, peer.ID, datamodel.Node, cid.Cid, datamodel.Node) (datatransfer.ValidationResult, error) + validator := processor.(datatransfer.RequestValidator) + if incoming.IsPull() { + validatorFunc = validator.ValidatePull + } else { + validatorFunc = validator.ValidatePush + } + + result, err := validatorFunc(chid, chid.Initiator, voucher.Voucher, incoming.BaseCid(), stor) + + // if an error occurred during validation or the request was not accepted, return + if err != nil || !result.Accepted { + return result, err + } + + // create the channel + var dataSender, dataReceiver peer.ID + if incoming.IsPull() { + dataSender = m.peerID + dataReceiver = chid.Initiator + } else { + dataSender = chid.Initiator + dataReceiver = m.peerID + } + + log.Infow("data-transfer request validated, will create & start tracking channel", "channelID", chid, "payloadCid", incoming.BaseCid()) + _, _, err = m.channels.CreateNew( + m.peerID, + incoming.TransferID(), + incoming.BaseCid(), + stor, + voucher, + chid.Initiator, + dataSender, + dataReceiver, + ) + if err != nil { + log.Errorw("failed to create and start tracking channel", "channelID", chid, "err", err) + return result, err + } + + // record that the channel was accepted + log.Debugw("successfully created and started tracking channel", "channelID", chid) + if err := m.channels.Accept(chid); err != nil { + return result, err + } + + // read the channel state + chst, err := m.channels.GetByID(context.TODO(), chid) + if err != nil { + return datatransfer.ValidationResult{}, err + } + + // record validation events + if err := m.recordAcceptedValidationEvents(chst, result); err != nil { + return result, err + } + + // configure the transport + processor, has := m.transportConfigurers.Processor(voucher.Type) + if has { + transportConfigurer := processor.(datatransfer.TransportConfigurer) + transportConfigurer(chid, voucher, m.transport) + } + + return result, nil +} + +// receiveRestartRequest handles an incoming restart request message +func (m *manager) receiveRestartRequest(chid datatransfer.ChannelID, incoming datatransfer.Request) (datatransfer.Response, error) { + log.Infof("channel %s: received restart request", chid) + + // process the restart message, including validations + stayPaused, result, err := m.restartRequest(chid, incoming) + + // generate a response message + msg := message.ValidationResultResponse(types.RestartMessage, incoming.TransferID(), result, err, stayPaused) + + // return the response message and any errors + if err == nil && !result.Accepted { + err = datatransfer.ErrRejected + } + return msg, err +} + +// restartRequest performs processing (including validation) on a incoming restart request +func (m *manager) restartRequest(chid datatransfer.ChannelID, + incoming datatransfer.Request) (bool, datatransfer.ValidationResult, error) { + + // restart requests are invalid if we the initiator + // (the responder must send a "restart existing channel request") + initiator := chid.Initiator + if m.peerID == initiator { + return false, datatransfer.ValidationResult{}, xerrors.New("initiator cannot be manager peer for a restart request") + } + + // read the channel state + chst, err := m.channels.GetByID(context.TODO(), chid) + if err != nil { + return false, datatransfer.ValidationResult{}, err + } + + // valide that the request parameters match the original request + // TODO: not sure this is needed -- the request parameters cannot change, + // so perhaps the solution is just to ignore them in the message + if err := m.validateRestartRequest(context.Background(), initiator, chst, incoming); err != nil { + return false, datatransfer.ValidationResult{}, xerrors.Errorf("restart request for channel %s failed validation: %w", chid, err) + } + + // perform a revalidation against the last voucher + result, err := m.validateRestart(chst) + stayPaused := result.LeaveRequestPaused(chst) + + // if an error occurred during validation return + if err != nil { + return stayPaused, result, err + } + + // if the request is now rejected, error the channel + if !result.Accepted { + return stayPaused, result, m.recordRejectedValidationEvents(chid, result) + } + + // record the restart events + if err := m.channels.Restart(chid); err != nil { + return stayPaused, result, xerrors.Errorf("failed to restart channel %s: %w", chid, err) + } + + // record validation events + if err := m.recordAcceptedValidationEvents(chst, result); err != nil { + return stayPaused, result, err + } + + // configure the transport + voucher, err := incoming.Voucher() + if err != nil { + return stayPaused, result, err + } + voucherType := incoming.VoucherType() + typedVoucher := datatransfer.TypedVoucher{Voucher: voucher, Type: voucherType} + processor, has := m.transportConfigurers.Processor(voucherType) + if has { + transportConfigurer := processor.(datatransfer.TransportConfigurer) + transportConfigurer(chid, typedVoucher, m.transport) + } + return stayPaused, result, nil +} + +// recordRejectedValidationEvents sends changes based on an reject validation to the state machine +func (m *manager) recordRejectedValidationEvents(chid datatransfer.ChannelID, result datatransfer.ValidationResult) error { + if result.VoucherResult != nil { + if err := m.channels.NewVoucherResult(chid, *result.VoucherResult); err != nil { + return err + } + } + + return m.channels.Error(chid, datatransfer.ErrRejected) +} + +// recordAcceptedValidationEvents sends changes based on an accepted validation to the state machine +func (m *manager) recordAcceptedValidationEvents(chst datatransfer.ChannelState, result datatransfer.ValidationResult) error { + chid := chst.ChannelID() + + // record the voucher result if present + if result.VoucherResult != nil && result.VoucherResult.Voucher != nil { + err := m.channels.NewVoucherResult(chid, *result.VoucherResult) + if err != nil { + return err + } + } + + // record the change in data limit if different + if result.DataLimit != chst.DataLimit() { + err := m.channels.SetDataLimit(chid, result.DataLimit) + if err != nil { + return err + } + } + + // record the finalization state if different + if result.RequiresFinalization != chst.RequiresFinalization() { + err := m.channels.SetRequiresFinalization(chid, result.RequiresFinalization) + if err != nil { + return err + } + } + + // pause or resume the request as neccesary + if result.LeaveRequestPaused(chst) { + if !chst.ResponderPaused() { + err := m.channels.PauseResponder(chid) + if err != nil { + return err + } + } + } else { + if chst.ResponderPaused() { + err := m.channels.ResumeResponder(chid) + if err != nil { + return err + } + } + } + + return nil +} + +// processUpdateVoucher handles an incoming request message with an updated voucher +func (m *manager) processUpdateVoucher(chid datatransfer.ChannelID, request datatransfer.Request) (datatransfer.Response, error) { + // decode the voucher and save it on the channel + voucher, err := request.TypedVoucher() + if err != nil { + return nil, err + } + return nil, m.channels.NewVoucher(chid, voucher) +} + +// receiveUpdateRequest handles an incoming request message change in pause status +func (m *manager) receiveUpdateRequest(chid datatransfer.ChannelID, request datatransfer.Request) (datatransfer.Response, error) { + if request.IsPaused() { + return nil, m.pauseOther(chid) + } + return nil, m.resumeOther(chid) +} + +// validateRestart looks up the appropriate validator and validates a restart +func (m *manager) validateRestart(chst datatransfer.ChannelState) (datatransfer.ValidationResult, error) { + chv := chst.Voucher() + + processor, _ := m.validatedTypes.Processor(chv.Type) + validator := processor.(datatransfer.RequestValidator) + + return validator.ValidateRestart(chst.ChannelID(), chst) +} diff --git a/impl/responding_test.go b/impl/responding_test.go index 2566886f..92abd7e1 100644 --- a/impl/responding_test.go +++ b/impl/responding_test.go @@ -11,115 +11,119 @@ import ( "github.com/ipfs/go-datastore" dss "github.com/ipfs/go-datastore/sync" "github.com/ipld/go-ipld-prime" - cidlink "github.com/ipld/go-ipld-prime/linking/cid" + "github.com/ipld/go-ipld-prime/datamodel" + "github.com/ipld/go-ipld-prime/node/basicnode" + selectorparse "github.com/ipld/go-ipld-prime/traversal/selector/parse" "github.com/libp2p/go-libp2p-core/peer" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "golang.org/x/xerrors" - datatransfer "github.com/filecoin-project/go-data-transfer" - "github.com/filecoin-project/go-data-transfer/channels" - . "github.com/filecoin-project/go-data-transfer/impl" - "github.com/filecoin-project/go-data-transfer/message" - "github.com/filecoin-project/go-data-transfer/testutil" + datatransfer "github.com/filecoin-project/go-data-transfer/v2" + . "github.com/filecoin-project/go-data-transfer/v2/impl" + "github.com/filecoin-project/go-data-transfer/v2/message" + "github.com/filecoin-project/go-data-transfer/v2/testutil" ) func TestDataTransferResponding(t *testing.T) { // create network ctx := context.Background() testCases := map[string]struct { - expectedEvents []datatransfer.EventCode - configureValidator func(sv *testutil.StubbedValidator) - configureRevalidator func(sv *testutil.StubbedRevalidator) - verify func(t *testing.T, h *receiverHarness) + expectedEvents []datatransfer.EventCode + configureValidator func(sv *testutil.StubbedValidator) + verify func(t *testing.T, h *receiverHarness) }{ "new push request validates": { - expectedEvents: []datatransfer.EventCode{datatransfer.Open, datatransfer.NewVoucherResult, datatransfer.Accept}, + expectedEvents: []datatransfer.EventCode{ + datatransfer.Open, + datatransfer.Accept, + datatransfer.NewVoucherResult, + }, configureValidator: func(sv *testutil.StubbedValidator) { sv.ExpectSuccessPush() - sv.StubResult(testutil.NewFakeDTType()) + vr := testutil.NewTestTypedVoucher() + sv.StubResult(datatransfer.ValidationResult{Accepted: true, VoucherResult: &vr}) }, verify: func(t *testing.T, h *receiverHarness) { - h.network.Delegate.ReceiveRequest(h.ctx, h.peers[1], h.pushRequest) + response, err := h.transport.EventHandler.OnRequestReceived(channelID(h.id, h.peers), h.pushRequest) + require.NoError(t, err) require.Len(t, h.sv.ValidationsReceived, 1) validation := h.sv.ValidationsReceived[0] assert.False(t, validation.IsPull) assert.Equal(t, h.peers[1], validation.Other) - assert.Equal(t, h.voucher, validation.Voucher) + assert.True(t, ipld.DeepEqual(h.voucher.Voucher, validation.Voucher)) assert.Equal(t, h.baseCid, validation.BaseCid) assert.Equal(t, h.stor, validation.Selector) - - require.Len(t, h.transport.OpenedChannels, 1) - openChannel := h.transport.OpenedChannels[0] - require.Equal(t, openChannel.ChannelID, channelID(h.id, h.peers)) - require.Equal(t, openChannel.DataSender, h.peers[1]) - require.Equal(t, openChannel.Root, cidlink.Link{Cid: h.baseCid}) - require.Equal(t, openChannel.Selector, h.stor) - require.False(t, openChannel.Message.IsRequest()) - response, ok := openChannel.Message.(datatransfer.Response) - require.True(t, ok) require.True(t, response.Accepted()) require.Equal(t, response.TransferID(), h.id) require.False(t, response.IsUpdate()) require.False(t, response.IsCancel()) require.False(t, response.IsPaused()) require.True(t, response.IsNew()) - require.True(t, response.IsVoucherResult()) + require.True(t, response.IsValidationResult()) + }, + }, + "new push request rejects": { + configureValidator: func(sv *testutil.StubbedValidator) { + sv.ExpectSuccessPush() + vr := testutil.NewTestTypedVoucher() + sv.StubResult(datatransfer.ValidationResult{Accepted: false, VoucherResult: &vr}) + }, + verify: func(t *testing.T, h *receiverHarness) { + response, err := h.transport.EventHandler.OnRequestReceived(channelID(h.id, h.peers), h.pushRequest) + require.EqualError(t, err, datatransfer.ErrRejected.Error()) + require.False(t, response.Accepted()) + require.Equal(t, response.TransferID(), h.id) + require.False(t, response.IsUpdate()) + require.False(t, response.IsCancel()) + require.False(t, response.IsPaused()) + require.True(t, response.IsNew()) + require.True(t, response.IsValidationResult()) }, }, "new push request errors": { configureValidator: func(sv *testutil.StubbedValidator) { sv.ExpectErrorPush() - sv.StubResult(testutil.NewFakeDTType()) + vr := testutil.NewTestTypedVoucher() + sv.StubResult(datatransfer.ValidationResult{VoucherResult: &vr}) }, verify: func(t *testing.T, h *receiverHarness) { - h.network.Delegate.ReceiveRequest(h.ctx, h.peers[1], h.pushRequest) - require.Len(t, h.network.SentMessages, 1) - responseMessage := h.network.SentMessages[0].Message - require.False(t, responseMessage.IsRequest()) - response, ok := responseMessage.(datatransfer.Response) - require.True(t, ok) + response, err := h.transport.EventHandler.OnRequestReceived(channelID(h.id, h.peers), h.pushRequest) + require.Error(t, err) require.False(t, response.Accepted()) require.Equal(t, response.TransferID(), h.id) require.False(t, response.IsUpdate()) require.False(t, response.IsCancel()) require.False(t, response.IsPaused()) require.True(t, response.IsNew()) - require.True(t, response.IsVoucherResult()) + require.True(t, response.IsValidationResult()) }, }, "new push request pauses": { configureValidator: func(sv *testutil.StubbedValidator) { - sv.ExpectPausePush() - sv.StubResult(testutil.NewFakeDTType()) + sv.ExpectSuccessPush() + vr := testutil.NewTestTypedVoucher() + sv.StubResult(datatransfer.ValidationResult{Accepted: true, ForcePause: true, VoucherResult: &vr}) }, verify: func(t *testing.T, h *receiverHarness) { - h.network.Delegate.ReceiveRequest(h.ctx, h.peers[1], h.pushRequest) - - require.Len(t, h.transport.OpenedChannels, 1) - openChannel := h.transport.OpenedChannels[0] - require.Equal(t, openChannel.ChannelID, channelID(h.id, h.peers)) - require.Equal(t, openChannel.DataSender, h.peers[1]) - require.Equal(t, openChannel.Root, cidlink.Link{Cid: h.baseCid}) - require.Equal(t, openChannel.Selector, h.stor) - require.False(t, openChannel.Message.IsRequest()) - response, ok := openChannel.Message.(datatransfer.Response) - require.True(t, ok) + response, err := h.transport.EventHandler.OnRequestReceived(channelID(h.id, h.peers), h.pushRequest) + require.NoError(t, err) require.True(t, response.Accepted()) require.Equal(t, response.TransferID(), h.id) require.False(t, response.IsUpdate()) require.False(t, response.IsCancel()) require.True(t, response.IsPaused()) require.True(t, response.IsNew()) - require.True(t, response.IsVoucherResult()) - require.Len(t, h.transport.PausedChannels, 1) - require.Equal(t, channelID(h.id, h.peers), h.transport.PausedChannels[0]) + require.True(t, response.IsValidationResult()) }, }, "new pull request validates": { - expectedEvents: []datatransfer.EventCode{datatransfer.Open, datatransfer.Accept}, + expectedEvents: []datatransfer.EventCode{ + datatransfer.Open, + datatransfer.Accept, + }, configureValidator: func(sv *testutil.StubbedValidator) { sv.ExpectSuccessPull() + sv.StubResult(datatransfer.ValidationResult{Accepted: true}) }, verify: func(t *testing.T, h *receiverHarness) { response, err := h.transport.EventHandler.OnRequestReceived(channelID(h.id, h.peers), h.pullRequest) @@ -128,7 +132,7 @@ func TestDataTransferResponding(t *testing.T) { validation := h.sv.ValidationsReceived[0] assert.True(t, validation.IsPull) assert.Equal(t, h.peers[1], validation.Other) - assert.Equal(t, h.voucher, validation.Voucher) + assert.True(t, ipld.DeepEqual(h.voucher.Voucher, validation.Voucher)) assert.Equal(t, h.baseCid, validation.BaseCid) assert.Equal(t, h.stor, validation.Selector) require.True(t, response.Accepted()) @@ -137,7 +141,24 @@ func TestDataTransferResponding(t *testing.T) { require.False(t, response.IsCancel()) require.False(t, response.IsPaused()) require.True(t, response.IsNew()) - require.True(t, response.IsVoucherResult()) + require.True(t, response.IsValidationResult()) + }, + }, + "new pull request rejects": { + configureValidator: func(sv *testutil.StubbedValidator) { + sv.ExpectSuccessPull() + sv.StubResult(datatransfer.ValidationResult{Accepted: false}) + }, + verify: func(t *testing.T, h *receiverHarness) { + response, err := h.transport.EventHandler.OnRequestReceived(channelID(h.id, h.peers), h.pullRequest) + require.EqualError(t, err, datatransfer.ErrRejected.Error()) + require.False(t, response.Accepted()) + require.Equal(t, response.TransferID(), h.id) + require.False(t, response.IsUpdate()) + require.False(t, response.IsCancel()) + require.False(t, response.IsPaused()) + require.True(t, response.IsNew()) + require.True(t, response.IsValidationResult()) }, }, "new pull request errors": { @@ -153,72 +174,107 @@ func TestDataTransferResponding(t *testing.T) { require.False(t, response.IsCancel()) require.False(t, response.IsPaused()) require.True(t, response.IsNew()) - require.True(t, response.IsVoucherResult()) + require.True(t, response.IsValidationResult()) }, }, "new pull request pauses": { configureValidator: func(sv *testutil.StubbedValidator) { - sv.ExpectPausePull() + sv.ExpectSuccessPull() + sv.StubResult(datatransfer.ValidationResult{Accepted: true, ForcePause: true}) }, verify: func(t *testing.T, h *receiverHarness) { response, err := h.transport.EventHandler.OnRequestReceived(channelID(h.id, h.peers), h.pullRequest) - require.EqualError(t, err, datatransfer.ErrPause.Error()) - + require.NoError(t, err) require.True(t, response.Accepted()) require.Equal(t, response.TransferID(), h.id) require.False(t, response.IsUpdate()) require.False(t, response.IsCancel()) require.True(t, response.IsPaused()) require.True(t, response.IsNew()) - require.True(t, response.IsVoucherResult()) + require.True(t, response.IsValidationResult()) require.True(t, response.EmptyVoucherResult()) }, }, + "send voucher results from responder succeeds, push request": { + configureValidator: func(sv *testutil.StubbedValidator) { + sv.ExpectSuccessPush() + sv.StubResult(datatransfer.ValidationResult{Accepted: true}) + }, + verify: func(t *testing.T, h *receiverHarness) { + _, _ = h.transport.EventHandler.OnRequestReceived(channelID(h.id, h.peers), h.pushRequest) + newVoucherResult := testutil.NewTestTypedVoucher() + err := h.dt.SendVoucherResult(h.ctx, channelID(h.id, h.peers), newVoucherResult) + require.NoError(t, err) + }, + }, + "send voucher results from responder succeeds, pull request": { + configureValidator: func(sv *testutil.StubbedValidator) { + sv.ExpectSuccessPull() + sv.StubResult(datatransfer.ValidationResult{Accepted: true}) + }, + verify: func(t *testing.T, h *receiverHarness) { + _, _ = h.transport.EventHandler.OnRequestReceived(channelID(h.id, h.peers), h.pullRequest) + newVoucherResult := testutil.NewTestTypedVoucher() + err := h.dt.SendVoucherResult(h.ctx, channelID(h.id, h.peers), newVoucherResult) + require.NoError(t, err) + }, + }, "send vouchers from responder fails, push request": { + configureValidator: func(sv *testutil.StubbedValidator) { + sv.ExpectSuccessPush() + sv.StubResult(datatransfer.ValidationResult{Accepted: true}) + }, verify: func(t *testing.T, h *receiverHarness) { - h.network.Delegate.ReceiveRequest(h.ctx, h.peers[1], h.pushRequest) - newVoucher := testutil.NewFakeDTType() + _, _ = h.transport.EventHandler.OnRequestReceived(channelID(h.id, h.peers), h.pushRequest) + newVoucher := testutil.NewTestTypedVoucher() err := h.dt.SendVoucher(h.ctx, channelID(h.id, h.peers), newVoucher) require.EqualError(t, err, "cannot send voucher for request we did not initiate") }, }, "send vouchers from responder fails, pull request": { + configureValidator: func(sv *testutil.StubbedValidator) { + sv.ExpectSuccessPull() + sv.StubResult(datatransfer.ValidationResult{Accepted: true}) + }, verify: func(t *testing.T, h *receiverHarness) { _, _ = h.transport.EventHandler.OnRequestReceived(channelID(h.id, h.peers), h.pullRequest) - newVoucher := testutil.NewFakeDTType() + newVoucher := testutil.NewTestTypedVoucher() err := h.dt.SendVoucher(h.ctx, channelID(h.id, h.peers), newVoucher) require.EqualError(t, err, "cannot send voucher for request we did not initiate") }, }, "receive voucher": { - expectedEvents: []datatransfer.EventCode{datatransfer.Open, datatransfer.NewVoucherResult, datatransfer.Accept, datatransfer.NewVoucher, datatransfer.ResumeResponder}, + expectedEvents: []datatransfer.EventCode{ + datatransfer.Open, + datatransfer.Accept, + datatransfer.NewVoucherResult, + datatransfer.NewVoucher, + }, configureValidator: func(sv *testutil.StubbedValidator) { sv.ExpectSuccessPush() - sv.StubResult(testutil.NewFakeDTType()) + vr := testutil.NewTestTypedVoucher() + sv.StubResult(datatransfer.ValidationResult{Accepted: true, VoucherResult: &vr}) }, - configureRevalidator: func(sv *testutil.StubbedRevalidator) { - sv.ExpectSuccessErrResume() - }, - verify: func(t *testing.T, h *receiverHarness) { - h.network.Delegate.ReceiveRequest(h.ctx, h.peers[1], h.pushRequest) + _, _ = h.transport.EventHandler.OnRequestReceived(channelID(h.id, h.peers), h.pushRequest) _, err := h.transport.EventHandler.OnRequestReceived(channelID(h.id, h.peers), h.voucherUpdate) - require.EqualError(t, err, datatransfer.ErrResume.Error()) + require.NoError(t, err) }, }, "receive pause, unpause": { expectedEvents: []datatransfer.EventCode{ datatransfer.Open, - datatransfer.NewVoucherResult, datatransfer.Accept, + datatransfer.NewVoucherResult, datatransfer.PauseInitiator, datatransfer.ResumeInitiator}, configureValidator: func(sv *testutil.StubbedValidator) { sv.ExpectSuccessPush() - sv.StubResult(testutil.NewFakeDTType()) + vr := testutil.NewTestTypedVoucher() + sv.StubResult(datatransfer.ValidationResult{Accepted: true, VoucherResult: &vr}) }, verify: func(t *testing.T, h *receiverHarness) { - h.network.Delegate.ReceiveRequest(h.ctx, h.peers[1], h.pushRequest) + _, _ = h.transport.EventHandler.OnRequestReceived(channelID(h.id, h.peers), h.pushRequest) _, err := h.transport.EventHandler.OnRequestReceived(channelID(h.id, h.peers), h.pauseUpdate) require.NoError(t, err) _, err = h.transport.EventHandler.OnRequestReceived(channelID(h.id, h.peers), h.resumeUpdate) @@ -228,289 +284,282 @@ func TestDataTransferResponding(t *testing.T) { "receive pause, set pause local, receive unpause": { expectedEvents: []datatransfer.EventCode{ datatransfer.Open, - datatransfer.NewVoucherResult, datatransfer.Accept, + datatransfer.NewVoucherResult, datatransfer.PauseInitiator, datatransfer.PauseResponder, datatransfer.ResumeInitiator}, configureValidator: func(sv *testutil.StubbedValidator) { sv.ExpectSuccessPush() - sv.StubResult(testutil.NewFakeDTType()) + vr := testutil.NewTestTypedVoucher() + sv.StubResult(datatransfer.ValidationResult{Accepted: true, VoucherResult: &vr}) }, verify: func(t *testing.T, h *receiverHarness) { - h.network.Delegate.ReceiveRequest(h.ctx, h.peers[1], h.pushRequest) + _, _ = h.transport.EventHandler.OnRequestReceived(channelID(h.id, h.peers), h.pushRequest) _, err := h.transport.EventHandler.OnRequestReceived(channelID(h.id, h.peers), h.pauseUpdate) require.NoError(t, err) err = h.dt.PauseDataTransferChannel(h.ctx, channelID(h.id, h.peers)) require.NoError(t, err) _, err = h.transport.EventHandler.OnRequestReceived(channelID(h.id, h.peers), h.resumeUpdate) - require.EqualError(t, err, datatransfer.ErrPause.Error()) + require.NoError(t, err) }, }, "receive cancel": { - expectedEvents: []datatransfer.EventCode{datatransfer.Open, datatransfer.NewVoucherResult, datatransfer.Accept, datatransfer.Cancel, datatransfer.CleanupComplete}, + expectedEvents: []datatransfer.EventCode{ + datatransfer.Open, + datatransfer.Accept, + datatransfer.NewVoucherResult, + datatransfer.Cancel, + datatransfer.CleanupComplete, + }, configureValidator: func(sv *testutil.StubbedValidator) { sv.ExpectSuccessPush() - sv.StubResult(testutil.NewFakeDTType()) + vr := testutil.NewTestTypedVoucher() + sv.StubResult(datatransfer.ValidationResult{Accepted: true, VoucherResult: &vr}) }, verify: func(t *testing.T, h *receiverHarness) { - h.network.Delegate.ReceiveRequest(h.ctx, h.peers[1], h.pushRequest) + _, _ = h.transport.EventHandler.OnRequestReceived(channelID(h.id, h.peers), h.pushRequest) _, err := h.transport.EventHandler.OnRequestReceived(channelID(h.id, h.peers), h.cancelUpdate) require.NoError(t, err) - require.Len(t, h.transport.CleanedUpChannels, 1) - require.Equal(t, channelID(h.id, h.peers), h.transport.CleanedUpChannels[0]) }, }, "validate and revalidate successfully, push": { expectedEvents: []datatransfer.EventCode{ datatransfer.Open, - datatransfer.NewVoucherResult, datatransfer.Accept, + datatransfer.NewVoucherResult, + datatransfer.SetDataLimit, + datatransfer.TransferInitiated, datatransfer.DataReceivedProgress, datatransfer.DataReceived, - datatransfer.NewVoucherResult, - datatransfer.PauseResponder, + datatransfer.DataLimitExceeded, datatransfer.NewVoucher, datatransfer.NewVoucherResult, + datatransfer.SetDataLimit, datatransfer.ResumeResponder, }, configureValidator: func(sv *testutil.StubbedValidator) { sv.ExpectSuccessPush() - sv.StubResult(testutil.NewFakeDTType()) - }, - configureRevalidator: func(srv *testutil.StubbedRevalidator) { - srv.ExpectPausePushCheck() - srv.StubRevalidationResult(testutil.NewFakeDTType()) - srv.ExpectSuccessErrResume() - srv.StubCheckResult(testutil.NewFakeDTType()) + vr := testutil.NewTestTypedVoucher() + sv.StubResult(datatransfer.ValidationResult{Accepted: true, DataLimit: 1000, VoucherResult: &vr}) }, verify: func(t *testing.T, h *receiverHarness) { - h.network.Delegate.ReceiveRequest(h.ctx, h.peers[1], h.pushRequest) - err := h.transport.EventHandler.OnDataReceived(channelID(h.id, h.peers), cidlink.Link{Cid: testutil.GenerateCids(1)[0]}, 12345, 1, true) - require.EqualError(t, err, datatransfer.ErrPause.Error()) - require.Len(t, h.network.SentMessages, 1) - response, ok := h.network.SentMessages[0].Message.(datatransfer.Response) + _, _ = h.transport.EventHandler.OnRequestReceived(channelID(h.id, h.peers), h.pushRequest) + h.transport.EventHandler.OnTransportEvent(channelID(h.id, h.peers), datatransfer.TransportInitiatedTransfer{}) + h.transport.EventHandler.OnTransportEvent(channelID(h.id, h.peers), datatransfer.TransportReceivedData{Size: 12345, Index: basicnode.NewInt(1)}) + h.transport.EventHandler.OnTransportEvent(channelID(h.id, h.peers), datatransfer.TransportReachedDataLimit{}) + require.Len(t, h.transport.MessagesSent, 1) + response, ok := h.transport.MessagesSent[0].Message.(datatransfer.Response) require.True(t, ok) - require.True(t, response.Accepted()) require.Equal(t, response.TransferID(), h.id) - require.False(t, response.IsUpdate()) + require.True(t, response.IsUpdate()) require.False(t, response.IsCancel()) require.True(t, response.IsPaused()) - require.True(t, response.IsVoucherResult()) - require.False(t, response.EmptyVoucherResult()) - response, err = h.transport.EventHandler.OnRequestReceived(channelID(h.id, h.peers), h.voucherUpdate) - require.EqualError(t, err, datatransfer.ErrResume.Error()) + require.False(t, response.IsValidationResult()) + response, err := h.transport.EventHandler.OnRequestReceived(channelID(h.id, h.peers), h.voucherUpdate) + require.NoError(t, err) + require.Nil(t, response) + vr := testutil.NewTestTypedVoucher() + err = h.dt.UpdateValidationStatus(ctx, channelID(h.id, h.peers), datatransfer.ValidationResult{Accepted: true, DataLimit: 50000, VoucherResult: &vr}) + require.NoError(t, err) + require.Len(t, h.transport.ChannelsUpdated, 1) + require.Equal(t, h.transport.ChannelsUpdated[0], channelID(h.id, h.peers)) + require.Len(t, h.transport.MessagesSent, 2) + response, ok = h.transport.MessagesSent[1].Message.(datatransfer.Response) + require.True(t, ok) require.True(t, response.Accepted()) require.Equal(t, response.TransferID(), h.id) require.False(t, response.IsUpdate()) require.False(t, response.IsCancel()) require.False(t, response.IsPaused()) - require.True(t, response.IsVoucherResult()) + require.True(t, response.IsValidationResult()) require.False(t, response.EmptyVoucherResult()) }, }, - "validate and revalidate with err": { + "validate and revalidate with rejection on second voucher": { expectedEvents: []datatransfer.EventCode{ datatransfer.Open, - datatransfer.NewVoucherResult, datatransfer.Accept, - datatransfer.DataReceivedProgress, - datatransfer.DataReceived, datatransfer.NewVoucherResult, - }, - configureValidator: func(sv *testutil.StubbedValidator) { - sv.ExpectSuccessPush() - sv.StubResult(testutil.NewFakeDTType()) - }, - configureRevalidator: func(srv *testutil.StubbedRevalidator) { - srv.ExpectErrorPushCheck() - srv.StubRevalidationResult(testutil.NewFakeDTType()) - }, - verify: func(t *testing.T, h *receiverHarness) { - h.network.Delegate.ReceiveRequest(h.ctx, h.peers[1], h.pushRequest) - err := h.transport.EventHandler.OnDataReceived(channelID(h.id, h.peers), cidlink.Link{Cid: testutil.GenerateCids(1)[0]}, 12345, 1, true) - require.Error(t, err) - require.Len(t, h.network.SentMessages, 1) - response, ok := h.network.SentMessages[0].Message.(datatransfer.Response) - require.True(t, ok) - require.False(t, response.Accepted()) - require.Equal(t, response.TransferID(), h.id) - require.False(t, response.IsUpdate()) - require.False(t, response.IsCancel()) - require.False(t, response.IsPaused()) - require.True(t, response.IsVoucherResult()) - require.False(t, response.EmptyVoucherResult()) - }, - }, - "validate and revalidate with err with second voucher": { - expectedEvents: []datatransfer.EventCode{ - datatransfer.Open, - datatransfer.NewVoucherResult, - datatransfer.Accept, + datatransfer.SetDataLimit, + datatransfer.TransferInitiated, datatransfer.DataReceivedProgress, datatransfer.DataReceived, - datatransfer.NewVoucherResult, - datatransfer.PauseResponder, + datatransfer.DataLimitExceeded, datatransfer.NewVoucher, datatransfer.NewVoucherResult, + datatransfer.Error, + datatransfer.CleanupComplete, }, configureValidator: func(sv *testutil.StubbedValidator) { sv.ExpectSuccessPush() - sv.StubResult(testutil.NewFakeDTType()) - }, - configureRevalidator: func(srv *testutil.StubbedRevalidator) { - srv.ExpectPausePushCheck() - srv.StubRevalidationResult(testutil.NewFakeDTType()) - srv.ExpectErrorRevalidation() - srv.StubCheckResult(testutil.NewFakeDTType()) + vr := testutil.NewTestTypedVoucher() + sv.StubResult(datatransfer.ValidationResult{Accepted: true, DataLimit: 1000, VoucherResult: &vr}) }, verify: func(t *testing.T, h *receiverHarness) { - h.network.Delegate.ReceiveRequest(h.ctx, h.peers[1], h.pushRequest) - err := h.transport.EventHandler.OnDataReceived(channelID(h.id, h.peers), cidlink.Link{Cid: testutil.GenerateCids(1)[0]}, 12345, 1, true) - require.EqualError(t, err, datatransfer.ErrPause.Error()) - require.Len(t, h.network.SentMessages, 1) - response, ok := h.network.SentMessages[0].Message.(datatransfer.Response) + _, _ = h.transport.EventHandler.OnRequestReceived(channelID(h.id, h.peers), h.pushRequest) + h.transport.EventHandler.OnTransportEvent(channelID(h.id, h.peers), datatransfer.TransportInitiatedTransfer{}) + h.transport.EventHandler.OnTransportEvent(channelID(h.id, h.peers), datatransfer.TransportReceivedData{Size: 12345, Index: basicnode.NewInt(1)}) + h.transport.EventHandler.OnTransportEvent(channelID(h.id, h.peers), datatransfer.TransportReachedDataLimit{}) + require.Len(t, h.transport.MessagesSent, 1) + response, ok := h.transport.MessagesSent[0].Message.(datatransfer.Response) require.True(t, ok) - require.True(t, response.Accepted()) require.Equal(t, response.TransferID(), h.id) - require.False(t, response.IsUpdate()) + require.True(t, response.IsUpdate()) require.False(t, response.IsCancel()) require.True(t, response.IsPaused()) - require.True(t, response.IsVoucherResult()) - require.False(t, response.EmptyVoucherResult()) - response, err = h.transport.EventHandler.OnRequestReceived(channelID(h.id, h.peers), h.voucherUpdate) - require.Error(t, err) + require.False(t, response.IsValidationResult()) + response, err := h.transport.EventHandler.OnRequestReceived(channelID(h.id, h.peers), h.voucherUpdate) + require.NoError(t, err) + require.Nil(t, response) + vr := testutil.NewTestTypedVoucher() + err = h.dt.UpdateValidationStatus(ctx, channelID(h.id, h.peers), datatransfer.ValidationResult{Accepted: false, VoucherResult: &vr}) + require.NoError(t, err) + require.Len(t, h.transport.ChannelsUpdated, 1) + require.Equal(t, h.transport.ChannelsUpdated[0], channelID(h.id, h.peers)) + require.Len(t, h.transport.MessagesSent, 2) + response, ok = h.transport.MessagesSent[1].Message.(datatransfer.Response) + require.True(t, ok) require.False(t, response.Accepted()) require.Equal(t, response.TransferID(), h.id) require.False(t, response.IsUpdate()) require.False(t, response.IsCancel()) require.False(t, response.IsPaused()) - require.True(t, response.IsVoucherResult()) + require.True(t, response.IsValidationResult()) require.False(t, response.EmptyVoucherResult()) }, }, "validate and revalidate successfully, pull": { expectedEvents: []datatransfer.EventCode{ datatransfer.Open, - datatransfer.NewVoucherResult, datatransfer.Accept, + datatransfer.NewVoucherResult, + datatransfer.SetDataLimit, + datatransfer.TransferInitiated, datatransfer.DataQueuedProgress, datatransfer.DataQueued, - datatransfer.NewVoucherResult, - datatransfer.PauseResponder, + datatransfer.DataLimitExceeded, datatransfer.NewVoucher, datatransfer.NewVoucherResult, + datatransfer.SetDataLimit, datatransfer.ResumeResponder, }, configureValidator: func(sv *testutil.StubbedValidator) { sv.ExpectSuccessPull() - sv.StubResult(testutil.NewFakeDTType()) - }, - configureRevalidator: func(srv *testutil.StubbedRevalidator) { - srv.ExpectPausePullCheck() - srv.StubRevalidationResult(testutil.NewFakeDTType()) - srv.ExpectSuccessErrResume() - srv.StubCheckResult(testutil.NewFakeDTType()) + vr := testutil.NewTestTypedVoucher() + sv.StubResult(datatransfer.ValidationResult{Accepted: true, DataLimit: 1000, VoucherResult: &vr}) }, verify: func(t *testing.T, h *receiverHarness) { _, err := h.transport.EventHandler.OnRequestReceived(channelID(h.id, h.peers), h.pullRequest) require.NoError(t, err) - msg, err := h.transport.EventHandler.OnDataQueued( - channelID(h.id, h.peers), - cidlink.Link{Cid: testutil.GenerateCids(1)[0]}, - 12345, 1, true) - require.EqualError(t, err, datatransfer.ErrPause.Error()) - response, ok := msg.(datatransfer.Response) + h.transport.EventHandler.OnTransportEvent(channelID(h.id, h.peers), datatransfer.TransportInitiatedTransfer{}) + h.transport.EventHandler.OnTransportEvent(channelID(h.id, h.peers), datatransfer.TransportQueuedData{Size: 12345, Index: basicnode.NewInt(1)}) + h.transport.EventHandler.OnTransportEvent(channelID(h.id, h.peers), datatransfer.TransportReachedDataLimit{}) + require.Len(t, h.transport.MessagesSent, 1) + response, ok := h.transport.MessagesSent[0].Message.(datatransfer.Response) require.True(t, ok) - require.True(t, response.Accepted()) require.Equal(t, response.TransferID(), h.id) - require.False(t, response.IsUpdate()) + require.True(t, response.IsUpdate()) require.False(t, response.IsCancel()) require.True(t, response.IsPaused()) - require.True(t, response.IsVoucherResult()) - require.False(t, response.EmptyVoucherResult()) + require.False(t, response.IsValidationResult()) response, err = h.transport.EventHandler.OnRequestReceived(channelID(h.id, h.peers), h.voucherUpdate) - require.EqualError(t, err, datatransfer.ErrResume.Error()) + require.NoError(t, err, nil) + require.Nil(t, response) + vr := testutil.NewTestTypedVoucher() + err = h.dt.UpdateValidationStatus(ctx, channelID(h.id, h.peers), datatransfer.ValidationResult{Accepted: true, DataLimit: 50000, VoucherResult: &vr}) + require.NoError(t, err) + require.Len(t, h.transport.ChannelsUpdated, 1) + require.Equal(t, h.transport.ChannelsUpdated[0], channelID(h.id, h.peers)) + require.Len(t, h.transport.MessagesSent, 2) + response, ok = h.transport.MessagesSent[1].Message.(datatransfer.Response) + require.True(t, ok) require.True(t, response.Accepted()) require.Equal(t, response.TransferID(), h.id) require.False(t, response.IsUpdate()) require.False(t, response.IsCancel()) require.False(t, response.IsPaused()) - require.True(t, response.IsVoucherResult()) + require.True(t, response.IsValidationResult()) require.False(t, response.EmptyVoucherResult()) }, }, "validated, finalize, and complete successfully": { expectedEvents: []datatransfer.EventCode{ datatransfer.Open, - datatransfer.NewVoucherResult, datatransfer.Accept, datatransfer.NewVoucherResult, + datatransfer.SetRequiresFinalization, + datatransfer.TransferInitiated, datatransfer.BeginFinalizing, datatransfer.NewVoucher, + datatransfer.NewVoucherResult, + datatransfer.SetRequiresFinalization, datatransfer.ResumeResponder, datatransfer.CleanupComplete, }, configureValidator: func(sv *testutil.StubbedValidator) { sv.ExpectSuccessPull() - sv.StubResult(testutil.NewFakeDTType()) - }, - configureRevalidator: func(srv *testutil.StubbedRevalidator) { - srv.ExpectPauseComplete() - srv.StubRevalidationResult(testutil.NewFakeDTType()) - srv.ExpectSuccessErrResume() + vr := testutil.NewTestTypedVoucher() + sv.StubResult(datatransfer.ValidationResult{Accepted: true, RequiresFinalization: true, VoucherResult: &vr}) }, verify: func(t *testing.T, h *receiverHarness) { _, err := h.transport.EventHandler.OnRequestReceived(channelID(h.id, h.peers), h.pullRequest) require.NoError(t, err) - err = h.transport.EventHandler.OnChannelCompleted(channelID(h.id, h.peers), nil) - require.NoError(t, err) - require.Len(t, h.network.SentMessages, 1) - response, ok := h.network.SentMessages[0].Message.(datatransfer.Response) + h.transport.EventHandler.OnTransportEvent(channelID(h.id, h.peers), datatransfer.TransportInitiatedTransfer{}) + h.transport.EventHandler.OnTransportEvent(channelID(h.id, h.peers), datatransfer.TransportCompletedTransfer{Success: true}) + require.Len(t, h.transport.MessagesSent, 1) + response, ok := h.transport.MessagesSent[0].Message.(datatransfer.Response) require.True(t, ok) require.True(t, response.Accepted()) require.Equal(t, response.TransferID(), h.id) require.False(t, response.IsUpdate()) require.False(t, response.IsCancel()) + require.True(t, response.IsComplete()) require.True(t, response.IsPaused()) - require.True(t, response.IsVoucherResult()) - require.False(t, response.EmptyVoucherResult()) + require.True(t, response.IsValidationResult()) + require.True(t, response.EmptyVoucherResult()) response, err = h.transport.EventHandler.OnRequestReceived(channelID(h.id, h.peers), h.voucherUpdate) - require.EqualError(t, err, datatransfer.ErrResume.Error()) + require.NoError(t, err, nil) + require.Nil(t, response) + vr := testutil.NewTestTypedVoucher() + err = h.dt.UpdateValidationStatus(ctx, channelID(h.id, h.peers), datatransfer.ValidationResult{Accepted: true, VoucherResult: &vr}) + require.NoError(t, err) + require.Len(t, h.transport.MessagesSent, 2) + response, ok = h.transport.MessagesSent[1].Message.(datatransfer.Response) + require.True(t, ok) require.Equal(t, response.TransferID(), h.id) - require.True(t, response.IsVoucherResult()) + require.True(t, response.IsValidationResult()) require.False(t, response.IsPaused()) }, }, "validated, incomplete response": { expectedEvents: []datatransfer.EventCode{ datatransfer.Open, - datatransfer.NewVoucherResult, datatransfer.Accept, datatransfer.Error, datatransfer.CleanupComplete, }, configureValidator: func(sv *testutil.StubbedValidator) { sv.ExpectSuccessPull() - sv.StubResult(testutil.NewFakeDTType()) - }, - configureRevalidator: func(srv *testutil.StubbedRevalidator) { + sv.StubResult(datatransfer.ValidationResult{Accepted: true}) }, verify: func(t *testing.T, h *receiverHarness) { _, err := h.transport.EventHandler.OnRequestReceived(channelID(h.id, h.peers), h.pullRequest) require.NoError(t, err) - err = h.transport.EventHandler.OnChannelCompleted(channelID(h.id, h.peers), xerrors.Errorf("err")) - require.NoError(t, err) + h.transport.EventHandler.OnTransportEvent(channelID(h.id, h.peers), datatransfer.TransportCompletedTransfer{Success: false, ErrorMessage: "something went wrong"}) }, }, "new push request, customized transport": { - expectedEvents: []datatransfer.EventCode{datatransfer.Open, datatransfer.NewVoucherResult, datatransfer.Accept}, + expectedEvents: []datatransfer.EventCode{ + datatransfer.Open, + datatransfer.Accept, + }, configureValidator: func(sv *testutil.StubbedValidator) { sv.ExpectSuccessPush() - sv.StubResult(testutil.NewFakeDTType()) + sv.StubResult(datatransfer.ValidationResult{Accepted: true}) }, verify: func(t *testing.T, h *receiverHarness) { - err := h.dt.RegisterTransportConfigurer(h.voucher, func(channelID datatransfer.ChannelID, voucher datatransfer.Voucher, transport datatransfer.Transport) { + err := h.dt.RegisterTransportConfigurer(h.voucher.Type, func(channelID datatransfer.ChannelID, voucher datatransfer.TypedVoucher, transport datatransfer.Transport) { ft, ok := transport.(*testutil.FakeTransport) if !ok { return @@ -518,7 +567,7 @@ func TestDataTransferResponding(t *testing.T) { ft.RecordCustomizedTransfer(channelID, voucher) }) require.NoError(t, err) - h.network.Delegate.ReceiveRequest(h.ctx, h.peers[1], h.pushRequest) + _, _ = h.transport.EventHandler.OnRequestReceived(channelID(h.id, h.peers), h.pushRequest) require.Len(t, h.transport.CustomizedTransfers, 1) customizedTransfer := h.transport.CustomizedTransfers[0] require.Equal(t, channelID(h.id, h.peers), customizedTransfer.ChannelID) @@ -526,12 +575,16 @@ func TestDataTransferResponding(t *testing.T) { }, }, "new pull request, customized transport": { - expectedEvents: []datatransfer.EventCode{datatransfer.Open, datatransfer.Accept}, + expectedEvents: []datatransfer.EventCode{ + datatransfer.Open, + datatransfer.Accept, + }, configureValidator: func(sv *testutil.StubbedValidator) { sv.ExpectSuccessPull() + sv.StubResult(datatransfer.ValidationResult{Accepted: true}) }, verify: func(t *testing.T, h *receiverHarness) { - err := h.dt.RegisterTransportConfigurer(h.voucher, func(channelID datatransfer.ChannelID, voucher datatransfer.Voucher, transport datatransfer.Transport) { + err := h.dt.RegisterTransportConfigurer(h.voucher.Type, func(channelID datatransfer.ChannelID, voucher datatransfer.TypedVoucher, transport datatransfer.Transport) { ft, ok := transport.(*testutil.FakeTransport) if !ok { return @@ -555,10 +608,9 @@ func TestDataTransferResponding(t *testing.T) { defer cancel() h.ctx = ctx h.peers = testutil.GeneratePeers(2) - h.network = testutil.NewFakeNetwork(h.peers[0]) h.transport = testutil.NewFakeTransport() h.ds = dss.MutexWrap(datastore.NewMapDatastore()) - dt, err := NewDataTransfer(h.ds, h.network, h.transport) + dt, err := NewDataTransfer(h.ds, h.peers[0], h.transport) require.NoError(t, err) testutil.StartAndWaitForReady(ctx, t, dt) h.dt = dt @@ -567,36 +619,30 @@ func TestDataTransferResponding(t *testing.T) { events: make(chan datatransfer.EventCode, len(verify.expectedEvents)), } ev.setup(t, dt) - h.stor = testutil.AllSelector() - h.voucher = testutil.NewFakeDTType() + h.stor = selectorparse.CommonSelector_ExploreAllRecursively + h.voucher = testutil.NewTestTypedVoucher() h.baseCid = testutil.GenerateCids(1)[0] h.id = datatransfer.TransferID(rand.Int31()) - h.pullRequest, err = message.NewRequest(h.id, false, true, h.voucher.Type(), h.voucher, h.baseCid, h.stor) + h.pullRequest, err = message.NewRequest(h.id, false, true, &h.voucher, h.baseCid, h.stor) require.NoError(t, err) - h.pushRequest, err = message.NewRequest(h.id, false, false, h.voucher.Type(), h.voucher, h.baseCid, h.stor) + h.pushRequest, err = message.NewRequest(h.id, false, false, &h.voucher, h.baseCid, h.stor) require.NoError(t, err) h.pauseUpdate = message.UpdateRequest(h.id, true) require.NoError(t, err) h.resumeUpdate = message.UpdateRequest(h.id, false) require.NoError(t, err) - updateVoucher := testutil.NewFakeDTType() - h.voucherUpdate, err = message.VoucherRequest(h.id, updateVoucher.Type(), updateVoucher) + updateVoucher := testutil.NewTestTypedVoucher() + h.voucherUpdate = message.VoucherRequest(h.id, &updateVoucher) h.cancelUpdate = message.CancelRequest(h.id) require.NoError(t, err) h.sv = testutil.NewStubbedValidator() if verify.configureValidator != nil { verify.configureValidator(h.sv) } - require.NoError(t, h.dt.RegisterVoucherType(h.voucher, h.sv)) - h.srv = testutil.NewStubbedRevalidator() - if verify.configureRevalidator != nil { - verify.configureRevalidator(h.srv) - } - err = h.dt.RegisterRevalidator(updateVoucher, h.srv) + require.NoError(t, h.dt.RegisterVoucherType(h.voucher.Type, h.sv)) require.NoError(t, err) verify.verify(t, h) h.sv.VerifyExpectations(t) - h.srv.VerifyExpectations(t) ev.verify(ctx, t) }) } @@ -606,10 +652,9 @@ func TestDataTransferRestartResponding(t *testing.T) { // create network ctx := context.Background() testCases := map[string]struct { - expectedEvents []datatransfer.EventCode - configureValidator func(sv *testutil.StubbedValidator) - configureRevalidator func(sv *testutil.StubbedRevalidator) - verify func(t *testing.T, h *receiverHarness) + expectedEvents []datatransfer.EventCode + configureValidator func(sv *testutil.StubbedValidator) + verify func(t *testing.T, h *receiverHarness) }{ "receiving a pull restart response": { expectedEvents: []datatransfer.EventCode{datatransfer.Open, datatransfer.Restart, datatransfer.ResumeResponder}, @@ -618,54 +663,49 @@ func TestDataTransferRestartResponding(t *testing.T) { require.NoError(t, err) require.NotEmpty(t, channelID) - response, err := message.RestartResponse(channelID.ID, true, false, datatransfer.EmptyTypeIdentifier, nil) - require.NoError(t, err) + response := message.RestartResponse(channelID.ID, true, false, nil) err = h.transport.EventHandler.OnResponseReceived(channelID, response) require.NoError(t, err) }, }, "receiving a push restart request validates and opens a channel for pull": { - expectedEvents: []datatransfer.EventCode{datatransfer.Open, datatransfer.NewVoucherResult, datatransfer.Accept, - datatransfer.DataReceivedProgress, datatransfer.DataReceived, datatransfer.DataReceivedProgress, datatransfer.DataReceived, - datatransfer.NewVoucherResult, datatransfer.Restart}, + expectedEvents: []datatransfer.EventCode{ + datatransfer.Open, + datatransfer.Accept, + datatransfer.NewVoucherResult, + datatransfer.TransferInitiated, + datatransfer.DataReceivedProgress, + datatransfer.DataReceived, + datatransfer.DataReceivedProgress, + datatransfer.DataReceived, + datatransfer.Restart, + datatransfer.NewVoucherResult, + }, configureValidator: func(sv *testutil.StubbedValidator) { sv.ExpectSuccessPush() - sv.StubResult(testutil.NewFakeDTType()) + vr := testutil.NewTestTypedVoucher() + sv.StubResult(datatransfer.ValidationResult{Accepted: true, VoucherResult: &vr}) + sv.ExpectSuccessValidateRestart() + vr = testutil.NewTestTypedVoucher() + sv.StubRestartResult(datatransfer.ValidationResult{Accepted: true, VoucherResult: &vr}) }, verify: func(t *testing.T, h *receiverHarness) { // receive an incoming push - h.network.Delegate.ReceiveRequest(h.ctx, h.peers[1], h.pushRequest) + _, _ = h.transport.EventHandler.OnRequestReceived(channelID(h.id, h.peers), h.pushRequest) require.Len(t, h.sv.ValidationsReceived, 1) - require.Len(t, h.transport.OpenedChannels, 1) - require.Len(t, h.network.SentMessages, 0) // some cids are received chid := datatransfer.ChannelID{Initiator: h.peers[1], Responder: h.peers[0], ID: h.pushRequest.TransferID()} - testCids := testutil.GenerateCids(2) - ev, ok := h.dt.(datatransfer.EventsHandler) - require.True(t, ok) - require.NoError(t, ev.OnDataReceived(chid, cidlink.Link{Cid: testCids[0]}, 12345, 1, true)) - require.NoError(t, ev.OnDataReceived(chid, cidlink.Link{Cid: testCids[1]}, 12345, 2, true)) + h.transport.EventHandler.OnTransportEvent(chid, datatransfer.TransportInitiatedTransfer{}) + h.transport.EventHandler.OnTransportEvent(chid, datatransfer.TransportReceivedData{Size: 12345, Index: basicnode.NewInt(1)}) + h.transport.EventHandler.OnTransportEvent(chid, datatransfer.TransportReceivedData{Size: 12345, Index: basicnode.NewInt(2)}) // receive restart push request - req, err := message.NewRequest(h.pushRequest.TransferID(), true, false, h.voucher.Type(), h.voucher, - h.baseCid, h.stor) + req, err := message.NewRequest(h.pushRequest.TransferID(), true, false, &h.voucher, h.baseCid, h.stor) require.NoError(t, err) - h.network.Delegate.ReceiveRequest(h.ctx, h.peers[1], req) - require.Len(t, h.sv.ValidationsReceived, 2) - require.Len(t, h.transport.OpenedChannels, 2) - require.Len(t, h.network.SentMessages, 0) - - // validate channel that is opened second time - openChannel := h.transport.OpenedChannels[1] - require.Equal(t, openChannel.ChannelID, channelID(h.id, h.peers)) - require.Equal(t, openChannel.DataSender, h.peers[1]) - require.Equal(t, openChannel.Root, cidlink.Link{Cid: h.baseCid}) - require.Equal(t, openChannel.Selector, h.stor) - // assert do not send cids are sent - require.False(t, openChannel.Message.IsRequest()) - response, ok := openChannel.Message.(datatransfer.Response) - require.True(t, ok) + response, err := h.transport.EventHandler.OnRequestReceived(channelID(h.id, h.peers), req) + require.NoError(t, err) + require.Len(t, h.sv.RevalidationsReceived, 1) require.True(t, response.IsRestart()) require.True(t, response.Accepted()) require.Equal(t, response.TransferID(), h.id) @@ -673,40 +713,39 @@ func TestDataTransferRestartResponding(t *testing.T) { require.False(t, response.IsCancel()) require.False(t, response.IsPaused()) require.False(t, response.IsNew()) - require.True(t, response.IsVoucherResult()) + require.True(t, response.IsValidationResult()) - // validate the voucher that is validated the second time - vmsg := h.sv.ValidationsReceived[1] - require.Equal(t, h.voucher, vmsg.Voucher) - require.False(t, vmsg.IsPull) - require.Equal(t, h.stor, vmsg.Selector) - require.Equal(t, h.baseCid, vmsg.BaseCid) - require.Equal(t, h.peers[1], vmsg.Other) + vmsg := h.sv.RevalidationsReceived[0] + require.Equal(t, channelID(h.id, h.peers), vmsg.ChannelID) }, }, "receiving a pull restart request validates and sends a success response": { - expectedEvents: []datatransfer.EventCode{datatransfer.Open, datatransfer.NewVoucherResult, datatransfer.Accept, - datatransfer.NewVoucherResult, datatransfer.Restart}, + expectedEvents: []datatransfer.EventCode{ + datatransfer.Open, + datatransfer.Accept, + datatransfer.NewVoucherResult, + datatransfer.Restart, + datatransfer.NewVoucherResult, + }, configureValidator: func(sv *testutil.StubbedValidator) { sv.ExpectSuccessPull() - sv.StubResult(testutil.NewFakeDTType()) + vr := testutil.NewTestTypedVoucher() + sv.StubResult(datatransfer.ValidationResult{Accepted: true, VoucherResult: &vr}) + sv.ExpectSuccessValidateRestart() + vr = testutil.NewTestTypedVoucher() + sv.StubRestartResult(datatransfer.ValidationResult{Accepted: true, VoucherResult: &vr}) }, verify: func(t *testing.T, h *receiverHarness) { // receive an incoming pull - _, err := h.transport.EventHandler.OnRequestReceived(channelID(h.id, h.peers), h.pullRequest) - require.NoError(t, err) + _, _ = h.transport.EventHandler.OnRequestReceived(channelID(h.id, h.peers), h.pullRequest) require.Len(t, h.sv.ValidationsReceived, 1) - require.Len(t, h.transport.OpenedChannels, 0) - require.Len(t, h.network.SentMessages, 0) // receive restart pull request - restartReq, err := message.NewRequest(h.id, true, true, h.voucher.Type(), h.voucher, h.baseCid, h.stor) + restartReq, err := message.NewRequest(h.id, true, true, &h.voucher, h.baseCid, h.stor) require.NoError(t, err) response, err := h.transport.EventHandler.OnRequestReceived(channelID(h.id, h.peers), restartReq) require.NoError(t, err) - require.Len(t, h.sv.ValidationsReceived, 2) - require.Len(t, h.transport.OpenedChannels, 0) - require.Len(t, h.network.SentMessages, 0) + require.Len(t, h.sv.RevalidationsReceived, 1) // validate response require.True(t, response.IsRestart()) @@ -716,111 +755,125 @@ func TestDataTransferRestartResponding(t *testing.T) { require.False(t, response.IsCancel()) require.False(t, response.IsPaused()) require.False(t, response.IsNew()) - require.True(t, response.IsVoucherResult()) + require.True(t, response.IsValidationResult()) - // validate the voucher that is validated the second time - vmsg := h.sv.ValidationsReceived[1] - require.Equal(t, h.voucher, vmsg.Voucher) - require.True(t, vmsg.IsPull) - require.Equal(t, h.stor, vmsg.Selector) - require.Equal(t, h.baseCid, vmsg.BaseCid) - require.Equal(t, h.peers[1], vmsg.Other) + vmsg := h.sv.RevalidationsReceived[0] + require.Equal(t, channelID(h.id, h.peers), vmsg.ChannelID) }, }, "restart request fails if channel does not exist": { - expectedEvents: []datatransfer.EventCode{datatransfer.Open, datatransfer.NewVoucherResult, datatransfer.Accept}, + expectedEvents: []datatransfer.EventCode{ + datatransfer.Open, + datatransfer.Accept, + datatransfer.NewVoucherResult, + }, configureValidator: func(sv *testutil.StubbedValidator) { sv.ExpectSuccessPull() - sv.StubResult(testutil.NewFakeDTType()) + vr := testutil.NewTestTypedVoucher() + sv.StubResult(datatransfer.ValidationResult{Accepted: true, VoucherResult: &vr}) }, verify: func(t *testing.T, h *receiverHarness) { // receive an incoming pull - _, err := h.transport.EventHandler.OnRequestReceived(channelID(h.id, h.peers), h.pullRequest) - require.NoError(t, err) + _, _ = h.transport.EventHandler.OnRequestReceived(channelID(h.id, h.peers), h.pullRequest) require.Len(t, h.sv.ValidationsReceived, 1) - require.Len(t, h.transport.OpenedChannels, 0) - require.Len(t, h.network.SentMessages, 0) // receive restart pull request - restartReq, err := message.NewRequest(h.id, true, true, h.voucher.Type(), h.voucher, h.baseCid, h.stor) + restartReq, err := message.NewRequest(h.id, true, true, &h.voucher, h.baseCid, h.stor) require.NoError(t, err) p := testutil.GeneratePeers(1)[0] chid := datatransfer.ChannelID{ID: h.pullRequest.TransferID(), Initiator: p, Responder: h.peers[0]} _, err = h.transport.EventHandler.OnRequestReceived(chid, restartReq) - require.True(t, xerrors.As(err, new(*channels.ErrNotFound))) + require.EqualError(t, err, datatransfer.ErrChannelNotFound.Error()) }, }, "restart request fails if voucher validation fails": { - expectedEvents: []datatransfer.EventCode{datatransfer.Open, datatransfer.NewVoucherResult, datatransfer.Accept}, + expectedEvents: []datatransfer.EventCode{ + datatransfer.Open, + datatransfer.Accept, + datatransfer.NewVoucherResult, + datatransfer.Error, + datatransfer.CleanupComplete, + }, + configureValidator: func(sv *testutil.StubbedValidator) { + sv.ExpectSuccessPull() + vr := testutil.NewTestTypedVoucher() + sv.StubResult(datatransfer.ValidationResult{Accepted: true, VoucherResult: &vr}) + sv.ExpectSuccessValidateRestart() + sv.StubRestartResult(datatransfer.ValidationResult{Accepted: false}) + }, verify: func(t *testing.T, h *receiverHarness) { // receive an incoming pull - h.sv.ExpectSuccessPull() - h.sv.StubResult(testutil.NewFakeDTType()) - _, err := h.transport.EventHandler.OnRequestReceived(channelID(h.id, h.peers), h.pullRequest) - require.NoError(t, err) + _, _ = h.transport.EventHandler.OnRequestReceived(channelID(h.id, h.peers), h.pullRequest) require.Len(t, h.sv.ValidationsReceived, 1) - require.Len(t, h.transport.OpenedChannels, 0) - require.Len(t, h.network.SentMessages, 0) // receive restart pull request h.sv.ExpectErrorPull() - restartReq, err := message.NewRequest(h.id, true, true, h.voucher.Type(), h.voucher, h.baseCid, h.stor) + restartReq, err := message.NewRequest(h.id, true, true, &h.voucher, h.baseCid, h.stor) require.NoError(t, err) _, err = h.transport.EventHandler.OnRequestReceived(channelID(h.id, h.peers), restartReq) - require.EqualError(t, err, "failed to validate voucher: something went wrong") + require.EqualError(t, err, datatransfer.ErrRejected.Error()) }, }, "restart request fails if base cid does not match": { - expectedEvents: []datatransfer.EventCode{datatransfer.Open, datatransfer.NewVoucherResult, datatransfer.Accept}, + expectedEvents: []datatransfer.EventCode{ + datatransfer.Open, + datatransfer.Accept, + datatransfer.NewVoucherResult, + }, configureValidator: func(sv *testutil.StubbedValidator) { sv.ExpectSuccessPull() - sv.StubResult(testutil.NewFakeDTType()) + vr := testutil.NewTestTypedVoucher() + sv.StubResult(datatransfer.ValidationResult{Accepted: true, VoucherResult: &vr}) }, verify: func(t *testing.T, h *receiverHarness) { // receive an incoming pull chid := channelID(h.id, h.peers) - _, err := h.transport.EventHandler.OnRequestReceived(chid, h.pullRequest) - require.NoError(t, err) + _, _ = h.transport.EventHandler.OnRequestReceived(chid, h.pullRequest) require.Len(t, h.sv.ValidationsReceived, 1) - require.Len(t, h.transport.OpenedChannels, 0) - require.Len(t, h.network.SentMessages, 0) // receive restart pull request randCid := testutil.GenerateCids(1)[0] - restartReq, err := message.NewRequest(h.id, true, true, h.voucher.Type(), h.voucher, randCid, h.stor) + restartReq, err := message.NewRequest(h.id, true, true, &h.voucher, randCid, h.stor) require.NoError(t, err) _, err = h.transport.EventHandler.OnRequestReceived(chid, restartReq) require.EqualError(t, err, fmt.Sprintf("restart request for channel %s failed validation: base cid does not match", chid)) }, }, "restart request fails if voucher type is not decodable": { - expectedEvents: []datatransfer.EventCode{datatransfer.Open, datatransfer.NewVoucherResult, datatransfer.Accept}, + expectedEvents: []datatransfer.EventCode{ + datatransfer.Open, + datatransfer.Accept, + datatransfer.NewVoucherResult, + }, configureValidator: func(sv *testutil.StubbedValidator) { sv.ExpectSuccessPull() - sv.StubResult(testutil.NewFakeDTType()) + vr := testutil.NewTestTypedVoucher() + sv.StubResult(datatransfer.ValidationResult{Accepted: true, VoucherResult: &vr}) }, verify: func(t *testing.T, h *receiverHarness) { // receive an incoming pull chid := channelID(h.id, h.peers) - _, err := h.transport.EventHandler.OnRequestReceived(chid, h.pullRequest) - require.NoError(t, err) + _, _ = h.transport.EventHandler.OnRequestReceived(chid, h.pullRequest) require.Len(t, h.sv.ValidationsReceived, 1) - require.Len(t, h.transport.OpenedChannels, 0) - require.Len(t, h.network.SentMessages, 0) // receive restart pull request - restartReq, err := message.NewRequest(h.id, true, true, "rand", h.voucher, h.baseCid, h.stor) + restartReq, err := message.NewRequest(h.id, true, true, &datatransfer.TypedVoucher{Voucher: h.voucher.Voucher, Type: "rand"}, h.baseCid, h.stor) require.NoError(t, err) _, err = h.transport.EventHandler.OnRequestReceived(chid, restartReq) - require.EqualError(t, err, fmt.Sprintf("restart request for channel %s failed validation: failed to decode request voucher: unknown voucher type: rand", chid)) + require.EqualError(t, err, fmt.Sprintf("restart request for channel %s failed validation: channel and request voucher types do not match", chid)) }, }, "restart request fails if voucher does not match": { - expectedEvents: []datatransfer.EventCode{datatransfer.Open, datatransfer.NewVoucherResult, datatransfer.Accept}, + expectedEvents: []datatransfer.EventCode{ + datatransfer.Open, + datatransfer.Accept, + datatransfer.NewVoucherResult, + }, configureValidator: func(sv *testutil.StubbedValidator) { sv.ExpectSuccessPull() - sv.StubResult(testutil.NewFakeDTType()) + vr := testutil.NewTestTypedVoucher() + sv.StubResult(datatransfer.ValidationResult{Accepted: true, VoucherResult: &vr}) }, verify: func(t *testing.T, h *receiverHarness) { // receive an incoming pull @@ -828,20 +881,24 @@ func TestDataTransferRestartResponding(t *testing.T) { _, err := h.transport.EventHandler.OnRequestReceived(chid, h.pullRequest) require.NoError(t, err) require.Len(t, h.sv.ValidationsReceived, 1) - require.Len(t, h.transport.OpenedChannels, 0) - require.Len(t, h.network.SentMessages, 0) // receive restart pull request - v := testutil.NewFakeDTType() - v.Data = "rand" - restartReq, err := message.NewRequest(h.id, true, true, h.voucher.Type(), v, h.baseCid, h.stor) + v := testutil.NewTestTypedVoucherWith("rand") + restartReq, err := message.NewRequest(h.id, true, true, &v, h.baseCid, h.stor) require.NoError(t, err) _, err = h.transport.EventHandler.OnRequestReceived(chid, restartReq) require.EqualError(t, err, fmt.Sprintf("restart request for channel %s failed validation: channel and request vouchers do not match", chid)) }, }, "ReceiveRestartExistingChannelRequest: Reopen Pull Channel": { - expectedEvents: []datatransfer.EventCode{datatransfer.Open, datatransfer.DataReceivedProgress, datatransfer.DataReceived, datatransfer.DataReceivedProgress, datatransfer.DataReceived}, + expectedEvents: []datatransfer.EventCode{ + datatransfer.Open, + datatransfer.TransferInitiated, + datatransfer.DataReceivedProgress, + datatransfer.DataReceived, + datatransfer.DataReceivedProgress, + datatransfer.DataReceived, + }, configureValidator: func(sv *testutil.StubbedValidator) { }, verify: func(t *testing.T, h *receiverHarness) { @@ -851,32 +908,27 @@ func TestDataTransferRestartResponding(t *testing.T) { require.NotEmpty(t, channelID) // some cids should already be received - testCids := testutil.GenerateCids(2) - ev, ok := h.dt.(datatransfer.EventsHandler) - require.True(t, ok) - require.NoError(t, ev.OnDataReceived(channelID, cidlink.Link{Cid: testCids[0]}, 12345, 1, true)) - require.NoError(t, ev.OnDataReceived(channelID, cidlink.Link{Cid: testCids[1]}, 12345, 2, true)) + // some cids are received + h.transport.EventHandler.OnTransportEvent(channelID, datatransfer.TransportInitiatedTransfer{}) + h.transport.EventHandler.OnTransportEvent(channelID, datatransfer.TransportReceivedData{Size: 12345, Index: basicnode.NewInt(1)}) + h.transport.EventHandler.OnTransportEvent(channelID, datatransfer.TransportReceivedData{Size: 12345, Index: basicnode.NewInt(2)}) // send a request to restart the same pull channel - restartReq := message.RestartExistingChannelRequest(channelID) - h.network.Delegate.ReceiveRestartExistingChannelRequest(ctx, h.peers[1], restartReq) + h.transport.EventHandler.OnTransportEvent(channelID, datatransfer.TransportReceivedRestartExistingChannelRequest{}) - require.Len(t, h.transport.OpenedChannels, 2) - require.Len(t, h.network.SentMessages, 0) + require.Len(t, h.transport.OpenedChannels, 1) + require.Len(t, h.transport.RestartedChannels, 1) // assert correct channel was created in response to this - require.Len(t, h.transport.OpenedChannels, 2) - openChannel := h.transport.OpenedChannels[1] - require.Equal(t, openChannel.ChannelID, channelID) - require.Equal(t, openChannel.DataSender, h.peers[1]) - require.Equal(t, openChannel.Root, cidlink.Link{Cid: h.baseCid}) - require.Equal(t, openChannel.Selector, h.stor) - require.True(t, openChannel.Message.IsRequest()) - require.EqualValues(t, len(testCids), openChannel.Channel.ReceivedCidsTotal()) + restartedChannel := h.transport.RestartedChannels[0] + require.Equal(t, restartedChannel.Channel.ChannelID(), channelID) + require.Equal(t, restartedChannel.Channel.Sender(), h.peers[1]) + require.Equal(t, restartedChannel.Channel.BaseCID(), h.baseCid) + require.Equal(t, restartedChannel.Channel.Selector(), h.stor) + require.EqualValues(t, basicnode.NewInt(2), restartedChannel.Channel.ReceivedIndex()) // assert a restart request is in the channel - request, ok := openChannel.Message.(datatransfer.Request) - require.True(t, ok) + request := restartedChannel.Message require.True(t, request.IsRestart()) require.Equal(t, request.TransferID(), channelID.ID) require.Equal(t, request.BaseCid(), h.baseCid) @@ -888,11 +940,13 @@ func TestDataTransferRestartResponding(t *testing.T) { receivedSelector, err := request.Selector() require.NoError(t, err) require.Equal(t, receivedSelector, h.stor) - testutil.AssertFakeDTVoucher(t, request, h.voucher) + testutil.AssertTestVoucher(t, request, h.voucher) }, }, "ReceiveRestartExistingChannelRequest: Resend Push Request": { - expectedEvents: []datatransfer.EventCode{datatransfer.Open}, + expectedEvents: []datatransfer.EventCode{ + datatransfer.Open, + }, configureValidator: func(sv *testutil.StubbedValidator) { }, verify: func(t *testing.T, h *receiverHarness) { @@ -902,20 +956,15 @@ func TestDataTransferRestartResponding(t *testing.T) { require.NotEmpty(t, channelID) // send a request to restart the same push request - restartReq := message.RestartExistingChannelRequest(channelID) - h.network.Delegate.ReceiveRestartExistingChannelRequest(ctx, h.peers[1], restartReq) + h.transport.EventHandler.OnTransportEvent(channelID, datatransfer.TransportReceivedRestartExistingChannelRequest{}) + require.NoError(t, err) - // assert correct message was sent in response to this - require.Len(t, h.transport.OpenedChannels, 0) - require.Len(t, h.network.SentMessages, 2) + require.Len(t, h.transport.OpenedChannels, 1) + require.Len(t, h.transport.RestartedChannels, 1) // assert restart request is well formed - messageReceived := h.network.SentMessages[1] - require.Equal(t, messageReceived.PeerID, h.peers[1]) - received := messageReceived.Message - require.True(t, received.IsRequest()) - receivedRequest, ok := received.(datatransfer.Request) - require.True(t, ok) + restartedChannel := h.transport.RestartedChannels[0] + receivedRequest := restartedChannel.Message require.Equal(t, receivedRequest.TransferID(), channelID.ID) require.Equal(t, receivedRequest.BaseCid(), h.baseCid) require.False(t, receivedRequest.IsCancel()) @@ -927,44 +976,7 @@ func TestDataTransferRestartResponding(t *testing.T) { receivedSelector, err := receivedRequest.Selector() require.NoError(t, err) require.Equal(t, receivedSelector, h.stor) - testutil.AssertFakeDTVoucher(t, receivedRequest, h.voucher) - }, - }, - "ReceiveRestartExistingChannelRequest: errors if peer is not the initiator": { - expectedEvents: []datatransfer.EventCode{datatransfer.Open, datatransfer.Accept}, - configureValidator: func(sv *testutil.StubbedValidator) { - }, - verify: func(t *testing.T, h *receiverHarness) { - // create an incoming push first - h.network.Delegate.ReceiveRequest(h.ctx, h.peers[1], h.pushRequest) - require.Len(t, h.sv.ValidationsReceived, 1) - - // restart req does not anything as we are not the initiator - chid := datatransfer.ChannelID{Initiator: h.peers[1], Responder: h.peers[0], ID: h.pushRequest.TransferID()} - restartReq := message.RestartExistingChannelRequest(chid) - h.network.Delegate.ReceiveRestartExistingChannelRequest(ctx, h.peers[1], restartReq) - - require.Len(t, h.transport.OpenedChannels, 1) - require.Len(t, h.network.SentMessages, 0) - }, - }, - "ReceiveRestartExistingChannelRequest: errors if sending peer is not the counter-party on the channel": { - expectedEvents: []datatransfer.EventCode{datatransfer.Open}, - configureValidator: func(sv *testutil.StubbedValidator) { - }, - verify: func(t *testing.T, h *receiverHarness) { - // create an outgoing push request first - p := testutil.GeneratePeers(1)[0] - channelID, err := h.dt.OpenPushDataChannel(h.ctx, p, h.voucher, h.baseCid, h.stor) - require.NoError(t, err) - require.NotEmpty(t, channelID) - - // sending peer is not the counter-party on the channel - restartReq := message.RestartExistingChannelRequest(channelID) - h.network.Delegate.ReceiveRestartExistingChannelRequest(ctx, h.peers[1], restartReq) - - require.Len(t, h.transport.OpenedChannels, 0) - require.Len(t, h.network.SentMessages, 1) + testutil.AssertTestVoucher(t, receivedRequest, h.voucher) }, }, } @@ -975,10 +987,9 @@ func TestDataTransferRestartResponding(t *testing.T) { defer cancel() h.ctx = ctx h.peers = testutil.GeneratePeers(2) - h.network = testutil.NewFakeNetwork(h.peers[0]) h.transport = testutil.NewFakeTransport() h.ds = dss.MutexWrap(datastore.NewMapDatastore()) - dt, err := NewDataTransfer(h.ds, h.network, h.transport) + dt, err := NewDataTransfer(h.ds, h.peers[0], h.transport) require.NoError(t, err) testutil.StartAndWaitForReady(ctx, t, dt) h.dt = dt @@ -987,20 +998,20 @@ func TestDataTransferRestartResponding(t *testing.T) { events: make(chan datatransfer.EventCode, len(verify.expectedEvents)), } ev.setup(t, dt) - h.stor = testutil.AllSelector() - h.voucher = testutil.NewFakeDTType() + h.stor = selectorparse.CommonSelector_ExploreAllRecursively + h.voucher = testutil.NewTestTypedVoucher() h.baseCid = testutil.GenerateCids(1)[0] h.id = datatransfer.TransferID(rand.Int31()) - h.pullRequest, err = message.NewRequest(h.id, false, true, h.voucher.Type(), h.voucher, h.baseCid, h.stor) + h.pullRequest, err = message.NewRequest(h.id, false, true, &h.voucher, h.baseCid, h.stor) require.NoError(t, err) - h.pushRequest, err = message.NewRequest(h.id, false, false, h.voucher.Type(), h.voucher, h.baseCid, h.stor) + h.pushRequest, err = message.NewRequest(h.id, false, false, &h.voucher, h.baseCid, h.stor) require.NoError(t, err) h.sv = testutil.NewStubbedValidator() if verify.configureValidator != nil { verify.configureValidator(h.sv) } - require.NoError(t, h.dt.RegisterVoucherType(h.voucher, h.sv)) + require.NoError(t, h.dt.RegisterVoucherType(h.voucher.Type, h.sv)) verify.verify(t, h) h.sv.VerifyExpectations(t) @@ -1019,14 +1030,12 @@ type receiverHarness struct { cancelUpdate datatransfer.Request ctx context.Context peers []peer.ID - network *testutil.FakeNetwork transport *testutil.FakeTransport sv *testutil.StubbedValidator - srv *testutil.StubbedRevalidator ds datastore.Batching dt datatransfer.Manager - stor ipld.Node - voucher *testutil.FakeDTType + stor datamodel.Node + voucher datatransfer.TypedVoucher baseCid cid.Cid } diff --git a/impl/restart.go b/impl/restart.go index 697985cd..9381aba4 100644 --- a/impl/restart.go +++ b/impl/restart.go @@ -1,143 +1,67 @@ package impl import ( - "bytes" "context" - cidlink "github.com/ipld/go-ipld-prime/linking/cid" + "github.com/ipld/go-ipld-prime" "github.com/libp2p/go-libp2p-core/peer" "golang.org/x/xerrors" - datatransfer "github.com/filecoin-project/go-data-transfer" - "github.com/filecoin-project/go-data-transfer/channels" - "github.com/filecoin-project/go-data-transfer/encoding" - "github.com/filecoin-project/go-data-transfer/message" + datatransfer "github.com/filecoin-project/go-data-transfer/v2" + "github.com/filecoin-project/go-data-transfer/v2/channels" + "github.com/filecoin-project/go-data-transfer/v2/message" ) -// ChannelDataTransferType identifies the type of a data transfer channel for the purposes of a restart -type ChannelDataTransferType int +func (m *manager) restartManagerPeerReceive(ctx context.Context, channel datatransfer.ChannelState) error { -const ( - // ManagerPeerCreatePull is the type of a channel wherein the manager peer created a Pull Data Transfer - ManagerPeerCreatePull ChannelDataTransferType = iota - - // ManagerPeerCreatePush is the type of a channel wherein the manager peer created a Push Data Transfer - ManagerPeerCreatePush - - // ManagerPeerReceivePull is the type of a channel wherein the manager peer received a Pull Data Transfer Request - ManagerPeerReceivePull - - // ManagerPeerReceivePush is the type of a channel wherein the manager peer received a Push Data Transfer Request - ManagerPeerReceivePush -) - -func (m *manager) restartManagerPeerReceivePush(ctx context.Context, channel datatransfer.ChannelState) error { - if err := m.validateRestartVoucher(channel, false); err != nil { - return xerrors.Errorf("failed to restart channel, validation error: %w", err) + if !m.transport.Capabilities().Restartable { + return datatransfer.ErrUnsupported } - // send a libp2p message to the other peer asking to send a "restart push request" - req := message.RestartExistingChannelRequest(channel.ChannelID()) - - if err := m.dataTransferNetwork.SendMessage(ctx, channel.OtherPeer(), req); err != nil { - return xerrors.Errorf("unable to send restart request: %w", err) + result, err := m.validateRestart(channel) + if err != nil { + return xerrors.Errorf("failed to restart channel, validation error: %w", err) } - return nil -} - -func (m *manager) restartManagerPeerReceivePull(ctx context.Context, channel datatransfer.ChannelState) error { - if err := m.validateRestartVoucher(channel, true); err != nil { - return xerrors.Errorf("failed to restart channel, validation error: %w", err) + if !result.Accepted { + return datatransfer.ErrRejected } + // send a libp2p message to the other peer asking to send a "restart push request" req := message.RestartExistingChannelRequest(channel.ChannelID()) - // send a libp2p message to the other peer asking to send a "restart pull request" - if err := m.dataTransferNetwork.SendMessage(ctx, channel.OtherPeer(), req); err != nil { + if err := m.transport.SendMessage(ctx, channel.ChannelID(), req); err != nil { return xerrors.Errorf("unable to send restart request: %w", err) } - - return nil -} - -func (m *manager) validateRestartVoucher(channel datatransfer.ChannelState, isPull bool) error { - // re-validate the original voucher received for safety - chid := channel.ChannelID() - - // recreate the request that would have led to this pull channel being created for validation - req, err := message.NewRequest(chid.ID, false, isPull, channel.Voucher().Type(), channel.Voucher(), - channel.BaseCID(), channel.Selector()) - if err != nil { - return err - } - - // revalidate the voucher by reconstructing the request that would have led to the creation of this channel - if _, _, err := m.validateVoucher(true, chid, channel.OtherPeer(), req, isPull, channel.BaseCID(), channel.Selector()); err != nil { - return err - } - return nil } -func (m *manager) openPushRestartChannel(ctx context.Context, channel datatransfer.ChannelState) error { +func (m *manager) openRestartChannel(ctx context.Context, channel datatransfer.ChannelState) error { selector := channel.Selector() voucher := channel.Voucher() baseCid := channel.BaseCID() requestTo := channel.OtherPeer() chid := channel.ChannelID() - req, err := message.NewRequest(chid.ID, true, false, voucher.Type(), voucher, baseCid, selector) + if !m.transport.Capabilities().Restartable { + return datatransfer.ErrUnsupported + } + req, err := message.NewRequest(chid.ID, true, channel.IsPull(), &voucher, baseCid, selector) if err != nil { return err } - processor, has := m.transportConfigurers.Processor(voucher.Type()) + processor, has := m.transportConfigurers.Processor(voucher.Type) if has { transportConfigurer := processor.(datatransfer.TransportConfigurer) transportConfigurer(chid, voucher, m.transport) } - m.dataTransferNetwork.Protect(requestTo, chid.String()) // Monitor the state of the connection for the channel - monitoredChan := m.channelMonitor.AddPushChannel(chid) + monitoredChan := m.channelMonitor.AddChannel(chid, channel.IsPull()) log.Infof("sending push restart channel to %s for channel %s", requestTo, chid) - if err := m.dataTransferNetwork.SendMessage(ctx, requestTo, req); err != nil { - // If push channel monitoring is enabled, shutdown the monitor as it - // wasn't possible to start the data transfer - if monitoredChan != nil { - monitoredChan.Shutdown() - } - - return xerrors.Errorf("Unable to send restart request: %w", err) - } - - return nil -} - -func (m *manager) openPullRestartChannel(ctx context.Context, channel datatransfer.ChannelState) error { - selector := channel.Selector() - voucher := channel.Voucher() - baseCid := channel.BaseCID() - requestTo := channel.OtherPeer() - chid := channel.ChannelID() - - req, err := message.NewRequest(chid.ID, true, true, voucher.Type(), voucher, baseCid, selector) + err = m.transport.RestartChannel(ctx, channel, req) if err != nil { - return err - } - - processor, has := m.transportConfigurers.Processor(voucher.Type()) - if has { - transportConfigurer := processor.(datatransfer.TransportConfigurer) - transportConfigurer(chid, voucher, m.transport) - } - m.dataTransferNetwork.Protect(requestTo, chid.String()) - - // Monitor the state of the connection for the channel - monitoredChan := m.channelMonitor.AddPullChannel(chid) - log.Infof("sending open channel to %s to restart channel %s", requestTo, chid) - if err := m.transport.OpenChannel(ctx, requestTo, chid, cidlink.Link{Cid: baseCid}, selector, channel, req); err != nil { // If pull channel monitoring is enabled, shutdown the monitor as it // wasn't possible to start the data transfer if monitoredChan != nil { @@ -150,12 +74,7 @@ func (m *manager) openPullRestartChannel(ctx context.Context, channel datatransf return nil } -func (m *manager) validateRestartRequest(ctx context.Context, otherPeer peer.ID, chid datatransfer.ChannelID, req datatransfer.Request) error { - // channel should exist - channel, err := m.channels.GetByID(ctx, chid) - if err != nil { - return err - } +func (m *manager) validateRestartRequest(ctx context.Context, otherPeer peer.ID, channel datatransfer.ChannelState, req datatransfer.Request) error { // channel is not terminated if channels.IsChannelTerminated(channel.Status()) { @@ -173,24 +92,16 @@ func (m *manager) validateRestartRequest(ctx context.Context, otherPeer peer.ID, } // vouchers should match - reqVoucher, err := m.decodeVoucher(req, m.validatedTypes) + reqVoucher, err := req.Voucher() if err != nil { - return xerrors.Errorf("failed to decode request voucher: %w", err) + return xerrors.Errorf("failed to fetch request voucher: %w", err) } - if reqVoucher.Type() != channel.Voucher().Type() { + channelVoucher := channel.Voucher() + if req.VoucherType() != channelVoucher.Type { return xerrors.New("channel and request voucher types do not match") } - reqBz, err := encoding.Encode(reqVoucher) - if err != nil { - return xerrors.New("failed to encode request voucher") - } - channelBz, err := encoding.Encode(channel.Voucher()) - if err != nil { - return xerrors.New("failed to encode channel voucher") - } - - if !bytes.Equal(reqBz, channelBz) { + if !ipld.DeepEqual(reqVoucher, channelVoucher.Voucher) { return xerrors.New("channel and request vouchers do not match") } diff --git a/impl/utils.go b/impl/utils.go index 0b4c4e6c..4a74d838 100644 --- a/impl/utils.go +++ b/impl/utils.go @@ -4,13 +4,11 @@ import ( "context" "github.com/ipfs/go-cid" - "github.com/ipld/go-ipld-prime" + "github.com/ipld/go-ipld-prime/datamodel" "github.com/libp2p/go-libp2p-core/peer" - "golang.org/x/xerrors" - datatransfer "github.com/filecoin-project/go-data-transfer" - "github.com/filecoin-project/go-data-transfer/message" - "github.com/filecoin-project/go-data-transfer/registry" + datatransfer "github.com/filecoin-project/go-data-transfer/v2" + "github.com/filecoin-project/go-data-transfer/v2/message" ) type statusList []datatransfer.Status @@ -31,37 +29,17 @@ var resumeTransportStatesResponder = statusList{ } // newRequest encapsulates message creation -func (m *manager) newRequest(ctx context.Context, selector ipld.Node, isPull bool, voucher datatransfer.Voucher, baseCid cid.Cid, to peer.ID) (datatransfer.Request, error) { +func (m *manager) newRequest(ctx context.Context, selector datamodel.Node, isPull bool, voucher datatransfer.TypedVoucher, baseCid cid.Cid, to peer.ID) (datatransfer.Request, error) { // Generate a new transfer ID for the request tid := datatransfer.TransferID(m.transferIDGen.next()) - return message.NewRequest(tid, false, isPull, voucher.Type(), voucher, baseCid, selector) + return message.NewRequest(tid, false, isPull, &voucher, baseCid, selector) } -func (m *manager) response(isRestart bool, isNew bool, err error, tid datatransfer.TransferID, voucherResult datatransfer.VoucherResult) (datatransfer.Response, error) { - isAccepted := err == nil || err == datatransfer.ErrPause || err == datatransfer.ErrResume - isPaused := err == datatransfer.ErrPause - resultType := datatransfer.EmptyTypeIdentifier - if voucherResult != nil { - resultType = voucherResult.Type() - } - if isRestart { - return message.RestartResponse(tid, isAccepted, isPaused, resultType, voucherResult) - } - - if isNew { - return message.NewResponse(tid, isAccepted, isPaused, resultType, voucherResult) - } - return message.VoucherResultResponse(tid, isAccepted, isPaused, resultType, voucherResult) -} - -func (m *manager) completeResponse(err error, tid datatransfer.TransferID, voucherResult datatransfer.VoucherResult) (datatransfer.Response, error) { - isAccepted := err == nil || err == datatransfer.ErrPause || err == datatransfer.ErrResume - isPaused := err == datatransfer.ErrPause - resultType := datatransfer.EmptyTypeIdentifier - if voucherResult != nil { - resultType = voucherResult.Type() +func (m *manager) otherPeer(chid datatransfer.ChannelID) peer.ID { + if chid.Initiator == m.peerID { + return chid.Responder } - return message.CompleteResponse(tid, isAccepted, isPaused, resultType, voucherResult) + return chid.Initiator } func (m *manager) resume(chid datatransfer.ChannelID) error { @@ -112,29 +90,3 @@ func (m *manager) cancelMessage(chid datatransfer.ChannelID) datatransfer.Messag } return message.CancelResponse(chid.ID) } - -func (m *manager) decodeVoucherResult(response datatransfer.Response) (datatransfer.VoucherResult, error) { - vtypStr := datatransfer.TypeIdentifier(response.VoucherResultType()) - decoder, has := m.resultTypes.Decoder(vtypStr) - if !has { - return nil, xerrors.Errorf("unknown voucher result type: %s", vtypStr) - } - encodable, err := response.VoucherResult(decoder) - if err != nil { - return nil, err - } - return encodable.(datatransfer.Registerable), nil -} - -func (m *manager) decodeVoucher(request datatransfer.Request, registry *registry.Registry) (datatransfer.Voucher, error) { - vtypStr := datatransfer.TypeIdentifier(request.VoucherType()) - decoder, has := registry.Decoder(vtypStr) - if !has { - return nil, xerrors.Errorf("unknown voucher type: %s", vtypStr) - } - encodable, err := request.Voucher(decoder) - if err != nil { - return nil, err - } - return encodable.(datatransfer.Registerable), nil -} diff --git a/testutil/fixtures/lorem.txt b/itest/fixtures/lorem.txt similarity index 100% rename from testutil/fixtures/lorem.txt rename to itest/fixtures/lorem.txt diff --git a/testutil/fixtures/lorem_large.txt b/itest/fixtures/lorem_large.txt similarity index 100% rename from testutil/fixtures/lorem_large.txt rename to itest/fixtures/lorem_large.txt diff --git a/testutil/gstestdata.go b/itest/gstestdata.go similarity index 89% rename from testutil/gstestdata.go rename to itest/gstestdata.go index fef23447..2b1b808a 100644 --- a/testutil/gstestdata.go +++ b/itest/gstestdata.go @@ -1,4 +1,4 @@ -package testutil +package itest import ( "bytes" @@ -30,35 +30,29 @@ import ( ihelper "github.com/ipfs/go-unixfs/importer/helpers" "github.com/ipld/go-ipld-prime" cidlink "github.com/ipld/go-ipld-prime/linking/cid" - basicnode "github.com/ipld/go-ipld-prime/node/basic" - "github.com/ipld/go-ipld-prime/traversal/selector" - "github.com/ipld/go-ipld-prime/traversal/selector/builder" "github.com/libp2p/go-libp2p-core/host" "github.com/libp2p/go-libp2p-core/protocol" mocknet "github.com/libp2p/go-libp2p/p2p/net/mock" "github.com/stretchr/testify/require" - datatransfer "github.com/filecoin-project/go-data-transfer" - "github.com/filecoin-project/go-data-transfer/network" - gstransport "github.com/filecoin-project/go-data-transfer/transport/graphsync" - "github.com/filecoin-project/go-data-transfer/transport/graphsync/extension" + datatransfer "github.com/filecoin-project/go-data-transfer/v2" + gstransport "github.com/filecoin-project/go-data-transfer/v2/transport/graphsync" + "github.com/filecoin-project/go-data-transfer/v2/transport/graphsync/extension" + "github.com/filecoin-project/go-data-transfer/v2/transport/helpers/network" ) -var allSelector ipld.Node - const loremFile = "lorem.txt" +const loremFileTransferBytes = 20439 -func init() { - ssb := builder.NewSelectorSpecBuilder(basicnode.Prototype.Any) - allSelector = ssb.ExploreRecursive(selector.RecursionLimitNone(), - ssb.ExploreAll(ssb.ExploreRecursiveEdge())).Node() -} +const loremLargeFile = "lorem_large.txt" +const loremLargeFileTransferBytes = 217452 const unixfsChunkSize uint64 = 1 << 10 const unixfsLinksPerLevel = 1024 var extsForProtocol = map[protocol.ID]graphsync.ExtensionName{ - datatransfer.ProtocolDataTransfer1_2: extension.ExtensionDataTransfer1_1, + network.ProtocolDataTransfer1_2: extension.ExtensionDataTransfer1_1, + network.ProtocolFilDataTransfer1_2: extension.ExtensionDataTransfer1_1, } // GraphsyncTestingData is a test harness for testing data transfer on top of @@ -82,7 +76,8 @@ type GraphsyncTestingData struct { GsNet2 gsnet.GraphSyncNetwork DtNet1 network.DataTransferNetwork DtNet2 network.DataTransferNetwork - AllSelector ipld.Node + Gs1 graphsync.GraphExchange + Gs2 graphsync.GraphExchange OrigBytes []byte TempDir1 string TempDir2 string @@ -151,7 +146,6 @@ func NewGraphsyncTestingData(ctx context.Context, t *testing.T, host1Protocols [ require.NoError(t, err) gsData.TempDir2 = tempdir // create a selector for the whole UnixFS dag - gsData.AllSelector = allSelector gsData.host1Protocols = host1Protocols gsData.host2Protocols = host2Protocols return gsData @@ -159,13 +153,17 @@ func NewGraphsyncTestingData(ctx context.Context, t *testing.T, host1Protocols [ // SetupGraphsyncHost1 sets up a new, real graphsync instance on top of the first host func (gsData *GraphsyncTestingData) SetupGraphsyncHost1() graphsync.GraphExchange { + if gsData.Gs1 != nil { + return gsData.Gs1 + } // setup graphsync if gsData.gs1Cancel != nil { gsData.gs1Cancel() } gsCtx, gsCancel := context.WithCancel(gsData.Ctx) gsData.gs1Cancel = gsCancel - return gsimpl.New(gsCtx, gsData.GsNet1, gsData.LinkSystem1) + gsData.Gs1 = gsimpl.New(gsCtx, gsData.GsNet1, gsData.LinkSystem1) + return gsData.Gs1 } // SetupGSTransportHost1 sets up a new grapshync transport over real graphsync on the first host @@ -180,18 +178,22 @@ func (gsData *GraphsyncTestingData) SetupGSTransportHost1(opts ...gstransport.Op opts = append(opts, gstransport.SupportedExtensions(supportedExtensions)) } - return gstransport.NewTransport(gsData.Host1.ID(), gs, opts...) + return gstransport.NewTransport(gs, gsData.DtNet1, opts...) } // SetupGraphsyncHost2 sets up a new, real graphsync instance on top of the second host func (gsData *GraphsyncTestingData) SetupGraphsyncHost2() graphsync.GraphExchange { + if gsData.Gs2 != nil { + return gsData.Gs2 + } // setup graphsync if gsData.gs2Cancel != nil { gsData.gs2Cancel() } gsCtx, gsCancel := context.WithCancel(gsData.Ctx) gsData.gs2Cancel = gsCancel - return gsimpl.New(gsCtx, gsData.GsNet2, gsData.LinkSystem2) + gsData.Gs2 = gsimpl.New(gsCtx, gsData.GsNet2, gsData.LinkSystem2) + return gsData.Gs2 } // SetupGSTransportHost2 sets up a new grapshync transport over real graphsync on the second host @@ -205,7 +207,7 @@ func (gsData *GraphsyncTestingData) SetupGSTransportHost2(opts ...gstransport.Op } opts = append(opts, gstransport.SupportedExtensions(supportedExtensions)) } - return gstransport.NewTransport(gsData.Host2.ID(), gs, opts...) + return gstransport.NewTransport(gs, gsData.DtNet2, opts...) } // LoadUnixFSFile loads a fixtures file we can test dag transfer with diff --git a/impl/integration_test.go b/itest/integration_test.go similarity index 75% rename from impl/integration_test.go rename to itest/integration_test.go index 6142c7df..a98a67d8 100644 --- a/impl/integration_test.go +++ b/itest/integration_test.go @@ -1,4 +1,4 @@ -package impl_test +package itest import ( "bytes" @@ -9,6 +9,7 @@ import ( "time" "github.com/ipfs/go-blockservice" + "github.com/ipfs/go-cid" "github.com/ipfs/go-datastore" "github.com/ipfs/go-datastore/namespace" dss "github.com/ipfs/go-datastore/sync" @@ -26,30 +27,25 @@ import ( "github.com/ipfs/go-unixfs/importer/balanced" ihelper "github.com/ipfs/go-unixfs/importer/helpers" "github.com/ipld/go-ipld-prime" + "github.com/ipld/go-ipld-prime/datamodel" cidlink "github.com/ipld/go-ipld-prime/linking/cid" + selectorparse "github.com/ipld/go-ipld-prime/traversal/selector/parse" "github.com/libp2p/go-libp2p-core/host" "github.com/libp2p/go-libp2p-core/peer" "github.com/libp2p/go-libp2p-core/protocol" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - datatransfer "github.com/filecoin-project/go-data-transfer" - "github.com/filecoin-project/go-data-transfer/channelmonitor" - "github.com/filecoin-project/go-data-transfer/encoding" - . "github.com/filecoin-project/go-data-transfer/impl" - "github.com/filecoin-project/go-data-transfer/message" - "github.com/filecoin-project/go-data-transfer/network" - "github.com/filecoin-project/go-data-transfer/testutil" - tp "github.com/filecoin-project/go-data-transfer/transport/graphsync" - "github.com/filecoin-project/go-data-transfer/transport/graphsync/extension" + datatransfer "github.com/filecoin-project/go-data-transfer/v2" + "github.com/filecoin-project/go-data-transfer/v2/channelmonitor" + . "github.com/filecoin-project/go-data-transfer/v2/impl" + "github.com/filecoin-project/go-data-transfer/v2/message" + "github.com/filecoin-project/go-data-transfer/v2/testutil" + tp "github.com/filecoin-project/go-data-transfer/v2/transport/graphsync" + "github.com/filecoin-project/go-data-transfer/v2/transport/graphsync/extension" + "github.com/filecoin-project/go-data-transfer/v2/transport/helpers/network" ) -const loremFile = "lorem.txt" -const loremFileTransferBytes = 20439 - -const loremLargeFile = "lorem_large.txt" -const loremLargeFileTransferBytes = 217452 - // nil means use the default protocols // tests data transfer for the following protocol combinations: // default protocol -> default protocols @@ -59,7 +55,10 @@ var protocolsForTest = map[string]struct { host1Protocols []protocol.ID host2Protocols []protocol.ID }{ - "(v1.2 -> v1.2)": {nil, nil}, + "(wrapped v1.2 -> wrapped v1.2)": {nil, nil}, + "(v1.2 -> wrapped v1.2)": {[]protocol.ID{network.ProtocolFilDataTransfer1_2}, nil}, + "(wrapped v1.2 -> v1.2)": {nil, []protocol.ID{network.ProtocolFilDataTransfer1_2}}, + "(v1.2 -> v1.2)": {[]protocol.ID{network.ProtocolFilDataTransfer1_2}, []protocol.ID{network.ProtocolFilDataTransfer1_2}}, } // tests data transfer for the protocol combinations that support restart messages @@ -67,7 +66,10 @@ var protocolsForRestartTest = map[string]struct { host1Protocols []protocol.ID host2Protocols []protocol.ID }{ - "(v1.2 -> v1.2)": {nil, nil}, + "(wrapped v1.2 -> wrapped v1.2)": {nil, nil}, + "(v1.2 -> wrapped v1.2)": {[]protocol.ID{network.ProtocolFilDataTransfer1_2}, nil}, + "(wrapped v1.2 -> v1.2)": {nil, []protocol.ID{network.ProtocolFilDataTransfer1_2}}, + "(v1.2 -> v1.2)": {[]protocol.ID{network.ProtocolFilDataTransfer1_2}, []protocol.ID{network.ProtocolFilDataTransfer1_2}}, } func TestRoundTrip(t *testing.T) { @@ -136,17 +138,17 @@ func TestRoundTrip(t *testing.T) { ctx, cancel := context.WithTimeout(ctx, 10*time.Second) defer cancel() - gsData := testutil.NewGraphsyncTestingData(ctx, t, ps.host1Protocols, ps.host2Protocols) + gsData := NewGraphsyncTestingData(ctx, t, ps.host1Protocols, ps.host2Protocols) host1 := gsData.Host1 // initiator, data sender host2 := gsData.Host2 // data recipient tp1 := gsData.SetupGSTransportHost1() tp2 := gsData.SetupGSTransportHost2() - dt1, err := NewDataTransfer(gsData.DtDs1, gsData.DtNet1, tp1) + dt1, err := NewDataTransfer(gsData.DtDs1, gsData.Host1.ID(), tp1) require.NoError(t, err) testutil.StartAndWaitForReady(ctx, t, dt1) - dt2, err := NewDataTransfer(gsData.DtDs2, gsData.DtNet2, tp2) + dt2, err := NewDataTransfer(gsData.DtDs2, gsData.Host2.ID(), tp2) require.NoError(t, err) testutil.StartAndWaitForReady(ctx, t, dt2) @@ -180,8 +182,9 @@ func TestRoundTrip(t *testing.T) { } dt1.SubscribeToEvents(subscriber) dt2.SubscribeToEvents(subscriber) - voucher := testutil.FakeDTType{Data: "applesauce"} + voucher := testutil.NewTestTypedVoucherWith("applesauce") sv := testutil.NewStubbedValidator() + sv.StubResult(datatransfer.ValidationResult{Accepted: true}) var sourceDagService ipldformat.DAGService if data.customSourceStore { @@ -189,9 +192,8 @@ func TestRoundTrip(t *testing.T) { bs := bstore.NewBlockstore(namespace.Wrap(ds, datastore.NewKey("blockstore"))) lsys := storeutil.LinkSystemForBlockstore(bs) sourceDagService = merkledag.NewDAGService(blockservice.New(bs, offline.Exchange(bs))) - err := dt1.RegisterTransportConfigurer(&testutil.FakeDTType{}, func(channelID datatransfer.ChannelID, testVoucher datatransfer.Voucher, transport datatransfer.Transport) { - fv, ok := testVoucher.(*testutil.FakeDTType) - if ok && fv.Data == voucher.Data { + err := dt1.RegisterTransportConfigurer(testutil.TestVoucherType, func(channelID datatransfer.ChannelID, testVoucher datatransfer.TypedVoucher, transport datatransfer.Transport) { + if testVoucher.Equals(voucher) { gsTransport, ok := transport.(*tp.Transport) if ok { err := gsTransport.UseStore(channelID, lsys) @@ -203,7 +205,7 @@ func TestRoundTrip(t *testing.T) { } else { sourceDagService = gsData.DagService1 } - root, origBytes := testutil.LoadUnixFSFile(ctx, t, sourceDagService, loremFile) + root, origBytes := LoadUnixFSFile(ctx, t, sourceDagService, loremFile) rootCid := root.(cidlink.Link).Cid var destDagService ipldformat.DAGService @@ -212,9 +214,8 @@ func TestRoundTrip(t *testing.T) { bs := bstore.NewBlockstore(namespace.Wrap(ds, datastore.NewKey("blockstore"))) lsys := storeutil.LinkSystemForBlockstore(bs) destDagService = merkledag.NewDAGService(blockservice.New(bs, offline.Exchange(bs))) - err := dt2.RegisterTransportConfigurer(&testutil.FakeDTType{}, func(channelID datatransfer.ChannelID, testVoucher datatransfer.Voucher, transport datatransfer.Transport) { - fv, ok := testVoucher.(*testutil.FakeDTType) - if ok && fv.Data == voucher.Data { + err := dt2.RegisterTransportConfigurer(testutil.TestVoucherType, func(channelID datatransfer.ChannelID, testVoucher datatransfer.TypedVoucher, transport datatransfer.Transport) { + if testVoucher.Equals(voucher) { gsTransport, ok := transport.(*tp.Transport) if ok { err := gsTransport.UseStore(channelID, lsys) @@ -230,12 +231,12 @@ func TestRoundTrip(t *testing.T) { var chid datatransfer.ChannelID if data.isPull { sv.ExpectSuccessPull() - require.NoError(t, dt1.RegisterVoucherType(&testutil.FakeDTType{}, sv)) - chid, err = dt2.OpenPullDataChannel(ctx, host1.ID(), &voucher, rootCid, gsData.AllSelector) + require.NoError(t, dt1.RegisterVoucherType(testutil.TestVoucherType, sv)) + chid, err = dt2.OpenPullDataChannel(ctx, host1.ID(), voucher, rootCid, selectorparse.CommonSelector_ExploreAllRecursively) } else { sv.ExpectSuccessPush() - require.NoError(t, dt2.RegisterVoucherType(&testutil.FakeDTType{}, sv)) - chid, err = dt1.OpenPushDataChannel(ctx, host2.ID(), &voucher, rootCid, gsData.AllSelector) + require.NoError(t, dt2.RegisterVoucherType(testutil.TestVoucherType, sv)) + chid, err = dt1.OpenPushDataChannel(ctx, host2.ID(), voucher, rootCid, selectorparse.CommonSelector_ExploreAllRecursively) } require.NoError(t, err) opens := 0 @@ -259,7 +260,7 @@ func TestRoundTrip(t *testing.T) { } } require.Equal(t, sentIncrements, receivedIncrements) - testutil.VerifyHasFile(ctx, t, destDagService, root, origBytes) + VerifyHasFile(ctx, t, destDagService, root, origBytes) if data.isPull { assert.Equal(t, chid.Initiator, host2.ID()) } else { @@ -293,17 +294,17 @@ func TestMultipleRoundTripMultipleStores(t *testing.T) { ctx, cancel := context.WithTimeout(ctx, 10*time.Second) defer cancel() - gsData := testutil.NewGraphsyncTestingData(ctx, t, nil, nil) + gsData := NewGraphsyncTestingData(ctx, t, nil, nil) host1 := gsData.Host1 // initiator, data sender host2 := gsData.Host2 // data recipient tp1 := gsData.SetupGSTransportHost1() tp2 := gsData.SetupGSTransportHost2() - dt1, err := NewDataTransfer(gsData.DtDs1, gsData.DtNet1, tp1) + dt1, err := NewDataTransfer(gsData.DtDs1, gsData.Host1.ID(), tp1) require.NoError(t, err) testutil.StartAndWaitForReady(ctx, t, dt1) - dt2, err := NewDataTransfer(gsData.DtDs2, gsData.DtNet2, tp2) + dt2, err := NewDataTransfer(gsData.DtDs2, gsData.Host2.ID(), tp2) require.NoError(t, err) testutil.StartAndWaitForReady(ctx, t, dt2) @@ -323,13 +324,14 @@ func TestMultipleRoundTripMultipleStores(t *testing.T) { } dt1.SubscribeToEvents(subscriber) dt2.SubscribeToEvents(subscriber) - vouchers := make([]datatransfer.Voucher, 0, data.requestCount) + vouchers := make([]datatransfer.TypedVoucher, 0, data.requestCount) for i := 0; i < data.requestCount; i++ { - vouchers = append(vouchers, testutil.NewFakeDTType()) + vouchers = append(vouchers, testutil.NewTestTypedVoucher()) } sv := testutil.NewStubbedValidator() + sv.StubResult(datatransfer.ValidationResult{Accepted: true}) - root, origBytes := testutil.LoadUnixFSFile(ctx, t, gsData.DagService1, loremFile) + root, origBytes := LoadUnixFSFile(ctx, t, gsData.DagService1, loremFile) rootCid := root.(cidlink.Link).Cid destDagServices := make([]ipldformat.DAGService, 0, data.requestCount) @@ -344,16 +346,13 @@ func TestMultipleRoundTripMultipleStores(t *testing.T) { linkSystems = append(linkSystems, lsys) } - err = dt2.RegisterTransportConfigurer(&testutil.FakeDTType{}, func(channelID datatransfer.ChannelID, testVoucher datatransfer.Voucher, transport datatransfer.Transport) { - fv, ok := testVoucher.(*testutil.FakeDTType) - if ok { - for i, voucher := range vouchers { - if fv.Data == voucher.(*testutil.FakeDTType).Data { - gsTransport, ok := transport.(*tp.Transport) - if ok { - err := gsTransport.UseStore(channelID, linkSystems[i]) - require.NoError(t, err) - } + err = dt2.RegisterTransportConfigurer(testutil.TestVoucherType, func(channelID datatransfer.ChannelID, testVoucher datatransfer.TypedVoucher, transport datatransfer.Transport) { + for i, voucher := range vouchers { + if testVoucher.Equals(voucher) { + gsTransport, ok := transport.(*tp.Transport) + if ok { + err := gsTransport.UseStore(channelID, linkSystems[i]) + require.NoError(t, err) } } } @@ -362,16 +361,16 @@ func TestMultipleRoundTripMultipleStores(t *testing.T) { if data.isPull { sv.ExpectSuccessPull() - require.NoError(t, dt1.RegisterVoucherType(&testutil.FakeDTType{}, sv)) + require.NoError(t, dt1.RegisterVoucherType(testutil.TestVoucherType, sv)) for i := 0; i < data.requestCount; i++ { - _, err = dt2.OpenPullDataChannel(ctx, host1.ID(), vouchers[i], rootCid, gsData.AllSelector) + _, err = dt2.OpenPullDataChannel(ctx, host1.ID(), vouchers[i], rootCid, selectorparse.CommonSelector_ExploreAllRecursively) require.NoError(t, err) } } else { sv.ExpectSuccessPush() - require.NoError(t, dt2.RegisterVoucherType(&testutil.FakeDTType{}, sv)) + require.NoError(t, dt2.RegisterVoucherType(testutil.TestVoucherType, sv)) for i := 0; i < data.requestCount; i++ { - _, err = dt1.OpenPushDataChannel(ctx, host2.ID(), vouchers[i], rootCid, gsData.AllSelector) + _, err = dt1.OpenPushDataChannel(ctx, host2.ID(), vouchers[i], rootCid, selectorparse.CommonSelector_ExploreAllRecursively) require.NoError(t, err) } } @@ -390,7 +389,7 @@ func TestMultipleRoundTripMultipleStores(t *testing.T) { } } for _, destDagService := range destDagServices { - testutil.VerifyHasFile(ctx, t, destDagService, root, origBytes) + VerifyHasFile(ctx, t, destDagService, root, origBytes) } }) } @@ -415,11 +414,11 @@ func TestManyReceiversAtOnce(t *testing.T) { ctx, cancel := context.WithTimeout(ctx, 10*time.Second) defer cancel() - gsData := testutil.NewGraphsyncTestingData(ctx, t, nil, nil) + gsData := NewGraphsyncTestingData(ctx, t, nil, nil) host1 := gsData.Host1 // initiator, data sender tp1 := gsData.SetupGSTransportHost1() - dt1, err := NewDataTransfer(gsData.DtDs1, gsData.DtNet1, tp1) + dt1, err := NewDataTransfer(gsData.DtDs1, gsData.Host1.ID(), tp1) require.NoError(t, err) testutil.StartAndWaitForReady(ctx, t, dt1) @@ -441,19 +440,18 @@ func TestManyReceiversAtOnce(t *testing.T) { destDagService := merkledag.NewDAGService(blockservice.New(altBs, offline.Exchange(altBs))) gs := gsimpl.New(gsData.Ctx, gsnet, lsys) - gsTransport := tp.NewTransport(host.ID(), gs) + gsTransport := tp.NewTransport(gs, dtnet) dtDs := namespace.Wrap(ds, datastore.NewKey("datatransfer")) - receiver, err := NewDataTransfer(dtDs, dtnet, gsTransport) + receiver, err := NewDataTransfer(dtDs, host.ID(), gsTransport) require.NoError(t, err) err = receiver.Start(gsData.Ctx) require.NoError(t, err) - err = receiver.RegisterTransportConfigurer(&testutil.FakeDTType{}, func(channelID datatransfer.ChannelID, testVoucher datatransfer.Voucher, transport datatransfer.Transport) { - _, isFv := testVoucher.(*testutil.FakeDTType) + err = receiver.RegisterTransportConfigurer(testutil.TestVoucherType, func(channelID datatransfer.ChannelID, testVoucher datatransfer.TypedVoucher, transport datatransfer.Transport) { gsTransport, isGs := transport.(*tp.Transport) - if isFv && isGs { + if isGs { err := gsTransport.UseStore(channelID, altLinkSystem) require.NoError(t, err) } @@ -485,27 +483,28 @@ func TestManyReceiversAtOnce(t *testing.T) { for _, receiver := range receivers { receiver.SubscribeToEvents(subscriber) } - vouchers := make([]datatransfer.Voucher, 0, data.receiverCount) + vouchers := make([]datatransfer.TypedVoucher, 0, data.receiverCount) for i := 0; i < data.receiverCount; i++ { - vouchers = append(vouchers, testutil.NewFakeDTType()) + vouchers = append(vouchers, testutil.NewTestTypedVoucher()) } sv := testutil.NewStubbedValidator() + sv.StubResult(datatransfer.ValidationResult{Accepted: true}) - root, origBytes := testutil.LoadUnixFSFile(ctx, t, gsData.DagService1, loremFile) + root, origBytes := LoadUnixFSFile(ctx, t, gsData.DagService1, loremFile) rootCid := root.(cidlink.Link).Cid if data.isPull { sv.ExpectSuccessPull() - require.NoError(t, dt1.RegisterVoucherType(&testutil.FakeDTType{}, sv)) + require.NoError(t, dt1.RegisterVoucherType(testutil.TestVoucherType, sv)) for i, receiver := range receivers { - _, err = receiver.OpenPullDataChannel(ctx, host1.ID(), vouchers[i], rootCid, gsData.AllSelector) + _, err = receiver.OpenPullDataChannel(ctx, host1.ID(), vouchers[i], rootCid, selectorparse.CommonSelector_ExploreAllRecursively) require.NoError(t, err) } } else { sv.ExpectSuccessPush() for i, receiver := range receivers { - require.NoError(t, receiver.RegisterVoucherType(&testutil.FakeDTType{}, sv)) - _, err = dt1.OpenPushDataChannel(ctx, hosts[i].ID(), vouchers[i], rootCid, gsData.AllSelector) + require.NoError(t, receiver.RegisterVoucherType(testutil.TestVoucherType, sv)) + _, err = dt1.OpenPushDataChannel(ctx, hosts[i].ID(), vouchers[i], rootCid, selectorparse.CommonSelector_ExploreAllRecursively) require.NoError(t, err) } } @@ -524,7 +523,7 @@ func TestManyReceiversAtOnce(t *testing.T) { } } for _, destDagService := range destDagServices { - testutil.VerifyHasFile(ctx, t, destDagService, root, origBytes) + VerifyHasFile(ctx, t, destDagService, root, origBytes) } }) } @@ -556,66 +555,6 @@ func (dc *disconnectCoordinator) onDisconnect() { close(dc.disconnected) } -type restartRevalidator struct { - *testutil.StubbedRevalidator - pullDataSent map[datatransfer.ChannelID][]uint64 - pushDataRcvd map[datatransfer.ChannelID][]uint64 -} - -func newRestartRevalidator() *restartRevalidator { - return &restartRevalidator{ - StubbedRevalidator: testutil.NewStubbedRevalidator(), - pullDataSent: make(map[datatransfer.ChannelID][]uint64), - pushDataRcvd: make(map[datatransfer.ChannelID][]uint64), - } -} - -func (r *restartRevalidator) OnPullDataSent(chid datatransfer.ChannelID, additionalBytesSent uint64) (bool, datatransfer.VoucherResult, error) { - chSent, ok := r.pullDataSent[chid] - if !ok { - chSent = []uint64{} - } - chSent = append(chSent, additionalBytesSent) - r.pullDataSent[chid] = chSent - - return true, nil, nil -} - -func (r *restartRevalidator) pullDataSum(chid datatransfer.ChannelID) uint64 { - pullDataSent, ok := r.pullDataSent[chid] - var total uint64 - if !ok { - return total - } - for _, sent := range pullDataSent { - total += sent - } - return total -} - -func (r *restartRevalidator) OnPushDataReceived(chid datatransfer.ChannelID, additionalBytesReceived uint64) (bool, datatransfer.VoucherResult, error) { - chRcvd, ok := r.pushDataRcvd[chid] - if !ok { - chRcvd = []uint64{} - } - chRcvd = append(chRcvd, additionalBytesReceived) - r.pushDataRcvd[chid] = chRcvd - - return true, nil, nil -} - -func (r *restartRevalidator) pushDataSum(chid datatransfer.ChannelID) uint64 { - pushDataRcvd, ok := r.pushDataRcvd[chid] - var total uint64 - if !ok { - return total - } - for _, rcvd := range pushDataRcvd { - total += rcvd - } - return total -} - // TestAutoRestart tests that if the connection for a push or pull request // goes down, it will automatically restart (given the right config options) func TestAutoRestart(t *testing.T) { @@ -753,7 +692,7 @@ func TestAutoRestart(t *testing.T) { // The retry config for the network layer: make 5 attempts, backing off by 1s each time netRetry := network.RetryParameters(time.Second, time.Second, 5, 1) - gsData := testutil.NewGraphsyncTestingData(ctx, t, ps.host1Protocols, ps.host2Protocols) + gsData := NewGraphsyncTestingData(ctx, t, ps.host1Protocols, ps.host2Protocols) gsData.DtNet1 = network.NewFromLibp2pHost(gsData.Host1, netRetry) initiatorHost := gsData.Host1 // initiator, data sender responderHost := gsData.Host2 // data recipient @@ -769,12 +708,12 @@ func TestAutoRestart(t *testing.T) { MaxConsecutiveRestarts: 10, CompleteTimeout: 100 * time.Millisecond, }) - initiator, err := NewDataTransfer(gsData.DtDs1, gsData.DtNet1, initiatorGSTspt, restartConf) + initiator, err := NewDataTransfer(gsData.DtDs1, gsData.Host1.ID(), initiatorGSTspt, restartConf) require.NoError(t, err) testutil.StartAndWaitForReady(ctx, t, initiator) defer initiator.Stop(ctx) - responder, err := NewDataTransfer(gsData.DtDs2, gsData.DtNet2, responderGSTspt) + responder, err := NewDataTransfer(gsData.DtDs2, gsData.Host2.ID(), responderGSTspt) require.NoError(t, err) testutil.StartAndWaitForReady(ctx, t, responder) defer responder.Stop(ctx) @@ -792,8 +731,10 @@ func TestAutoRestart(t *testing.T) { } initiator.SubscribeToEvents(subscriber) responder.SubscribeToEvents(subscriber) - voucher := testutil.FakeDTType{Data: "applesauce"} + voucher := testutil.NewTestTypedVoucherWith("applesauce") sv := testutil.NewStubbedValidator() + sv.StubResult(datatransfer.ValidationResult{Accepted: true}) + sv.StubRestartResult(datatransfer.ValidationResult{Accepted: true}) var sourceDagService, destDagService ipldformat.DAGService if tc.isPush { @@ -804,15 +745,11 @@ func TestAutoRestart(t *testing.T) { destDagService = gsData.DagService1 } - root, origBytes := testutil.LoadUnixFSFile(ctx, t, sourceDagService, loremFile) + root, origBytes := LoadUnixFSFile(ctx, t, sourceDagService, loremFile) rootCid := root.(cidlink.Link).Cid - require.NoError(t, initiator.RegisterVoucherType(&testutil.FakeDTType{}, sv)) - require.NoError(t, responder.RegisterVoucherType(&testutil.FakeDTType{}, sv)) - - // Register a revalidator that records calls to OnPullDataSent and OnPushDataReceived - srv := newRestartRevalidator() - require.NoError(t, responder.RegisterRevalidator(testutil.NewFakeDTType(), srv)) + require.NoError(t, initiator.RegisterVoucherType(testutil.TestVoucherType, sv)) + require.NoError(t, responder.RegisterVoucherType(testutil.TestVoucherType, sv)) // If the test case needs to subscribe to response events, provide // the test case with the responder @@ -833,10 +770,10 @@ func TestAutoRestart(t *testing.T) { var chid datatransfer.ChannelID if tc.isPush { // Open a push channel - chid, err = initiator.OpenPushDataChannel(ctx, responderHost.ID(), &voucher, rootCid, gsData.AllSelector) + chid, err = initiator.OpenPushDataChannel(ctx, responderHost.ID(), voucher, rootCid, selectorparse.CommonSelector_ExploreAllRecursively) } else { // Open a pull channel - chid, err = initiator.OpenPullDataChannel(ctx, responderHost.ID(), &voucher, rootCid, gsData.AllSelector) + chid, err = initiator.OpenPullDataChannel(ctx, responderHost.ID(), voucher, rootCid, selectorparse.CommonSelector_ExploreAllRecursively) } require.NoError(t, err) @@ -895,16 +832,17 @@ func TestAutoRestart(t *testing.T) { } })() - // Verify that the total amount of data sent / received that was - // reported to the revalidator is correct + chst, err := responder.ChannelState(ctx, chid) + require.NoError(t, err) + // Verify that the total amount of data sent / received was correct if tc.isPush { - require.EqualValues(t, loremFileTransferBytes, srv.pushDataSum(chid)) + require.EqualValues(t, uint64(loremFileTransferBytes), chst.Received()) } else { - require.EqualValues(t, loremFileTransferBytes, srv.pullDataSum(chid)) + require.EqualValues(t, uint64(loremFileTransferBytes), chst.Sent()) } // Verify that the file was transferred to the destination node - testutil.VerifyHasFile(ctx, t, destDagService, root, origBytes) + VerifyHasFile(ctx, t, destDagService, root, origBytes) }) } } @@ -928,7 +866,7 @@ func TestAutoRestartAfterBouncingInitiator(t *testing.T) { // The retry config for the network layer: make 5 attempts, backing off by 1s each time netRetry := network.RetryParameters(time.Second, time.Second, 5, 1) - gsData := testutil.NewGraphsyncTestingData(ctx, t, nil, nil) + gsData := NewGraphsyncTestingData(ctx, t, nil, nil) gsData.DtNet1 = network.NewFromLibp2pHost(gsData.Host1, netRetry) initiatorHost := gsData.Host1 // initiator, data sender responderHost := gsData.Host2 // data recipient @@ -944,12 +882,12 @@ func TestAutoRestartAfterBouncingInitiator(t *testing.T) { MaxConsecutiveRestarts: 10, CompleteTimeout: 100 * time.Millisecond, }) - initiator, err := NewDataTransfer(gsData.DtDs1, gsData.DtNet1, initiatorGSTspt, restartConf) + initiator, err := NewDataTransfer(gsData.DtDs1, gsData.Host1.ID(), initiatorGSTspt, restartConf) require.NoError(t, err) testutil.StartAndWaitForReady(ctx, t, initiator) defer initiator.Stop(ctx) - responder, err := NewDataTransfer(gsData.DtDs2, gsData.DtNet2, responderGSTspt) + responder, err := NewDataTransfer(gsData.DtDs2, gsData.Host2.ID(), responderGSTspt) require.NoError(t, err) testutil.StartAndWaitForReady(ctx, t, responder) defer responder.Stop(ctx) @@ -986,8 +924,10 @@ func TestAutoRestartAfterBouncingInitiator(t *testing.T) { } dataReceived := onDataReceivedChan(dataReceiver) - voucher := testutil.FakeDTType{Data: "applesauce"} + voucher := testutil.NewTestTypedVoucherWith("applesauce") sv := testutil.NewStubbedValidator() + sv.StubResult(datatransfer.ValidationResult{Accepted: true}) + sv.StubRestartResult(datatransfer.ValidationResult{Accepted: true}) var sourceDagService, destDagService ipldformat.DAGService if isPush { @@ -998,23 +938,19 @@ func TestAutoRestartAfterBouncingInitiator(t *testing.T) { destDagService = gsData.DagService1 } - root, origBytes := testutil.LoadUnixFSFile(ctx, t, sourceDagService, loremLargeFile) + root, origBytes := LoadUnixFSFile(ctx, t, sourceDagService, loremLargeFile) rootCid := root.(cidlink.Link).Cid - require.NoError(t, initiator.RegisterVoucherType(&testutil.FakeDTType{}, sv)) - require.NoError(t, responder.RegisterVoucherType(&testutil.FakeDTType{}, sv)) - - // Register a revalidator that records calls to OnPullDataSent and OnPushDataReceived - srv := newRestartRevalidator() - require.NoError(t, responder.RegisterRevalidator(testutil.NewFakeDTType(), srv)) + require.NoError(t, initiator.RegisterVoucherType(testutil.TestVoucherType, sv)) + require.NoError(t, responder.RegisterVoucherType(testutil.TestVoucherType, sv)) var chid datatransfer.ChannelID if isPush { // Open a push channel - chid, err = initiator.OpenPushDataChannel(ctx, responderHost.ID(), &voucher, rootCid, gsData.AllSelector) + chid, err = initiator.OpenPushDataChannel(ctx, responderHost.ID(), voucher, rootCid, selectorparse.CommonSelector_ExploreAllRecursively) } else { // Open a pull channel - chid, err = initiator.OpenPullDataChannel(ctx, responderHost.ID(), &voucher, rootCid, gsData.AllSelector) + chid, err = initiator.OpenPullDataChannel(ctx, responderHost.ID(), voucher, rootCid, selectorparse.CommonSelector_ExploreAllRecursively) } require.NoError(t, err) @@ -1044,9 +980,9 @@ func TestAutoRestartAfterBouncingInitiator(t *testing.T) { // 2. Create a new initiator initiator2GSTspt := gsData.SetupGSTransportHost1() - initiator2, err := NewDataTransfer(gsData.DtDs1, gsData.DtNet1, initiator2GSTspt, restartConf) + initiator2, err := NewDataTransfer(gsData.DtDs1, gsData.Host1.ID(), initiator2GSTspt, restartConf) require.NoError(t, err) - require.NoError(t, initiator2.RegisterVoucherType(&testutil.FakeDTType{}, sv)) + require.NoError(t, initiator2.RegisterVoucherType(testutil.TestVoucherType, sv)) initiator2.SubscribeToEvents(completeSubscriber) testutil.StartAndWaitForReady(ctx, t, initiator2) @@ -1112,16 +1048,17 @@ func TestAutoRestartAfterBouncingInitiator(t *testing.T) { } })() - // Verify that the total amount of data sent / received that was - // reported to the revalidator is correct + chst, err := responder.ChannelState(ctx, chid) + require.NoError(t, err) + // Verify that the total amount of data sent / received was correct if isPush { - require.EqualValues(t, loremLargeFileTransferBytes, srv.pushDataSum(chid)) + require.EqualValues(t, uint64(loremFileTransferBytes), chst.Received()) } else { - require.EqualValues(t, loremLargeFileTransferBytes, srv.pullDataSum(chid)) + require.EqualValues(t, uint64(loremFileTransferBytes), chst.Sent()) } // Verify that the file was transferred to the destination node - testutil.VerifyHasFile(ctx, t, destDagService, root, origBytes) + VerifyHasFile(ctx, t, destDagService, root, origBytes) } t.Run("push", func(t *testing.T) { @@ -1147,17 +1084,17 @@ func TestRoundTripCancelledRequest(t *testing.T) { ctx, cancel := context.WithTimeout(ctx, 10*time.Second) defer cancel() - gsData := testutil.NewGraphsyncTestingData(ctx, t, nil, nil) + gsData := NewGraphsyncTestingData(ctx, t, nil, nil) host1 := gsData.Host1 // initiator, data sender host2 := gsData.Host2 tp1 := gsData.SetupGSTransportHost1() tp2 := gsData.SetupGSTransportHost2() - dt1, err := NewDataTransfer(gsData.DtDs1, gsData.DtNet1, tp1) + dt1, err := NewDataTransfer(gsData.DtDs1, gsData.Host1.ID(), tp1) require.NoError(t, err) testutil.StartAndWaitForReady(ctx, t, dt1) - dt2, err := NewDataTransfer(gsData.DtDs2, gsData.DtNet2, tp2) + dt2, err := NewDataTransfer(gsData.DtDs2, gsData.Host2.ID(), tp2) require.NoError(t, err) testutil.StartAndWaitForReady(ctx, t, dt2) @@ -1185,20 +1122,22 @@ func TestRoundTripCancelledRequest(t *testing.T) { } dt1.SubscribeToEvents(subscriber) dt2.SubscribeToEvents(subscriber) - voucher := testutil.FakeDTType{Data: "applesauce"} + voucher := testutil.NewTestTypedVoucherWith("applesauce") sv := testutil.NewStubbedValidator() - root, _ := testutil.LoadUnixFSFile(ctx, t, gsData.DagService1, loremFile) + root, _ := LoadUnixFSFile(ctx, t, gsData.DagService1, loremFile) rootCid := root.(cidlink.Link).Cid var chid datatransfer.ChannelID if data.isPull { - sv.ExpectPausePull() - require.NoError(t, dt1.RegisterVoucherType(&testutil.FakeDTType{}, sv)) - chid, err = dt2.OpenPullDataChannel(ctx, host1.ID(), &voucher, rootCid, gsData.AllSelector) + sv.ExpectSuccessPull() + sv.StubResult(datatransfer.ValidationResult{Accepted: true, ForcePause: true}) + require.NoError(t, dt1.RegisterVoucherType(testutil.TestVoucherType, sv)) + chid, err = dt2.OpenPullDataChannel(ctx, host1.ID(), voucher, rootCid, selectorparse.CommonSelector_ExploreAllRecursively) } else { - sv.ExpectPausePush() - require.NoError(t, dt2.RegisterVoucherType(&testutil.FakeDTType{}, sv)) - chid, err = dt1.OpenPushDataChannel(ctx, host2.ID(), &voucher, rootCid, gsData.AllSelector) + sv.ExpectSuccessPush() + sv.StubResult(datatransfer.ValidationResult{Accepted: true, ForcePause: true}) + require.NoError(t, dt2.RegisterVoucherType(testutil.TestVoucherType, sv)) + chid, err = dt1.OpenPushDataChannel(ctx, host2.ID(), voucher, rootCid, selectorparse.CommonSelector_ExploreAllRecursively) } require.NoError(t, err) opens := 0 @@ -1239,35 +1178,66 @@ func TestRoundTripCancelledRequest(t *testing.T) { } type retrievalRevalidator struct { - *testutil.StubbedRevalidator - dataSoFar uint64 - providerPausePoint int - pausePoints []uint64 - finalVoucher datatransfer.VoucherResult - revalVouchers []datatransfer.VoucherResult + *testutil.StubbedValidator + providerPausePoint int + pausePoints []uint64 + leavePausedInitially bool + initialVoucherResult *datatransfer.TypedVoucher + requiresFinalization bool } -func (r *retrievalRevalidator) OnPullDataSent(chid datatransfer.ChannelID, additionalBytesSent uint64) (bool, datatransfer.VoucherResult, error) { - r.dataSoFar += additionalBytesSent - if r.providerPausePoint < len(r.pausePoints) && - r.dataSoFar >= r.pausePoints[r.providerPausePoint] { - var v datatransfer.VoucherResult = testutil.NewFakeDTType() - if len(r.revalVouchers) > r.providerPausePoint { - v = r.revalVouchers[r.providerPausePoint] - } +func (r *retrievalRevalidator) ValidatePush( + chid datatransfer.ChannelID, + sender peer.ID, + voucher datamodel.Node, + baseCid cid.Cid, + selector datamodel.Node) (datatransfer.ValidationResult, error) { + vr := datatransfer.ValidationResult{ + Accepted: true, + RequiresFinalization: r.requiresFinalization, + ForcePause: r.leavePausedInitially, + } + if r.initialVoucherResult != nil { + vr.VoucherResult = r.initialVoucherResult + } + if len(r.pausePoints) > r.providerPausePoint { + vr.DataLimit = r.pausePoints[r.providerPausePoint] r.providerPausePoint++ - return true, v, datatransfer.ErrPause } - return true, nil, nil + r.StubbedValidator.StubResult(vr) + return r.StubbedValidator.ValidatePush(chid, sender, voucher, baseCid, selector) } -func (r *retrievalRevalidator) OnPushDataReceived(chid datatransfer.ChannelID, additionalBytesReceived uint64) (bool, datatransfer.VoucherResult, error) { - return false, nil, nil -} -func (r *retrievalRevalidator) OnComplete(chid datatransfer.ChannelID) (bool, datatransfer.VoucherResult, error) { - return true, r.finalVoucher, datatransfer.ErrPause +func (r *retrievalRevalidator) ValidatePull( + chid datatransfer.ChannelID, + sender peer.ID, + voucher datamodel.Node, + baseCid cid.Cid, + selector datamodel.Node) (datatransfer.ValidationResult, error) { + vr := datatransfer.ValidationResult{ + Accepted: true, + RequiresFinalization: r.requiresFinalization, + ForcePause: r.leavePausedInitially, + } + if r.initialVoucherResult != nil { + vr.VoucherResult = r.initialVoucherResult + } + if len(r.pausePoints) > r.providerPausePoint { + vr.DataLimit = r.pausePoints[r.providerPausePoint] + r.providerPausePoint++ + } + r.StubbedValidator.StubResult(vr) + return r.StubbedValidator.ValidatePull(chid, sender, voucher, baseCid, selector) } +func (r *retrievalRevalidator) nextStatus() datatransfer.ValidationResult { + vr := datatransfer.ValidationResult{Accepted: true, RequiresFinalization: r.requiresFinalization} + if len(r.pausePoints) > r.providerPausePoint { + vr.DataLimit = r.pausePoints[r.providerPausePoint] + r.providerPausePoint++ + } + return vr +} func TestSimulatedRetrievalFlow(t *testing.T) { ctx := context.Background() testCases := map[string]struct { @@ -1322,7 +1292,7 @@ func TestSimulatedRetrievalFlow(t *testing.T) { // responder: send message that we sent all data along with final voucher request "transfer(1)->sendMessage(0)", // responder: receive final voucher and send acceptance message - "transfer(1)->receiveRequest(5)->sendMessage(0)", + "transfer(1)->receiveRequest(5)", }, }, "fast unseal, payment channel not ready": { @@ -1340,7 +1310,7 @@ func TestSimulatedRetrievalFlow(t *testing.T) { ctx, cancel := context.WithTimeout(ctx, 4*time.Second) defer cancel() - gsData := testutil.NewGraphsyncTestingData(ctx, t, nil, nil) + gsData := NewGraphsyncTestingData(ctx, t, nil, nil) host1 := gsData.Host1 // initiator, data sender root := gsData.LoadUnixFSFile(t, false) @@ -1348,18 +1318,17 @@ func TestSimulatedRetrievalFlow(t *testing.T) { tp1 := gsData.SetupGSTransportHost1() tp2 := gsData.SetupGSTransportHost2() - dt1, err := NewDataTransfer(gsData.DtDs1, gsData.DtNet1, tp1) + dt1, err := NewDataTransfer(gsData.DtDs1, gsData.Host1.ID(), tp1) require.NoError(t, err) testutil.StartAndWaitForReady(ctx, t, dt1) - dt2, err := NewDataTransfer(gsData.DtDs2, gsData.DtNet2, tp2) + dt2, err := NewDataTransfer(gsData.DtDs2, gsData.Host2.ID(), tp2) require.NoError(t, err) testutil.StartAndWaitForReady(ctx, t, dt2) var chid datatransfer.ChannelID errChan := make(chan struct{}, 2) clientPausePoint := 0 clientFinished := make(chan struct{}, 1) - finalVoucherResult := testutil.NewFakeDTType() - encodedFVR, err := encoding.Encode(finalVoucherResult) + finalVoucherResult := testutil.NewTestTypedVoucher() require.NoError(t, err) var clientSubscriber datatransfer.Subscriber = func(event datatransfer.Event, channelState datatransfer.ChannelState) { if event.Code == datatransfer.Error { @@ -1367,17 +1336,15 @@ func TestSimulatedRetrievalFlow(t *testing.T) { } if event.Code == datatransfer.NewVoucherResult { lastVoucherResult := channelState.LastVoucherResult() - encodedLVR, err := encoding.Encode(lastVoucherResult) - require.NoError(t, err) - if bytes.Equal(encodedLVR, encodedFVR) { - _ = dt2.SendVoucher(ctx, chid, testutil.NewFakeDTType()) + if lastVoucherResult.Equals(finalVoucherResult) { + _ = dt2.SendVoucher(ctx, chid, testutil.NewTestTypedVoucher()) } } if event.Code == datatransfer.DataReceived && clientPausePoint < len(config.pausePoints) && channelState.Received() > config.pausePoints[clientPausePoint] { - _ = dt2.SendVoucher(ctx, chid, testutil.NewFakeDTType()) + _ = dt2.SendVoucher(ctx, chid, testutil.NewTestTypedVoucher()) clientPausePoint++ } if channelState.Status() == datatransfer.Completed { @@ -1385,18 +1352,31 @@ func TestSimulatedRetrievalFlow(t *testing.T) { } } dt2.SubscribeToEvents(clientSubscriber) + + sv := &retrievalRevalidator{ + StubbedValidator: testutil.NewStubbedValidator(), + pausePoints: config.pausePoints, + requiresFinalization: true, + leavePausedInitially: true, + } providerFinished := make(chan struct{}, 1) - providerAccepted := false var providerSubscriber datatransfer.Subscriber = func(event datatransfer.Event, channelState datatransfer.ChannelState) { if event.Code == datatransfer.PauseResponder { - if !providerAccepted { - providerAccepted = true - timer := time.NewTimer(config.unpauseResponderDelay) - go func() { - <-timer.C - _ = dt1.ResumeDataTransferChannel(ctx, chid) - }() - } + timer := time.NewTimer(config.unpauseResponderDelay) + go func() { + <-timer.C + _ = dt1.ResumeDataTransferChannel(ctx, chid) + }() + } + if event.Code == datatransfer.NewVoucher && channelState.Queued() > 0 { + dt1.UpdateValidationStatus(ctx, chid, sv.nextStatus()) + } + if event.Code == datatransfer.DataLimitExceeded { + dt1.SendVoucherResult(ctx, chid, testutil.NewTestTypedVoucher()) + } + if event.Code == datatransfer.BeginFinalizing { + sv.requiresFinalization = false + dt1.SendVoucherResult(ctx, chid, finalVoucherResult) } if event.Code == datatransfer.Error { errChan <- struct{}{} @@ -1406,19 +1386,11 @@ func TestSimulatedRetrievalFlow(t *testing.T) { } } dt1.SubscribeToEvents(providerSubscriber) - voucher := testutil.FakeDTType{Data: "applesauce"} - sv := testutil.NewStubbedValidator() - sv.ExpectPausePull() - require.NoError(t, dt1.RegisterVoucherType(&testutil.FakeDTType{}, sv)) + voucher := testutil.NewTestTypedVoucherWith("applesauce") - srv := &retrievalRevalidator{ - testutil.NewStubbedRevalidator(), 0, 0, config.pausePoints, finalVoucherResult, []datatransfer.VoucherResult{}, - } - srv.ExpectSuccessErrResume() - require.NoError(t, dt1.RegisterRevalidator(testutil.NewFakeDTType(), srv)) + require.NoError(t, dt1.RegisterVoucherType(testutil.TestVoucherType, sv)) - require.NoError(t, dt2.RegisterVoucherResultType(testutil.NewFakeDTType())) - chid, err = dt2.OpenPullDataChannel(ctx, host1.ID(), &voucher, rootCid, gsData.AllSelector) + chid, err = dt2.OpenPullDataChannel(ctx, host1.ID(), voucher, rootCid, selectorparse.CommonSelector_ExploreAllRecursively) require.NoError(t, err) for providerFinished != nil || clientFinished != nil { @@ -1434,9 +1406,8 @@ func TestSimulatedRetrievalFlow(t *testing.T) { } } sv.VerifyExpectations(t) - srv.VerifyExpectations(t) gsData.VerifyFileTransferred(t, root, true) - require.Equal(t, srv.providerPausePoint, len(config.pausePoints)) + require.Equal(t, sv.providerPausePoint, len(config.pausePoints)) require.Equal(t, clientPausePoint, len(config.pausePoints)) traces := collectTracing(t).TracesToStrings(3) for _, expectedTrace := range config.expectedTraces { @@ -1457,7 +1428,7 @@ func TestPauseAndResume(t *testing.T) { ctx, cancel := context.WithTimeout(ctx, 10*time.Second) defer cancel() - gsData := testutil.NewGraphsyncTestingData(ctx, t, nil, nil) + gsData := NewGraphsyncTestingData(ctx, t, nil, nil) host1 := gsData.Host1 // initiator, data sender host2 := gsData.Host2 // data recipient @@ -1466,12 +1437,13 @@ func TestPauseAndResume(t *testing.T) { tp1 := gsData.SetupGSTransportHost1() tp2 := gsData.SetupGSTransportHost2() - dt1, err := NewDataTransfer(gsData.DtDs1, gsData.DtNet1, tp1) + dt1, err := NewDataTransfer(gsData.DtDs1, gsData.Host1.ID(), tp1) require.NoError(t, err) testutil.StartAndWaitForReady(ctx, t, dt1) - dt2, err := NewDataTransfer(gsData.DtDs2, gsData.DtNet2, tp2) + dt2, err := NewDataTransfer(gsData.DtDs2, gsData.Host2.ID(), tp2) require.NoError(t, err) testutil.StartAndWaitForReady(ctx, t, dt2) + finished := make(chan struct{}, 2) errChan := make(chan struct{}, 2) opened := make(chan struct{}, 2) @@ -1519,18 +1491,39 @@ func TestPauseAndResume(t *testing.T) { } dt1.SubscribeToEvents(subscriber) dt2.SubscribeToEvents(subscriber) - voucher := testutil.FakeDTType{Data: "applesauce"} + voucher := testutil.NewTestTypedVoucherWith("applesauce") sv := testutil.NewStubbedValidator() - + sv.StubResult(datatransfer.ValidationResult{Accepted: true}) + sv.StubRestartResult(datatransfer.ValidationResult{Accepted: true}) var chid datatransfer.ChannelID + + gsData.Gs1.RegisterOutgoingBlockHook(func(p peer.ID, r graphsync.RequestData, block graphsync.BlockData, ha graphsync.OutgoingBlockHookActions) { + if block.Index() == 5 && block.BlockSizeOnWire() > 0 { + require.NoError(t, dt1.PauseDataTransferChannel(ctx, chid)) + go func() { + time.Sleep(100 * time.Millisecond) + require.NoError(t, dt1.ResumeDataTransferChannel(ctx, chid)) + }() + } + }) + gsData.Gs2.RegisterIncomingBlockHook(func(p peer.ID, r graphsync.ResponseData, block graphsync.BlockData, ha graphsync.IncomingBlockHookActions) { + if block.Index() == 5 { + require.NoError(t, dt2.PauseDataTransferChannel(ctx, chid)) + go func() { + time.Sleep(50 * time.Millisecond) + require.NoError(t, dt2.ResumeDataTransferChannel(ctx, chid)) + }() + } + }) + if isPull { sv.ExpectSuccessPull() - require.NoError(t, dt1.RegisterVoucherType(&testutil.FakeDTType{}, sv)) - chid, err = dt2.OpenPullDataChannel(ctx, host1.ID(), &voucher, rootCid, gsData.AllSelector) + require.NoError(t, dt1.RegisterVoucherType(testutil.TestVoucherType, sv)) + chid, err = dt2.OpenPullDataChannel(ctx, host1.ID(), voucher, rootCid, selectorparse.CommonSelector_ExploreAllRecursively) } else { sv.ExpectSuccessPush() - require.NoError(t, dt2.RegisterVoucherType(&testutil.FakeDTType{}, sv)) - chid, err = dt1.OpenPushDataChannel(ctx, host2.ID(), &voucher, rootCid, gsData.AllSelector) + require.NoError(t, dt2.RegisterVoucherType(testutil.TestVoucherType, sv)) + chid, err = dt1.OpenPushDataChannel(ctx, host2.ID(), voucher, rootCid, selectorparse.CommonSelector_ExploreAllRecursively) } require.NoError(t, err) opens := 0 @@ -1560,18 +1553,8 @@ func TestPauseAndResume(t *testing.T) { resumeResponders++ case sentIncrement := <-sent: sentIncrements = append(sentIncrements, sentIncrement) - if len(sentIncrements) == 5 { - require.NoError(t, dt1.PauseDataTransferChannel(ctx, chid)) - time.Sleep(100 * time.Millisecond) - require.NoError(t, dt1.ResumeDataTransferChannel(ctx, chid)) - } case receivedIncrement := <-received: receivedIncrements = append(receivedIncrements, receivedIncrement) - if len(receivedIncrements) == 10 { - require.NoError(t, dt2.PauseDataTransferChannel(ctx, chid)) - time.Sleep(100 * time.Millisecond) - require.NoError(t, dt2.ResumeDataTransferChannel(ctx, chid)) - } case <-errChan: t.Fatal("received error on data transfer") } @@ -1598,17 +1581,17 @@ func TestUnrecognizedVoucherRoundTrip(t *testing.T) { // ctx, cancel := context.WithTimeout(ctx, 5*time.Second) // defer cancel() - gsData := testutil.NewGraphsyncTestingData(ctx, t, nil, nil) + gsData := NewGraphsyncTestingData(ctx, t, nil, nil) host1 := gsData.Host1 // initiator, data sender host2 := gsData.Host2 // data recipient tp1 := gsData.SetupGSTransportHost1() tp2 := gsData.SetupGSTransportHost2() - dt1, err := NewDataTransfer(gsData.DtDs1, gsData.DtNet1, tp1) + dt1, err := NewDataTransfer(gsData.DtDs1, gsData.Host1.ID(), tp1) require.NoError(t, err) testutil.StartAndWaitForReady(ctx, t, dt1) - dt2, err := NewDataTransfer(gsData.DtDs2, gsData.DtNet2, tp2) + dt2, err := NewDataTransfer(gsData.DtDs2, gsData.Host2.ID(), tp2) require.NoError(t, err) testutil.StartAndWaitForReady(ctx, t, dt2) @@ -1628,15 +1611,15 @@ func TestUnrecognizedVoucherRoundTrip(t *testing.T) { } dt1.SubscribeToEvents(subscriber) dt2.SubscribeToEvents(subscriber) - voucher := testutil.FakeDTType{Data: "applesauce"} + voucher := testutil.NewTestTypedVoucherWith("applesauce") - root, _ := testutil.LoadUnixFSFile(ctx, t, gsData.DagService1, loremFile) + root, _ := LoadUnixFSFile(ctx, t, gsData.DagService1, loremFile) rootCid := root.(cidlink.Link).Cid if isPull { - _, err = dt2.OpenPullDataChannel(ctx, host1.ID(), &voucher, rootCid, gsData.AllSelector) + _, err = dt2.OpenPullDataChannel(ctx, host1.ID(), voucher, rootCid, selectorparse.CommonSelector_ExploreAllRecursively) } else { - _, err = dt1.OpenPushDataChannel(ctx, host2.ID(), &voucher, rootCid, gsData.AllSelector) + _, err = dt1.OpenPushDataChannel(ctx, host2.ID(), voucher, rootCid, selectorparse.CommonSelector_ExploreAllRecursively) } require.NoError(t, err) opens := 0 @@ -1667,7 +1650,7 @@ func TestDataTransferSubscribing(t *testing.T) { ctx := context.Background() ctx, cancel := context.WithTimeout(ctx, 10*time.Second) defer cancel() - gsData := testutil.NewGraphsyncTestingData(ctx, t, nil, nil) + gsData := NewGraphsyncTestingData(ctx, t, nil, nil) host2 := gsData.Host2 tp1 := gsData.SetupGSTransportHost1() @@ -1675,14 +1658,14 @@ func TestDataTransferSubscribing(t *testing.T) { sv := testutil.NewStubbedValidator() sv.StubErrorPull() sv.StubErrorPush() - dt2, err := NewDataTransfer(gsData.DtDs2, gsData.DtNet2, tp2) + dt2, err := NewDataTransfer(gsData.DtDs2, gsData.Host2.ID(), tp2) require.NoError(t, err) testutil.StartAndWaitForReady(ctx, t, dt2) - require.NoError(t, dt2.RegisterVoucherType(&testutil.FakeDTType{}, sv)) - voucher := testutil.FakeDTType{Data: "applesauce"} + require.NoError(t, dt2.RegisterVoucherType(testutil.TestVoucherType, sv)) + voucher := testutil.NewTestTypedVoucherWith("applesauce") baseCid := testutil.GenerateCids(1)[0] - dt1, err := NewDataTransfer(gsData.DtDs1, gsData.DtNet1, tp1) + dt1, err := NewDataTransfer(gsData.DtDs1, gsData.Host1.ID(), tp1) require.NoError(t, err) testutil.StartAndWaitForReady(ctx, t, dt1) subscribe1Calls := make(chan struct{}, 1) @@ -1699,7 +1682,7 @@ func TestDataTransferSubscribing(t *testing.T) { } unsub1 := dt1.SubscribeToEvents(subscribe1) unsub2 := dt1.SubscribeToEvents(subscribe2) - _, err = dt1.OpenPushDataChannel(ctx, host2.ID(), &voucher, baseCid, gsData.AllSelector) + _, err = dt1.OpenPushDataChannel(ctx, host2.ID(), voucher, baseCid, selectorparse.CommonSelector_ExploreAllRecursively) require.NoError(t, err) select { case <-ctx.Done(): @@ -1728,7 +1711,7 @@ func TestDataTransferSubscribing(t *testing.T) { } unsub3 := dt1.SubscribeToEvents(subscribe3) unsub4 := dt1.SubscribeToEvents(subscribe4) - _, err = dt1.OpenPullDataChannel(ctx, host2.ID(), &voucher, baseCid, gsData.AllSelector) + _, err = dt1.OpenPullDataChannel(ctx, host2.ID(), voucher, baseCid, selectorparse.CommonSelector_ExploreAllRecursively) require.NoError(t, err) select { case <-ctx.Done(): @@ -1795,10 +1778,10 @@ func TestRespondingToPushGraphsyncRequests(t *testing.T) { ctx := context.Background() ctx, cancel := context.WithTimeout(ctx, 10*time.Second) defer cancel() - gsData := testutil.NewGraphsyncTestingData(ctx, t, nil, nil) + gsData := NewGraphsyncTestingData(ctx, t, nil, nil) host1 := gsData.Host1 // initiator and data sender host2 := gsData.Host2 // data recipient, makes graphsync request for data - voucher := testutil.NewFakeDTType() + voucher := testutil.NewTestTypedVoucher() link := gsData.LoadUnixFSFile(t, false) // setup receiving peer to just record message coming in @@ -1806,7 +1789,7 @@ func TestRespondingToPushGraphsyncRequests(t *testing.T) { r := &receiver{ messageReceived: make(chan receivedMessage), } - dtnet2.SetDelegate(r) + dtnet2.SetDelegate(datatransfer.LegacyTransportID, []datatransfer.Version{datatransfer.LegacyTransportVersion}, r) gsr := &fakeGraphSyncReceiver{ receivedMessages: make(chan receivedGraphSyncMessage), @@ -1814,15 +1797,14 @@ func TestRespondingToPushGraphsyncRequests(t *testing.T) { gsData.GsNet2.SetDelegate(gsr) tp1 := gsData.SetupGSTransportHost1() - dt1, err := NewDataTransfer(gsData.DtDs1, gsData.DtNet1, tp1) + dt1, err := NewDataTransfer(gsData.DtDs1, gsData.Host1.ID(), tp1) require.NoError(t, err) testutil.StartAndWaitForReady(ctx, t, dt1) - voucherResult := testutil.NewFakeDTType() - err = dt1.RegisterVoucherResultType(voucherResult) + voucherResult := testutil.NewTestTypedVoucher() require.NoError(t, err) t.Run("when request is initiated", func(t *testing.T) { - _, err := dt1.OpenPushDataChannel(ctx, host2.ID(), voucher, link.(cidlink.Link).Cid, gsData.AllSelector) + _, err := dt1.OpenPushDataChannel(ctx, host2.ID(), voucher, link.(cidlink.Link).Cid, selectorparse.CommonSelector_ExploreAllRecursively) require.NoError(t, err) var messageReceived receivedMessage @@ -1833,11 +1815,10 @@ func TestRespondingToPushGraphsyncRequests(t *testing.T) { } requestReceived := messageReceived.message.(datatransfer.Request) - response, err := message.NewResponse(requestReceived.TransferID(), true, false, voucherResult.Type(), voucherResult) - require.NoError(t, err) - nd, err := response.ToIPLD() + response := message.NewResponse(requestReceived.TransferID(), true, false, &voucherResult) require.NoError(t, err) - request := gsmsg.NewRequest(graphsync.NewRequestID(), link.(cidlink.Link).Cid, gsData.AllSelector, graphsync.Priority(rand.Int31()), graphsync.ExtensionData{ + nd := response.ToIPLD() + request := gsmsg.NewRequest(graphsync.NewRequestID(), link.(cidlink.Link).Cid, selectorparse.CommonSelector_ExploreAllRecursively, graphsync.Priority(rand.Int31()), graphsync.ExtensionData{ Name: extension.ExtensionDataTransfer1_1, Data: nd, }) @@ -1852,11 +1833,10 @@ func TestRespondingToPushGraphsyncRequests(t *testing.T) { }) t.Run("when no request is initiated", func(t *testing.T) { - response, err := message.NewResponse(datatransfer.TransferID(rand.Uint32()), true, false, voucher.Type(), voucher) - require.NoError(t, err) - nd, err := response.ToIPLD() + response := message.NewResponse(datatransfer.TransferID(rand.Uint32()), true, false, &voucher) require.NoError(t, err) - request := gsmsg.NewRequest(graphsync.NewRequestID(), link.(cidlink.Link).Cid, gsData.AllSelector, graphsync.Priority(rand.Int31()), graphsync.ExtensionData{ + nd := response.ToIPLD() + request := gsmsg.NewRequest(graphsync.NewRequestID(), link.(cidlink.Link).Cid, selectorparse.CommonSelector_ExploreAllRecursively, graphsync.Priority(rand.Int31()), graphsync.ExtensionData{ Name: extension.ExtensionDataTransfer1_1, Data: nd, }) @@ -1876,10 +1856,10 @@ func TestResponseHookWhenExtensionNotFound(t *testing.T) { ctx := context.Background() ctx, cancel := context.WithTimeout(ctx, 10*time.Second) defer cancel() - gsData := testutil.NewGraphsyncTestingData(ctx, t, nil, nil) + gsData := NewGraphsyncTestingData(ctx, t, nil, nil) host1 := gsData.Host1 // initiator and data sender host2 := gsData.Host2 // data recipient, makes graphsync request for data - voucher := testutil.FakeDTType{Data: "applesauce"} + voucher := testutil.NewTestTypedVoucherWith("applesauce") link := gsData.LoadUnixFSFile(t, false) // setup receiving peer to just record message coming in @@ -1887,7 +1867,7 @@ func TestResponseHookWhenExtensionNotFound(t *testing.T) { r := &receiver{ messageReceived: make(chan receivedMessage), } - dtnet2.SetDelegate(r) + dtnet2.SetDelegate(datatransfer.LegacyTransportID, []datatransfer.Version{datatransfer.LegacyTransportVersion}, r) gsr := &fakeGraphSyncReceiver{ receivedMessages: make(chan receivedGraphSyncMessage), @@ -1895,8 +1875,8 @@ func TestResponseHookWhenExtensionNotFound(t *testing.T) { gsData.GsNet2.SetDelegate(gsr) gs1 := gsData.SetupGraphsyncHost1() - tp1 := tp.NewTransport(host1.ID(), gs1) - dt1, err := NewDataTransfer(gsData.DtDs1, gsData.DtNet1, tp1) + tp1 := tp.NewTransport(gs1, gsData.DtNet1) + dt1, err := NewDataTransfer(gsData.DtDs1, gsData.Host1.ID(), tp1) require.NoError(t, err) testutil.StartAndWaitForReady(ctx, t, dt1) t.Run("when it's not our extension, does not error and does not validate", func(t *testing.T) { @@ -1907,7 +1887,7 @@ func TestResponseHookWhenExtensionNotFound(t *testing.T) { } gs1.RegisterIncomingRequestHook(validateHook) - _, err := dt1.OpenPushDataChannel(ctx, host2.ID(), &voucher, link.(cidlink.Link).Cid, gsData.AllSelector) + _, err := dt1.OpenPushDataChannel(ctx, host2.ID(), voucher, link.(cidlink.Link).Cid, selectorparse.CommonSelector_ExploreAllRecursively) require.NoError(t, err) select { @@ -1916,7 +1896,7 @@ func TestResponseHookWhenExtensionNotFound(t *testing.T) { case <-r.messageReceived: } - request := gsmsg.NewRequest(graphsync.NewRequestID(), link.(cidlink.Link).Cid, gsData.AllSelector, graphsync.Priority(rand.Int31())) + request := gsmsg.NewRequest(graphsync.NewRequestID(), link.(cidlink.Link).Cid, selectorparse.CommonSelector_ExploreAllRecursively, graphsync.Priority(rand.Int31())) builder := gsmsg.NewBuilder() builder.AddRequest(request) gsmessage, err := builder.Build() @@ -1932,23 +1912,24 @@ func TestRespondingToPullGraphsyncRequests(t *testing.T) { //create network ctx := context.Background() testCases := map[string]struct { - test func(*testing.T, *testutil.GraphsyncTestingData, datatransfer.Transport, ipld.Link, datatransfer.TransferID, *fakeGraphSyncReceiver) + test func(*testing.T, *GraphsyncTestingData, datatransfer.Transport, ipld.Link, datatransfer.TransferID, *fakeGraphSyncReceiver) }{ "When a pull request is initiated and validated": { - test: func(t *testing.T, gsData *testutil.GraphsyncTestingData, tp2 datatransfer.Transport, link ipld.Link, id datatransfer.TransferID, gsr *fakeGraphSyncReceiver) { + test: func(t *testing.T, gsData *GraphsyncTestingData, tp2 datatransfer.Transport, link ipld.Link, id datatransfer.TransferID, gsr *fakeGraphSyncReceiver) { sv := testutil.NewStubbedValidator() sv.ExpectSuccessPull() + sv.StubResult(datatransfer.ValidationResult{Accepted: true}) - dt1, err := NewDataTransfer(gsData.DtDs2, gsData.DtNet2, tp2) + dt1, err := NewDataTransfer(gsData.DtDs2, gsData.Host2.ID(), tp2) require.NoError(t, err) testutil.StartAndWaitForReady(ctx, t, dt1) - require.NoError(t, dt1.RegisterVoucherType(&testutil.FakeDTType{}, sv)) + require.NoError(t, dt1.RegisterVoucherType(testutil.TestVoucherType, sv)) - voucher := testutil.NewFakeDTType() - request, err := message.NewRequest(id, false, true, voucher.Type(), voucher, testutil.GenerateCids(1)[0], gsData.AllSelector) + voucher := testutil.NewTestTypedVoucher() + request, err := message.NewRequest(id, false, true, &voucher, testutil.GenerateCids(1)[0], selectorparse.CommonSelector_ExploreAllRecursively) require.NoError(t, err) - nd, err := request.ToIPLD() - gsRequest := gsmsg.NewRequest(graphsync.NewRequestID(), link.(cidlink.Link).Cid, gsData.AllSelector, graphsync.Priority(rand.Int31()), graphsync.ExtensionData{ + nd := request.ToIPLD() + gsRequest := gsmsg.NewRequest(graphsync.NewRequestID(), link.(cidlink.Link).Cid, selectorparse.CommonSelector_ExploreAllRecursively, graphsync.Priority(rand.Int31()), graphsync.ExtensionData{ Name: extension.ExtensionDataTransfer1_1, Data: nd, }) @@ -1964,20 +1945,19 @@ func TestRespondingToPullGraphsyncRequests(t *testing.T) { }, }, "When request is initiated, but fails validation": { - test: func(t *testing.T, gsData *testutil.GraphsyncTestingData, tp2 datatransfer.Transport, link ipld.Link, id datatransfer.TransferID, gsr *fakeGraphSyncReceiver) { + test: func(t *testing.T, gsData *GraphsyncTestingData, tp2 datatransfer.Transport, link ipld.Link, id datatransfer.TransferID, gsr *fakeGraphSyncReceiver) { sv := testutil.NewStubbedValidator() sv.ExpectErrorPull() - dt1, err := NewDataTransfer(gsData.DtDs2, gsData.DtNet2, tp2) + dt1, err := NewDataTransfer(gsData.DtDs2, gsData.Host2.ID(), tp2) require.NoError(t, err) testutil.StartAndWaitForReady(ctx, t, dt1) - require.NoError(t, dt1.RegisterVoucherType(&testutil.FakeDTType{}, sv)) - voucher := testutil.NewFakeDTType() - dtRequest, err := message.NewRequest(id, false, true, voucher.Type(), voucher, testutil.GenerateCids(1)[0], gsData.AllSelector) + require.NoError(t, dt1.RegisterVoucherType(testutil.TestVoucherType, sv)) + voucher := testutil.NewTestTypedVoucher() + dtRequest, err := message.NewRequest(id, false, true, &voucher, testutil.GenerateCids(1)[0], selectorparse.CommonSelector_ExploreAllRecursively) require.NoError(t, err) - nd, err := dtRequest.ToIPLD() - require.NoError(t, err) - request := gsmsg.NewRequest(graphsync.NewRequestID(), link.(cidlink.Link).Cid, gsData.AllSelector, graphsync.Priority(rand.Int31()), graphsync.ExtensionData{ + nd := dtRequest.ToIPLD() + request := gsmsg.NewRequest(graphsync.NewRequestID(), link.(cidlink.Link).Cid, selectorparse.CommonSelector_ExploreAllRecursively, graphsync.Priority(rand.Int31()), graphsync.ExtensionData{ Name: extension.ExtensionDataTransfer1_1, Data: nd, }) @@ -2000,7 +1980,7 @@ func TestRespondingToPullGraphsyncRequests(t *testing.T) { ctx, cancel := context.WithTimeout(ctx, 10*time.Second) defer cancel() - gsData := testutil.NewGraphsyncTestingData(ctx, t, nil, nil) + gsData := NewGraphsyncTestingData(ctx, t, nil, nil) // setup receiving peer to just record message coming in gsr := &fakeGraphSyncReceiver{ @@ -2027,20 +2007,19 @@ func TestMultipleMessagesInExtension(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 4*time.Second) defer cancel() - gsData := testutil.NewGraphsyncTestingData(ctx, t, nil, nil) + gsData := NewGraphsyncTestingData(ctx, t, nil, nil) host1 := gsData.Host1 // initiator, data sender - root, origBytes := LoadRandomData(ctx, t, gsData.DagService1, 256000) - gsData.OrigBytes = origBytes + root := gsData.LoadUnixFSFile(t, false) rootCid := root.(cidlink.Link).Cid tp1 := gsData.SetupGSTransportHost1() tp2 := gsData.SetupGSTransportHost2() - dt1, err := NewDataTransfer(gsData.DtDs1, gsData.DtNet1, tp1) + dt1, err := NewDataTransfer(gsData.DtDs1, gsData.Host1.ID(), tp1) require.NoError(t, err) testutil.StartAndWaitForReady(ctx, t, dt1) - dt2, err := NewDataTransfer(gsData.DtDs2, gsData.DtNet2, tp2) + dt2, err := NewDataTransfer(gsData.DtDs2, gsData.Host2.ID(), tp2) require.NoError(t, err) testutil.StartAndWaitForReady(ctx, t, dt2) @@ -2055,24 +2034,20 @@ func TestMultipleMessagesInExtension(t *testing.T) { // In this retrieval flow we expect 2 voucher results: // The first one is sent as a response from the initial request telling the client // the provider has accepted the request and is starting to send blocks - respVoucher := testutil.NewFakeDTType() - encodedRVR, err := encoding.Encode(respVoucher) - require.NoError(t, err) + respVoucher := testutil.NewTestTypedVoucher() // voucher results are sent by the providers to request payment while pausing until a voucher is sent // to revalidate - voucherResults := []datatransfer.VoucherResult{ - &testutil.FakeDTType{Data: "one"}, - &testutil.FakeDTType{Data: "two"}, - &testutil.FakeDTType{Data: "thr"}, - &testutil.FakeDTType{Data: "for"}, - &testutil.FakeDTType{Data: "fiv"}, + voucherResults := []datatransfer.TypedVoucher{ + testutil.NewTestTypedVoucherWith("one"), + testutil.NewTestTypedVoucherWith("two"), + testutil.NewTestTypedVoucherWith("thr"), + testutil.NewTestTypedVoucherWith("for"), + testutil.NewTestTypedVoucherWith("fiv"), } // The final voucher result is sent by the provider to request a last payment voucher - finalVoucherResult := testutil.NewFakeDTType() - encodedFVR, err := encoding.Encode(finalVoucherResult) - require.NoError(t, err) + finalVoucherResult := testutil.NewTestTypedVoucher() dt2.SubscribeToEvents(func(event datatransfer.Event, channelState datatransfer.ChannelState) { if event.Code == datatransfer.Error { @@ -2081,12 +2056,11 @@ func TestMultipleMessagesInExtension(t *testing.T) { // Here we verify reception of voucherResults by the client if event.Code == datatransfer.NewVoucherResult { voucherResult := channelState.LastVoucherResult() - encodedVR, err := encoding.Encode(voucherResult) require.NoError(t, err) // If this voucher result is the response voucher no action is needed // we just know that the provider has accepted the transfer and is sending blocks - if bytes.Equal(encodedVR, encodedRVR) { + if voucherResult.Equals(respVoucher) { // The test will fail if no response voucher is received clientGotResponse <- struct{}{} } @@ -2094,18 +2068,16 @@ func TestMultipleMessagesInExtension(t *testing.T) { // If this voucher is a revalidation request we need to send a new voucher // to revalidate and unpause the transfer if clientPausePoint < 5 { - encodedExpected, err := encoding.Encode(voucherResults[clientPausePoint]) - require.NoError(t, err) - if bytes.Equal(encodedVR, encodedExpected) { - _ = dt2.SendVoucher(ctx, chid, testutil.NewFakeDTType()) + if voucherResult.Equals(voucherResults[clientPausePoint]) { + _ = dt2.SendVoucher(ctx, chid, testutil.NewTestTypedVoucher()) clientPausePoint++ } } // If this voucher result is the final voucher result we need // to send a new voucher to unpause the provider and complete the transfer - if bytes.Equal(encodedVR, encodedFVR) { - _ = dt2.SendVoucher(ctx, chid, testutil.NewFakeDTType()) + if voucherResult.Equals(finalVoucherResult) { + _ = dt2.SendVoucher(ctx, chid, testutil.NewTestTypedVoucher()) } } @@ -2115,6 +2087,13 @@ func TestMultipleMessagesInExtension(t *testing.T) { }) providerFinished := make(chan struct{}, 1) + nextVoucherResult := 0 + sv := &retrievalRevalidator{ + StubbedValidator: testutil.NewStubbedValidator(), + pausePoints: pausePoints, + requiresFinalization: true, + initialVoucherResult: &respVoucher, + } dt1.SubscribeToEvents(func(event datatransfer.Event, channelState datatransfer.ChannelState) { if event.Code == datatransfer.Error { errChan <- struct{}{} @@ -2122,27 +2101,25 @@ func TestMultipleMessagesInExtension(t *testing.T) { if channelState.Status() == datatransfer.Completed { providerFinished <- struct{}{} } + if event.Code == datatransfer.NewVoucher && channelState.Queued() > 0 { + vs := sv.nextStatus() + dt1.UpdateValidationStatus(ctx, chid, vs) + } + if event.Code == datatransfer.DataLimitExceeded { + if nextVoucherResult < len(pausePoints) { + dt1.SendVoucherResult(ctx, chid, voucherResults[nextVoucherResult]) + nextVoucherResult++ + } + } + if event.Code == datatransfer.BeginFinalizing { + sv.requiresFinalization = false + dt1.SendVoucherResult(ctx, chid, finalVoucherResult) + } }) + require.NoError(t, dt1.RegisterVoucherType(testutil.TestVoucherType, sv)) - sv := testutil.NewStubbedValidator() - require.NoError(t, dt1.RegisterVoucherType(&testutil.FakeDTType{}, sv)) - // Stub in the validator so it returns that exact voucher when calling ValidatePull - // this validator will not pause transfer when accepting a transfer and will start - // sending blocks immediately - sv.StubResult(respVoucher) - - srv := &retrievalRevalidator{ - testutil.NewStubbedRevalidator(), 0, 0, pausePoints, finalVoucherResult, voucherResults, - } - // The stubbed revalidator will authorize Revalidate and return ErrResume to finisht the transfer - srv.ExpectSuccessErrResume() - require.NoError(t, dt1.RegisterRevalidator(testutil.NewFakeDTType(), srv)) - - // Register our response voucher with the client - require.NoError(t, dt2.RegisterVoucherResultType(respVoucher)) - - voucher := testutil.FakeDTType{Data: "applesauce"} - chid, err = dt2.OpenPullDataChannel(ctx, host1.ID(), &voucher, rootCid, gsData.AllSelector) + voucher := testutil.NewTestTypedVoucherWith("applesauce") + chid, err = dt2.OpenPullDataChannel(ctx, host1.ID(), voucher, rootCid, selectorparse.CommonSelector_ExploreAllRecursively) require.NoError(t, err) // Expect the client to receive a response voucher, the provider to complete the transfer and @@ -2162,19 +2139,9 @@ func TestMultipleMessagesInExtension(t *testing.T) { } } sv.VerifyExpectations(t) - srv.VerifyExpectations(t) gsData.VerifyFileTransferred(t, root, true) } -// completeRevalidator does not pause when sending the last voucher to confirm the deal is completed -type completeRevalidator struct { - *retrievalRevalidator -} - -func (r *completeRevalidator) OnComplete(chid datatransfer.ChannelID) (bool, datatransfer.VoucherResult, error) { - return true, r.finalVoucher, nil -} - func TestMultipleParallelTransfers(t *testing.T) { // Add more sizes here to trigger more transfers. @@ -2182,51 +2149,35 @@ func TestMultipleParallelTransfers(t *testing.T) { ctx := context.Background() - gsData := testutil.NewGraphsyncTestingData(ctx, t, nil, nil) + gsData := NewGraphsyncTestingData(ctx, t, nil, nil) host1 := gsData.Host1 // initiator, data sender tp1 := gsData.SetupGSTransportHost1() tp2 := gsData.SetupGSTransportHost2() - dt1, err := NewDataTransfer(gsData.DtDs1, gsData.DtNet1, tp1) + dt1, err := NewDataTransfer(gsData.DtDs1, gsData.Host1.ID(), tp1) require.NoError(t, err) testutil.StartAndWaitForReady(ctx, t, dt1) - dt2, err := NewDataTransfer(gsData.DtDs2, gsData.DtNet2, tp2) + dt2, err := NewDataTransfer(gsData.DtDs2, gsData.Host2.ID(), tp2) require.NoError(t, err) testutil.StartAndWaitForReady(ctx, t, dt2) // In this retrieval flow we expect 2 voucher results: // The first one is sent as a response from the initial request telling the client // the provider has accepted the request and is starting to send blocks - respVoucher := testutil.NewFakeDTType() - encodedRVR, err := encoding.Encode(respVoucher) + respVoucher := testutil.NewTestTypedVoucher() require.NoError(t, err) // The final voucher result is sent by the provider to let the client know the deal is completed - finalVoucherResult := testutil.NewFakeDTType() - encodedFVR, err := encoding.Encode(finalVoucherResult) + finalVoucherResult := testutil.NewTestTypedVoucher() require.NoError(t, err) - sv := testutil.NewStubbedValidator() - require.NoError(t, dt1.RegisterVoucherType(&testutil.FakeDTType{}, sv)) - // Stub in the validator so it returns that exact voucher when calling ValidatePull - // this validator will not pause transfer when accepting a transfer and will start - // sending blocks immediately - sv.StubResult(respVoucher) - - // no need for intermediary voucher results - voucherResults := []datatransfer.VoucherResult{} - - pausePoints := []uint64{} - srv := &retrievalRevalidator{ - testutil.NewStubbedRevalidator(), 0, 0, pausePoints, finalVoucherResult, voucherResults, + sv := &retrievalRevalidator{ + StubbedValidator: testutil.NewStubbedValidator(), + initialVoucherResult: &respVoucher, } - srv.ExpectSuccessErrResume() - require.NoError(t, dt1.RegisterRevalidator(testutil.NewFakeDTType(), srv)) - - // Register our response voucher with the client - require.NoError(t, dt2.RegisterVoucherResultType(respVoucher)) + require.NoError(t, dt1.RegisterVoucherType(testutil.TestVoucherType, sv)) // for each size we create a new random DAG of the given size and try to retrieve it for _, size := range sizes { @@ -2255,20 +2206,20 @@ func TestMultipleParallelTransfers(t *testing.T) { // Here we verify reception of voucherResults by the client if event.Code == datatransfer.NewVoucherResult { voucherResult := channelState.LastVoucherResult() - encodedVR, err := encoding.Encode(voucherResult) + require.NoError(t, err) require.NoError(t, err) // If this voucher result is the response voucher no action is needed // we just know that the provider has accepted the transfer and is sending blocks - if bytes.Equal(encodedVR, encodedRVR) { + if voucherResult.Equals(respVoucher) { // The test will fail if no response voucher is received clientGotResponse <- struct{}{} } // If this voucher result is the final voucher result we need // to send a new voucher to unpause the provider and complete the transfer - if bytes.Equal(encodedVR, encodedFVR) { - _ = dt2.SendVoucher(ctx, chid, testutil.NewFakeDTType()) + if voucherResult.Equals(finalVoucherResult) { + _ = dt2.SendVoucher(ctx, chid, testutil.NewTestTypedVoucher()) } } @@ -2284,19 +2235,21 @@ func TestMultipleParallelTransfers(t *testing.T) { return } if event.Code == datatransfer.Error { - fmt.Println(event.Message) errChan <- struct{}{} } if channelState.Status() == datatransfer.Completed { providerFinished <- struct{}{} } + if event.Code == datatransfer.BeginFinalizing { + dt1.SendVoucherResult(ctx, chid, finalVoucherResult) + } }) root, origBytes := LoadRandomData(ctx, t, gsData.DagService1, size) rootCid := root.(cidlink.Link).Cid - voucher := testutil.NewFakeDTType() - chid, err = dt2.OpenPullDataChannel(ctx, host1.ID(), voucher, rootCid, gsData.AllSelector) + voucher := testutil.NewTestTypedVoucher() + chid, err = dt2.OpenPullDataChannel(ctx, host1.ID(), voucher, rootCid, selectorparse.CommonSelector_ExploreAllRecursively) require.NoError(t, err) close(chidReceived) // Expect the client to receive a response voucher, the provider to complete the transfer and @@ -2325,8 +2278,7 @@ func TestMultipleParallelTransfers(t *testing.T) { } } sv.VerifyExpectations(t) - srv.VerifyExpectations(t) - testutil.VerifyHasFile(gsData.Ctx, t, gsData.DagService2, root, origBytes) + VerifyHasFile(gsData.Ctx, t, gsData.DagService2, root, origBytes) }) } } diff --git a/impl/restart_integration_test.go b/itest/restart_integration_test.go similarity index 88% rename from impl/restart_integration_test.go rename to itest/restart_integration_test.go index c3f4e389..68e02263 100644 --- a/impl/restart_integration_test.go +++ b/itest/restart_integration_test.go @@ -1,4 +1,4 @@ -package impl_test +package itest import ( "context" @@ -10,14 +10,15 @@ import ( ipldformat "github.com/ipfs/go-ipld-format" "github.com/ipld/go-ipld-prime" cidlink "github.com/ipld/go-ipld-prime/linking/cid" + selectorparse "github.com/ipld/go-ipld-prime/traversal/selector/parse" "github.com/libp2p/go-libp2p-core/peer" "github.com/stretchr/testify/require" "go.uber.org/atomic" "golang.org/x/xerrors" - datatransfer "github.com/filecoin-project/go-data-transfer" - . "github.com/filecoin-project/go-data-transfer/impl" - "github.com/filecoin-project/go-data-transfer/testutil" + datatransfer "github.com/filecoin-project/go-data-transfer/v2" + . "github.com/filecoin-project/go-data-transfer/v2/impl" + "github.com/filecoin-project/go-data-transfer/v2/testutil" ) const totalIncrements = 204 @@ -41,8 +42,8 @@ func TestRestartPush(t *testing.T) { "Restart peer create push": { stopAt: 20, openPushF: func(rh *restartHarness) datatransfer.ChannelID { - voucher := testutil.FakeDTType{Data: "applesauce"} - chid, err := rh.dt1.OpenPushDataChannel(rh.testCtx, rh.peer2, &voucher, rh.rootCid, rh.gsData.AllSelector) + voucher := testutil.NewTestTypedVoucherWith("applesauce") + chid, err := rh.dt1.OpenPushDataChannel(rh.testCtx, rh.peer2, voucher, rh.rootCid, selectorparse.CommonSelector_ExploreAllRecursively) require.NoError(rh.t, err) return chid }, @@ -51,9 +52,9 @@ func TestRestartPush(t *testing.T) { require.NoError(t, rh.dt1.Stop(rh.testCtx)) time.Sleep(100 * time.Millisecond) tp1 := rh.gsData.SetupGSTransportHost1() - rh.dt1, err = NewDataTransfer(rh.gsData.DtDs1, rh.gsData.DtNet1, tp1) + rh.dt1, err = NewDataTransfer(rh.gsData.DtDs1, rh.gsData.Host1.ID(), tp1) require.NoError(rh.t, err) - require.NoError(rh.t, rh.dt1.RegisterVoucherType(&testutil.FakeDTType{}, rh.sv)) + require.NoError(rh.t, rh.dt1.RegisterVoucherType(testutil.TestVoucherType, rh.sv)) testutil.StartAndWaitForReady(rh.testCtx, t, rh.dt1) rh.dt1.SubscribeToEvents(subscriber) require.NoError(rh.t, rh.dt1.RestartDataTransferChannel(rh.testCtx, chId)) @@ -83,8 +84,8 @@ func TestRestartPush(t *testing.T) { "Restart peer receive push": { stopAt: 20, openPushF: func(rh *restartHarness) datatransfer.ChannelID { - voucher := testutil.FakeDTType{Data: "applesauce"} - chid, err := rh.dt1.OpenPushDataChannel(rh.testCtx, rh.peer2, &voucher, rh.rootCid, rh.gsData.AllSelector) + voucher := testutil.NewTestTypedVoucherWith("applesauce") + chid, err := rh.dt1.OpenPushDataChannel(rh.testCtx, rh.peer2, voucher, rh.rootCid, selectorparse.CommonSelector_ExploreAllRecursively) require.NoError(rh.t, err) return chid }, @@ -93,9 +94,9 @@ func TestRestartPush(t *testing.T) { require.NoError(t, rh.dt2.Stop(rh.testCtx)) time.Sleep(100 * time.Millisecond) tp2 := rh.gsData.SetupGSTransportHost2() - rh.dt2, err = NewDataTransfer(rh.gsData.DtDs2, rh.gsData.DtNet2, tp2) + rh.dt2, err = NewDataTransfer(rh.gsData.DtDs2, rh.gsData.Host2.ID(), tp2) require.NoError(rh.t, err) - require.NoError(rh.t, rh.dt2.RegisterVoucherType(&testutil.FakeDTType{}, rh.sv)) + require.NoError(rh.t, rh.dt2.RegisterVoucherType(testutil.TestVoucherType, rh.sv)) testutil.StartAndWaitForReady(rh.testCtx, t, rh.dt2) rh.dt2.SubscribeToEvents(subscriber) require.NoError(rh.t, rh.dt2.RestartDataTransferChannel(rh.testCtx, chId)) @@ -108,7 +109,8 @@ func TestRestartPush(t *testing.T) { // initiator: abort GS response "transfer(0)->response(0)->abortRequest(0)", // initiator: receive restart request and send restart channel message - "transfer(0)->receiveRequest(0)->sendMessage(0)", + "transfer(0)->receiveRequest(0)", + "transfer(0)->sendMessage(1)", // initiator: receive second GS request in response to restart channel message // and execute GS response "transfer(0)->response(1)->executeTask(0)", @@ -134,6 +136,10 @@ func TestRestartPush(t *testing.T) { // START DATA TRANSFER INSTANCES rh.sv.ExpectSuccessPush() + rh.sv.StubResult(datatransfer.ValidationResult{Accepted: true}) + rh.sv.ExpectSuccessValidateRestart() + rh.sv.StubRestartResult(datatransfer.ValidationResult{Accepted: true}) + testutil.StartAndWaitForReady(rh.testCtx, t, rh.dt1) testutil.StartAndWaitForReady(rh.testCtx, t, rh.dt2) @@ -254,7 +260,7 @@ func TestRestartPush(t *testing.T) { // verify all cids are present on the receiver - testutil.VerifyHasFile(rh.testCtx, t, rh.destDagService, rh.root, rh.origBytes) + VerifyHasFile(rh.testCtx, t, rh.destDagService, rh.root, rh.origBytes) rh.sv.VerifyExpectations(t) // we should ONLY see two opens and two completes @@ -290,8 +296,8 @@ func TestRestartPull(t *testing.T) { "Restart peer create pull": { stopAt: 40, openPullF: func(rh *restartHarness) datatransfer.ChannelID { - voucher := testutil.FakeDTType{Data: "applesauce"} - chid, err := rh.dt2.OpenPullDataChannel(rh.testCtx, rh.peer1, &voucher, rh.rootCid, rh.gsData.AllSelector) + voucher := testutil.NewTestTypedVoucherWith("applesauce") + chid, err := rh.dt2.OpenPullDataChannel(rh.testCtx, rh.peer1, voucher, rh.rootCid, selectorparse.CommonSelector_ExploreAllRecursively) require.NoError(rh.t, err) return chid }, @@ -300,9 +306,9 @@ func TestRestartPull(t *testing.T) { require.NoError(t, rh.dt2.Stop(rh.testCtx)) time.Sleep(100 * time.Millisecond) tp2 := rh.gsData.SetupGSTransportHost2() - rh.dt2, err = NewDataTransfer(rh.gsData.DtDs2, rh.gsData.DtNet2, tp2) + rh.dt2, err = NewDataTransfer(rh.gsData.DtDs2, rh.gsData.Host2.ID(), tp2) require.NoError(rh.t, err) - require.NoError(rh.t, rh.dt2.RegisterVoucherType(&testutil.FakeDTType{}, rh.sv)) + require.NoError(rh.t, rh.dt2.RegisterVoucherType(testutil.TestVoucherType, rh.sv)) testutil.StartAndWaitForReady(rh.testCtx, t, rh.dt2) rh.dt2.SubscribeToEvents(subscriber) require.NoError(rh.t, rh.dt2.RestartDataTransferChannel(rh.testCtx, chId)) @@ -329,8 +335,8 @@ func TestRestartPull(t *testing.T) { "Restart peer receive pull": { stopAt: 40, openPullF: func(rh *restartHarness) datatransfer.ChannelID { - voucher := testutil.FakeDTType{Data: "applesauce"} - chid, err := rh.dt2.OpenPullDataChannel(rh.testCtx, rh.peer1, &voucher, rh.rootCid, rh.gsData.AllSelector) + voucher := testutil.NewTestTypedVoucherWith("applesauce") + chid, err := rh.dt2.OpenPullDataChannel(rh.testCtx, rh.peer1, voucher, rh.rootCid, selectorparse.CommonSelector_ExploreAllRecursively) require.NoError(rh.t, err) return chid }, @@ -339,9 +345,9 @@ func TestRestartPull(t *testing.T) { require.NoError(t, rh.dt1.Stop(rh.testCtx)) time.Sleep(100 * time.Millisecond) tp1 := rh.gsData.SetupGSTransportHost1() - rh.dt1, err = NewDataTransfer(rh.gsData.DtDs1, rh.gsData.DtNet1, tp1) + rh.dt1, err = NewDataTransfer(rh.gsData.DtDs1, rh.gsData.Host1.ID(), tp1) require.NoError(rh.t, err) - require.NoError(rh.t, rh.dt1.RegisterVoucherType(&testutil.FakeDTType{}, rh.sv)) + require.NoError(rh.t, rh.dt1.RegisterVoucherType(testutil.TestVoucherType, rh.sv)) testutil.StartAndWaitForReady(rh.testCtx, t, rh.dt1) rh.dt1.SubscribeToEvents(subscriber) require.NoError(rh.t, rh.dt1.RestartDataTransferChannel(rh.testCtx, chId)) @@ -352,7 +358,8 @@ func TestRestartPull(t *testing.T) { // initiator: initial outgoing gs request terminates "transfer(0)->request(0)->terminateRequest(0)", // initiator: respond to restart request and send second GS request - "transfer(0)->receiveRequest(0)->request(0)", + "transfer(0)->receiveRequest(0)", + "transfer(0)->request(1)->executeTask(0)", // initiator: receive completion message from responder that they sent all the data "transfer(0)->receiveResponse(0)", // responder: receive GS request and execute response @@ -378,6 +385,10 @@ func TestRestartPull(t *testing.T) { // START DATA TRANSFER INSTANCES rh.sv.ExpectSuccessPull() + rh.sv.StubResult(datatransfer.ValidationResult{Accepted: true}) + rh.sv.ExpectSuccessValidateRestart() + rh.sv.StubRestartResult(datatransfer.ValidationResult{Accepted: true}) + testutil.StartAndWaitForReady(rh.testCtx, t, rh.dt1) testutil.StartAndWaitForReady(rh.testCtx, t, rh.dt2) @@ -492,7 +503,7 @@ func TestRestartPull(t *testing.T) { _, _, err = waitF(10*time.Second, 2) require.NoError(t, err) - testutil.VerifyHasFile(rh.testCtx, t, rh.destDagService, rh.root, rh.origBytes) + VerifyHasFile(rh.testCtx, t, rh.destDagService, rh.root, rh.origBytes) rh.sv.VerifyExpectations(t) // we should ONLY see two opens and two completes @@ -526,7 +537,7 @@ type restartHarness struct { peer1 peer.ID peer2 peer.ID - gsData *testutil.GraphsyncTestingData + gsData *GraphsyncTestingData dt1 datatransfer.Manager dt2 datatransfer.Manager sv *testutil.StubbedValidator @@ -543,7 +554,7 @@ func newRestartHarness(t *testing.T) *restartHarness { ctx, cancel := context.WithTimeout(ctx, 120*time.Second) // Setup host - gsData := testutil.NewGraphsyncTestingData(ctx, t, nil, nil) + gsData := NewGraphsyncTestingData(ctx, t, nil, nil) host1 := gsData.Host1 // initiator, data sender host2 := gsData.Host2 // data recipient peer1 := host1.ID() @@ -555,18 +566,18 @@ func newRestartHarness(t *testing.T) *restartHarness { tp1 := gsData.SetupGSTransportHost1() tp2 := gsData.SetupGSTransportHost2() - dt1, err := NewDataTransfer(gsData.DtDs1, gsData.DtNet1, tp1) + dt1, err := NewDataTransfer(gsData.DtDs1, gsData.Host1.ID(), tp1) require.NoError(t, err) - dt2, err := NewDataTransfer(gsData.DtDs2, gsData.DtNet2, tp2) + dt2, err := NewDataTransfer(gsData.DtDs2, gsData.Host2.ID(), tp2) require.NoError(t, err) sv := testutil.NewStubbedValidator() - require.NoError(t, dt1.RegisterVoucherType(&testutil.FakeDTType{}, sv)) - require.NoError(t, dt2.RegisterVoucherType(&testutil.FakeDTType{}, sv)) + require.NoError(t, dt1.RegisterVoucherType(testutil.TestVoucherType, sv)) + require.NoError(t, dt2.RegisterVoucherType(testutil.TestVoucherType, sv)) sourceDagService := gsData.DagService1 - root, origBytes := testutil.LoadUnixFSFile(ctx, t, sourceDagService, largeFile) + root, origBytes := LoadUnixFSFile(ctx, t, sourceDagService, largeFile) rootCid := root.(cidlink.Link).Cid destDagService := gsData.DagService2 diff --git a/manager.go b/manager.go index 9d540abe..b7a6f217 100644 --- a/manager.go +++ b/manager.go @@ -4,65 +4,96 @@ import ( "context" "github.com/ipfs/go-cid" - "github.com/ipld/go-ipld-prime" + "github.com/ipld/go-ipld-prime/datamodel" "github.com/libp2p/go-libp2p-core/peer" ) +// ValidationResult describes the result of validating a voucher +type ValidationResult struct { + // Accepted indicates whether the request was accepted. If a request is not + // accepted, the request fails. This is true for revalidation as well + Accepted bool + // VoucherResult provides information to the other party about what happened + // with the voucher + VoucherResult *TypedVoucher + // ForcePause indicates whether the request should be paused, regardless + // of data limit and finalization status + ForcePause bool + // DataLimit specifies how much data this voucher is good for. When the amount + // of data specified is reached (or shortly after), the request will pause + // pending revalidation. 0 indicates no limit. + DataLimit uint64 + // RequiresFinalization indicates at the end of the transfer, the channel should + // be left open for a final settlement + RequiresFinalization bool +} + +// Equals checks the deep equality of two ValidationResult values +func (vr ValidationResult) Equals(vr2 ValidationResult) bool { + return vr.Accepted == vr2.Accepted && + vr.ForcePause == vr2.ForcePause && + vr.DataLimit == vr2.DataLimit && + vr.RequiresFinalization == vr2.RequiresFinalization && + (vr.VoucherResult == nil) == (vr2.VoucherResult == nil) && + (vr.VoucherResult == nil || vr.VoucherResult.Equals(*vr2.VoucherResult)) +} + +// LeaveRequestPaused indicates whether all conditions are met to resume a request +func (vr ValidationResult) LeaveRequestPaused(chst ChannelState) bool { + if chst == nil { + return false + } + if vr.ForcePause { + return true + } + if vr.RequiresFinalization && chst.Status().InFinalization() { + return true + } + var limitFactor uint64 + if chst.IsPull() { + limitFactor = chst.Queued() + } else { + limitFactor = chst.Received() + } + return vr.DataLimit != 0 && limitFactor >= vr.DataLimit +} + // RequestValidator is an interface implemented by the client of the // data transfer module to validate requests type RequestValidator interface { // ValidatePush validates a push request received from the peer that will send data + // -- All information about the validation operation is contained in ValidationResult, + // including if it was rejected. Information about why a rejection occurred is embedded + // in the VoucherResult. + // -- error indicates something went wrong with the actual process of trying to validate ValidatePush( - isRestart bool, chid ChannelID, sender peer.ID, - voucher Voucher, + voucher datamodel.Node, baseCid cid.Cid, - selector ipld.Node) (VoucherResult, error) + selector datamodel.Node) (ValidationResult, error) // ValidatePull validates a pull request received from the peer that will receive data + // -- All information about the validation operation is contained in ValidationResult, + // including if it was rejected. Information about why a rejection occurred should be embedded + // in the VoucherResult. + // -- error indicates something went wrong with the actual process of trying to validate ValidatePull( - isRestart bool, chid ChannelID, receiver peer.ID, - voucher Voucher, + voucher datamodel.Node, baseCid cid.Cid, - selector ipld.Node) (VoucherResult, error) -} - -// Revalidator is a request validator revalidates in progress requests -// by requesting request additional vouchers, and resuming when it receives them -type Revalidator interface { - // Revalidate revalidates a request with a new voucher - Revalidate(channelID ChannelID, voucher Voucher) (VoucherResult, error) - // OnPullDataSent is called on the responder side when more bytes are sent - // for a given pull request. The first value indicates whether the request was - // recognized by this revalidator and should be considered 'handled'. If true, - // the remaining two values are interpreted. If 'false' the request is passed on - // to the next revalidators. - // It should return a VoucherResult + ErrPause to - // request revalidation or nil to continue uninterrupted, - // other errors will terminate the request. - OnPullDataSent(chid ChannelID, additionalBytesSent uint64) (bool, VoucherResult, error) - // OnPushDataReceived is called on the responder side when more bytes are received - // for a given push request. The first value indicates whether the request was - // recognized by this revalidator and should be considered 'handled'. If true, - // the remaining two values are interpreted. If 'false' the request is passed on - // to the next revalidators. It should return a VoucherResult + ErrPause to - // request revalidation or nil to continue uninterrupted, - // other errors will terminate the request - OnPushDataReceived(chid ChannelID, additionalBytesReceived uint64) (bool, VoucherResult, error) - // OnComplete is called to make a final request for revalidation -- often for the - // purpose of settlement. The first value indicates whether the request was - // recognized by this revalidator and should be considered 'handled'. If true, - // the remaining two values are interpreted. If 'false' the request is passed on - // to the next revalidators. - // if VoucherResult is non nil, the request will enter a settlement phase awaiting - // a final update - OnComplete(chid ChannelID) (bool, VoucherResult, error) + selector datamodel.Node) (ValidationResult, error) + + // ValidateRestart validates restarting a request + // -- All information about the validation operation is contained in ValidationResult, + // including if it was rejected. Information about why a rejection occurred should be embedded + // in the VoucherResult. + // -- error indicates something went wrong with the actual process of trying to validate + ValidateRestart(channelID ChannelID, channel ChannelState) (ValidationResult, error) } // TransportConfigurer provides a mechanism to provide transport specific configuration for a given voucher type -type TransportConfigurer func(chid ChannelID, voucher Voucher, transport Transport) +type TransportConfigurer func(chid ChannelID, voucher TypedVoucher, transport Transport) // ReadyFunc is function that gets called once when the data transfer module is ready type ReadyFunc func(error) @@ -83,33 +114,29 @@ type Manager interface { // RegisterVoucherType registers a validator for the given voucher type // will error if voucher type does not implement voucher // or if there is a voucher type registered with an identical identifier - RegisterVoucherType(voucherType Voucher, validator RequestValidator) error - - // RegisterRevalidator registers a revalidator for the given voucher type - // Note: this is the voucher type used to revalidate. It can share a name - // with the initial validator type and CAN be the same type, or a different type. - // The revalidator can simply be the sampe as the original request validator, - // or a different validator that satisfies the revalidator interface. - RegisterRevalidator(voucherType Voucher, revalidator Revalidator) error - - // RegisterVoucherResultType allows deserialization of a voucher result, - // so that a listener can read the metadata - RegisterVoucherResultType(resultType VoucherResult) error + RegisterVoucherType(voucherType TypeIdentifier, validator RequestValidator) error // RegisterTransportConfigurer registers the given transport configurer to be run on requests with the given voucher // type - RegisterTransportConfigurer(voucherType Voucher, configurer TransportConfigurer) error + RegisterTransportConfigurer(voucherType TypeIdentifier, configurer TransportConfigurer) error // open a data transfer that will send data to the recipient peer and // transfer parts of the piece that match the selector - OpenPushDataChannel(ctx context.Context, to peer.ID, voucher Voucher, baseCid cid.Cid, selector ipld.Node) (ChannelID, error) + OpenPushDataChannel(ctx context.Context, to peer.ID, voucher TypedVoucher, baseCid cid.Cid, selector datamodel.Node) (ChannelID, error) // open a data transfer that will request data from the sending peer and // transfer parts of the piece that match the selector - OpenPullDataChannel(ctx context.Context, to peer.ID, voucher Voucher, baseCid cid.Cid, selector ipld.Node) (ChannelID, error) + OpenPullDataChannel(ctx context.Context, to peer.ID, voucher TypedVoucher, baseCid cid.Cid, selector datamodel.Node) (ChannelID, error) // send an intermediate voucher as needed when the receiver sends a request for revalidation - SendVoucher(ctx context.Context, chid ChannelID, voucher Voucher) error + SendVoucher(ctx context.Context, chid ChannelID, voucher TypedVoucher) error + + // send information from the responder to update the initiator on the state of their voucher + SendVoucherResult(ctx context.Context, chid ChannelID, voucherResult TypedVoucher) error + + // Update the validation status for a given channel, to change data limits, finalization, accepted status, and pause state + // and send new voucher results as + UpdateValidationStatus(ctx context.Context, chid ChannelID, validationResult ValidationResult) error // close an open channel (effectively a cancel) CloseDataTransferChannel(ctx context.Context, chid ChannelID) error diff --git a/message.go b/message.go index 73c8767c..b44b43e2 100644 --- a/message.go +++ b/message.go @@ -1,20 +1,51 @@ package datatransfer import ( + "errors" + "fmt" "io" + "strconv" + "strings" "github.com/ipfs/go-cid" - "github.com/ipld/go-ipld-prime" "github.com/ipld/go-ipld-prime/datamodel" - "github.com/libp2p/go-libp2p-core/protocol" - - "github.com/filecoin-project/go-data-transfer/encoding" ) +type Version struct { + Major uint64 + Minor uint64 + Patch uint64 +} + +func (mv Version) String() string { + return fmt.Sprintf("%d.%d.%d", mv.Major, mv.Minor, mv.Patch) +} + +// MessageVersionFromString parses a string into a message version +func MessageVersionFromString(versionString string) (Version, error) { + versions := strings.Split(versionString, ".") + if len(versions) != 3 { + return Version{}, errors.New("not a version string") + } + major, err := strconv.ParseUint(versions[0], 10, 0) + if err != nil { + return Version{}, errors.New("unable to parse major version") + } + minor, err := strconv.ParseUint(versions[1], 10, 0) + if err != nil { + return Version{}, errors.New("unable to parse major version") + } + patch, err := strconv.ParseUint(versions[2], 10, 0) + if err != nil { + return Version{}, errors.New("unable to parse major version") + } + return Version{Major: major, Minor: minor, Patch: patch}, nil +} + var ( - // ProtocolDataTransfer1_2 is the protocol identifier for the latest - // version of data-transfer (supports do-not-send-first-blocks extension) - ProtocolDataTransfer1_2 protocol.ID = "/fil/datatransfer/1.2.0" + // DataTransfer1_2 is the identifier for the current + // supported version of data-transfer + DataTransfer1_2 Version = Version{1, 2, 0} ) // Message is a message for the data transfer protocol @@ -28,8 +59,17 @@ type Message interface { IsCancel() bool TransferID() TransferID ToNet(w io.Writer) error - ToIPLD() (datamodel.Node, error) - MessageForProtocol(targetProtocol protocol.ID) (newMsg Message, err error) + ToIPLD() datamodel.Node + MessageForVersion(targetProtocol Version) (newMsg Message, err error) + Version() Version + WrappedForTransport(transportID TransportID, transportVersion Version) TransportedMessage +} + +// TransportedMessage is a message that can also report how it was transported +type TransportedMessage interface { + Message + TransportID() TransportID + TransportVersion() Version } // Request is a response message for the data transfer protocol @@ -38,9 +78,10 @@ type Request interface { IsPull() bool IsVoucher() bool VoucherType() TypeIdentifier - Voucher(decoder encoding.Decoder) (encoding.Encodable, error) + Voucher() (datamodel.Node, error) + TypedVoucher() (TypedVoucher, error) BaseCid() cid.Cid - Selector() (ipld.Node, error) + Selector() (datamodel.Node, error) IsRestartExistingChannelRequest() bool RestartChannelId() (ChannelID, error) } @@ -48,10 +89,10 @@ type Request interface { // Response is a response message for the data transfer protocol type Response interface { Message - IsVoucherResult() bool + IsValidationResult() bool IsComplete() bool Accepted() bool VoucherResultType() TypeIdentifier - VoucherResult(decoder encoding.Decoder) (encoding.Encodable, error) + VoucherResult() (datamodel.Node, error) EmptyVoucherResult() bool } diff --git a/message/message.go b/message/message.go index ff4a988b..639046a4 100644 --- a/message/message.go +++ b/message/message.go @@ -1,19 +1,28 @@ package message import ( - message1_1 "github.com/filecoin-project/go-data-transfer/message/message1_1prime" + message1_1 "github.com/filecoin-project/go-data-transfer/v2/message/message1_1prime" ) var NewRequest = message1_1.NewRequest var RestartExistingChannelRequest = message1_1.RestartExistingChannelRequest var UpdateRequest = message1_1.UpdateRequest var VoucherRequest = message1_1.VoucherRequest + +// DEPRECATED: Use ValidationResultResponse var RestartResponse = message1_1.RestartResponse + +var ValidationResultResponse = message1_1.ValidationResultResponse + +// DEPRECATED: Use ValidationResultResponse var NewResponse = message1_1.NewResponse + +// DEPRECATED: Use ValidationResultResponse var VoucherResultResponse = message1_1.VoucherResultResponse var CancelResponse = message1_1.CancelResponse var UpdateResponse = message1_1.UpdateResponse var FromNet = message1_1.FromNet var FromIPLD = message1_1.FromIPLD +var FromNetWrapped = message1_1.FromNetWrapped var CompleteResponse = message1_1.CompleteResponse var CancelRequest = message1_1.CancelRequest diff --git a/message/message1_1/message.go b/message/message1_1/message.go deleted file mode 100644 index c07544c3..00000000 --- a/message/message1_1/message.go +++ /dev/null @@ -1,196 +0,0 @@ -package message1_1 - -import ( - "bytes" - "io" - - "github.com/ipfs/go-cid" - "github.com/ipld/go-ipld-prime" - "github.com/ipld/go-ipld-prime/codec/dagcbor" - "github.com/ipld/go-ipld-prime/datamodel" - cborgen "github.com/whyrusleeping/cbor-gen" - xerrors "golang.org/x/xerrors" - - datatransfer "github.com/filecoin-project/go-data-transfer" - "github.com/filecoin-project/go-data-transfer/encoding" - "github.com/filecoin-project/go-data-transfer/message/types" -) - -// NewRequest generates a new request for the data transfer protocol -func NewRequest(id datatransfer.TransferID, isRestart bool, isPull bool, vtype datatransfer.TypeIdentifier, voucher encoding.Encodable, baseCid cid.Cid, selector ipld.Node) (datatransfer.Request, error) { - vbytes, err := encoding.Encode(voucher) - if err != nil { - return nil, xerrors.Errorf("Creating request: %w", err) - } - if baseCid == cid.Undef { - return nil, xerrors.Errorf("base CID must be defined") - } - selBytes, err := encoding.Encode(selector) - if err != nil { - return nil, xerrors.Errorf("Error encoding selector") - } - - var typ uint64 - if isRestart { - typ = uint64(types.RestartMessage) - } else { - typ = uint64(types.NewMessage) - } - - return &TransferRequest1_1{ - Type: typ, - Pull: isPull, - Vouch: &cborgen.Deferred{Raw: vbytes}, - Stor: &cborgen.Deferred{Raw: selBytes}, - BCid: &baseCid, - VTyp: vtype, - XferID: uint64(id), - }, nil -} - -// RestartExistingChannelRequest creates a request to ask the other side to restart an existing channel -func RestartExistingChannelRequest(channelId datatransfer.ChannelID) datatransfer.Request { - - return &TransferRequest1_1{Type: uint64(types.RestartExistingChannelRequestMessage), - RestartChannel: channelId} -} - -// CancelRequest request generates a request to cancel an in progress request -func CancelRequest(id datatransfer.TransferID) datatransfer.Request { - return &TransferRequest1_1{ - Type: uint64(types.CancelMessage), - XferID: uint64(id), - } -} - -// UpdateRequest generates a new request update -func UpdateRequest(id datatransfer.TransferID, isPaused bool) datatransfer.Request { - return &TransferRequest1_1{ - Type: uint64(types.UpdateMessage), - Paus: isPaused, - XferID: uint64(id), - } -} - -// VoucherRequest generates a new request for the data transfer protocol -func VoucherRequest(id datatransfer.TransferID, vtype datatransfer.TypeIdentifier, voucher encoding.Encodable) (datatransfer.Request, error) { - vbytes, err := encoding.Encode(voucher) - if err != nil { - return nil, xerrors.Errorf("Creating request: %w", err) - } - return &TransferRequest1_1{ - Type: uint64(types.VoucherMessage), - Vouch: &cborgen.Deferred{Raw: vbytes}, - VTyp: vtype, - XferID: uint64(id), - }, nil -} - -// RestartResponse builds a new Data Transfer response -func RestartResponse(id datatransfer.TransferID, accepted bool, isPaused bool, voucherResultType datatransfer.TypeIdentifier, voucherResult encoding.Encodable) (datatransfer.Response, error) { - vbytes, err := encoding.Encode(voucherResult) - if err != nil { - return nil, xerrors.Errorf("Creating request: %w", err) - } - return &TransferResponse1_1{ - Acpt: accepted, - Type: uint64(types.RestartMessage), - Paus: isPaused, - XferID: uint64(id), - VTyp: voucherResultType, - VRes: &cborgen.Deferred{Raw: vbytes}, - }, nil -} - -// NewResponse builds a new Data Transfer response -func NewResponse(id datatransfer.TransferID, accepted bool, isPaused bool, voucherResultType datatransfer.TypeIdentifier, voucherResult encoding.Encodable) (datatransfer.Response, error) { - vbytes, err := encoding.Encode(voucherResult) - if err != nil { - return nil, xerrors.Errorf("Creating request: %w", err) - } - return &TransferResponse1_1{ - Acpt: accepted, - Type: uint64(types.NewMessage), - Paus: isPaused, - XferID: uint64(id), - VTyp: voucherResultType, - VRes: &cborgen.Deferred{Raw: vbytes}, - }, nil -} - -// VoucherResultResponse builds a new response for a voucher result -func VoucherResultResponse(id datatransfer.TransferID, accepted bool, isPaused bool, voucherResultType datatransfer.TypeIdentifier, voucherResult encoding.Encodable) (datatransfer.Response, error) { - vbytes, err := encoding.Encode(voucherResult) - if err != nil { - return nil, xerrors.Errorf("Creating request: %w", err) - } - return &TransferResponse1_1{ - Acpt: accepted, - Type: uint64(types.VoucherResultMessage), - Paus: isPaused, - XferID: uint64(id), - VTyp: voucherResultType, - VRes: &cborgen.Deferred{Raw: vbytes}, - }, nil -} - -// UpdateResponse returns a new update response -func UpdateResponse(id datatransfer.TransferID, isPaused bool) datatransfer.Response { - return &TransferResponse1_1{ - Type: uint64(types.UpdateMessage), - Paus: isPaused, - XferID: uint64(id), - } -} - -// CancelResponse makes a new cancel response message -func CancelResponse(id datatransfer.TransferID) datatransfer.Response { - return &TransferResponse1_1{ - Type: uint64(types.CancelMessage), - XferID: uint64(id), - } -} - -// CompleteResponse returns a new complete response message -func CompleteResponse(id datatransfer.TransferID, isAccepted bool, isPaused bool, voucherResultType datatransfer.TypeIdentifier, voucherResult encoding.Encodable) (datatransfer.Response, error) { - vbytes, err := encoding.Encode(voucherResult) - if err != nil { - return nil, xerrors.Errorf("Creating request: %w", err) - } - return &TransferResponse1_1{ - Type: uint64(types.CompleteMessage), - Acpt: isAccepted, - Paus: isPaused, - VTyp: voucherResultType, - VRes: &cborgen.Deferred{Raw: vbytes}, - XferID: uint64(id), - }, nil -} - -// FromNet can read a network stream to deserialize a GraphSyncMessage -func FromNet(r io.Reader) (datatransfer.Message, error) { - tresp := TransferMessage1_1{} - err := tresp.UnmarshalCBOR(r) - if err != nil { - return nil, err - } - - if (tresp.IsRequest() && tresp.Request == nil) || (!tresp.IsRequest() && tresp.Response == nil) { - return nil, xerrors.Errorf("invalid/malformed message") - } - - if tresp.IsRequest() { - return tresp.Request, nil - } - return tresp.Response, nil -} - -// FromNet can read a network stream to deserialize a GraphSyncMessage -func FromIPLD(nd datamodel.Node) (datatransfer.Message, error) { - buf := new(bytes.Buffer) - err := dagcbor.Encode(nd, buf) - if err != nil { - return nil, err - } - return FromNet(buf) -} diff --git a/message/message1_1/message_test.go b/message/message1_1/message_test.go deleted file mode 100644 index 7cb2ac24..00000000 --- a/message/message1_1/message_test.go +++ /dev/null @@ -1,549 +0,0 @@ -package message1_1_test - -import ( - "bytes" - "encoding/hex" - "fmt" - "math/rand" - "testing" - - "github.com/ipfs/go-cid" - basicnode "github.com/ipld/go-ipld-prime/node/basic" - "github.com/ipld/go-ipld-prime/traversal/selector/builder" - "github.com/libp2p/go-libp2p-core/peer" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - datatransfer "github.com/filecoin-project/go-data-transfer" - "github.com/filecoin-project/go-data-transfer/message/message1_1" - "github.com/filecoin-project/go-data-transfer/testutil" -) - -func TestNewRequest(t *testing.T) { - baseCid := testutil.GenerateCids(1)[0] - selector := builder.NewSelectorSpecBuilder(basicnode.Prototype.Any).Matcher().Node() - isPull := true - id := datatransfer.TransferID(rand.Int31()) - voucher := testutil.NewFakeDTType() - request, err := message1_1.NewRequest(id, false, isPull, voucher.Type(), voucher, baseCid, selector) - require.NoError(t, err) - assert.Equal(t, id, request.TransferID()) - assert.False(t, request.IsCancel()) - assert.False(t, request.IsUpdate()) - assert.True(t, request.IsPull()) - assert.True(t, request.IsRequest()) - assert.Equal(t, baseCid.String(), request.BaseCid().String()) - testutil.AssertFakeDTVoucher(t, request, voucher) - receivedSelector, err := request.Selector() - require.NoError(t, err) - require.Equal(t, selector, receivedSelector) - // Sanity check to make sure we can cast to datatransfer.Message - msg, ok := request.(datatransfer.Message) - require.True(t, ok) - - assert.True(t, msg.IsRequest()) - assert.Equal(t, request.TransferID(), msg.TransferID()) - assert.False(t, msg.IsRestart()) - assert.True(t, msg.IsNew()) -} - -func TestRestartRequest(t *testing.T) { - baseCid := testutil.GenerateCids(1)[0] - selector := builder.NewSelectorSpecBuilder(basicnode.Prototype.Any).Matcher().Node() - isPull := true - id := datatransfer.TransferID(rand.Int31()) - voucher := testutil.NewFakeDTType() - request, err := message1_1.NewRequest(id, true, isPull, voucher.Type(), voucher, baseCid, selector) - require.NoError(t, err) - assert.Equal(t, id, request.TransferID()) - assert.False(t, request.IsCancel()) - assert.False(t, request.IsUpdate()) - assert.True(t, request.IsPull()) - assert.True(t, request.IsRequest()) - assert.Equal(t, baseCid.String(), request.BaseCid().String()) - testutil.AssertFakeDTVoucher(t, request, voucher) - receivedSelector, err := request.Selector() - require.NoError(t, err) - require.Equal(t, selector, receivedSelector) - // Sanity check to make sure we can cast to datatransfer.Message - msg, ok := request.(datatransfer.Message) - require.True(t, ok) - - assert.True(t, msg.IsRequest()) - assert.Equal(t, request.TransferID(), msg.TransferID()) - assert.True(t, msg.IsRestart()) - assert.False(t, msg.IsNew()) -} - -func TestRestartExistingChannelRequest(t *testing.T) { - t.Run("round-trip", func(t *testing.T) { - peers := testutil.GeneratePeers(2) - tid := uint64(1) - chid := datatransfer.ChannelID{Initiator: peers[0], - Responder: peers[1], ID: datatransfer.TransferID(tid)} - req := message1_1.RestartExistingChannelRequest(chid) - - wbuf := new(bytes.Buffer) - require.NoError(t, req.ToNet(wbuf)) - - desMsg, err := message1_1.FromNet(wbuf) - require.NoError(t, err) - req, ok := (desMsg).(datatransfer.Request) - require.True(t, ok) - require.True(t, req.IsRestartExistingChannelRequest()) - achid, err := req.RestartChannelId() - require.NoError(t, err) - require.Equal(t, chid, achid) - }) - t.Run("ipld-prime compat", func(t *testing.T) { - msg, _ := hex.DecodeString("a36449735271f56752657175657374aa6442436964f66454797065076450617573f46450617274f46450756c6cf46453746f72f665566f756368f664565479706066586665724944006e526573746172744368616e6e656c83613161320168526573706f6e7365f6") - desMsg, err := message1_1.FromNet(bytes.NewReader(msg)) - require.NoError(t, err) - req, ok := (desMsg).(datatransfer.Request) - require.True(t, ok) - require.True(t, req.IsRestartExistingChannelRequest()) - achid, err := req.RestartChannelId() - require.NoError(t, err) - tid := uint64(1) - chid := datatransfer.ChannelID{Initiator: peer.ID("1"), - Responder: peer.ID("2"), ID: datatransfer.TransferID(tid)} - require.Equal(t, chid, achid) - }) -} - -func TestTransferRequest_MarshalCBOR(t *testing.T) { - // sanity check MarshalCBOR does its thing w/o error - req, err := NewTestTransferRequest() - require.NoError(t, err) - wbuf := new(bytes.Buffer) - require.NoError(t, req.MarshalCBOR(wbuf)) - assert.Greater(t, wbuf.Len(), 0) -} -func TestTransferRequest_UnmarshalCBOR(t *testing.T) { - t.Run("round-trip", func(t *testing.T) { - req, err := NewTestTransferRequest() - require.NoError(t, err) - wbuf := new(bytes.Buffer) - // use ToNet / message1_1.FromNet - require.NoError(t, req.ToNet(wbuf)) - - desMsg, err := message1_1.FromNet(wbuf) - require.NoError(t, err) - - // Verify round-trip - assert.Equal(t, req.TransferID(), desMsg.TransferID()) - assert.Equal(t, req.IsRequest(), desMsg.IsRequest()) - - desReq := desMsg.(datatransfer.Request) - assert.Equal(t, req.IsPull(), desReq.IsPull()) - assert.Equal(t, req.IsCancel(), desReq.IsCancel()) - assert.Equal(t, req.BaseCid(), desReq.BaseCid()) - testutil.AssertEqualFakeDTVoucher(t, &req, desReq) - testutil.AssertEqualSelector(t, &req, desReq) - }) - t.Run("ipld-prime compat", func(t *testing.T) { - req, err := NewTestTransferRequest() - require.NoError(t, err) - - msg, _ := hex.DecodeString("a36449735271f56752657175657374aa6442436964d82a58230012204bf5122f344554c53bde2ebb8cd2b7e3d1600ad631c385a5d7cce23c7785459a6450617274f46450617573f46450756c6cf46453746f72a1612ea064547970650064565479706a46616b6544545479706565566f756368817864f55ff8f12508b63ef2bfeca7557ae90df6311a5ec1631b4a1fa843310bd9c3a710eaace5a1bdd72ad0bfe049771c11e756338bd93865e645f1adec9b9c99ef407fbd4fc6859e7904c5ad7dc9bd10a5cc16973d5b28ec1a6dd43d9f82f9f18c3d03418e35665866657249441a4d6582216e526573746172744368616e6e656c8360600068526573706f6e7365f6") - desMsg, err := message1_1.FromNet(bytes.NewReader(msg)) - require.NoError(t, err) - - // Verify round-trip - assert.Equal(t, datatransfer.TransferID(1298498081), desMsg.TransferID()) - assert.Equal(t, req.IsRequest(), desMsg.IsRequest()) - - desReq := desMsg.(datatransfer.Request) - assert.Equal(t, req.IsPull(), desReq.IsPull()) - assert.Equal(t, req.IsCancel(), desReq.IsCancel()) - c, _ := cid.Parse("QmTTA2daxGqo5denp6SwLzzkLJm3fuisYEi9CoWsuHpzfb") - assert.Equal(t, c, desReq.BaseCid()) - testutil.AssertEqualFakeDTVoucher(t, &req, desReq) - testutil.AssertEqualSelector(t, &req, desReq) - }) -} - -func TestResponses(t *testing.T) { - id := datatransfer.TransferID(rand.Int31()) - voucherResult := testutil.NewFakeDTType() - response, err := message1_1.NewResponse(id, false, true, voucherResult.Type(), voucherResult) // not accepted - require.NoError(t, err) - assert.Equal(t, response.TransferID(), id) - assert.False(t, response.Accepted()) - assert.True(t, response.IsNew()) - assert.False(t, response.IsUpdate()) - assert.True(t, response.IsPaused()) - assert.False(t, response.IsRequest()) - testutil.AssertFakeDTVoucherResult(t, response, voucherResult) - // Sanity check to make sure we can cast to datatransfer.Message - msg, ok := response.(datatransfer.Message) - require.True(t, ok) - - assert.False(t, msg.IsRequest()) - assert.True(t, msg.IsNew()) - assert.False(t, msg.IsUpdate()) - assert.True(t, msg.IsPaused()) - assert.Equal(t, response.TransferID(), msg.TransferID()) -} - -func TestTransferResponse_MarshalCBOR(t *testing.T) { - id := datatransfer.TransferID(rand.Int31()) - voucherResult := testutil.NewFakeDTType() - response, err := message1_1.NewResponse(id, true, false, voucherResult.Type(), voucherResult) // accepted - require.NoError(t, err) - - // sanity check that we can marshal data - wbuf := new(bytes.Buffer) - require.NoError(t, response.ToNet(wbuf)) - assert.Greater(t, wbuf.Len(), 0) -} - -func TestTransferResponse_UnmarshalCBOR(t *testing.T) { - t.Run("round-trip", func(t *testing.T) { - id := datatransfer.TransferID(rand.Int31()) - voucherResult := testutil.NewFakeDTType() - response, err := message1_1.NewResponse(id, true, false, voucherResult.Type(), voucherResult) // accepted - require.NoError(t, err) - - wbuf := new(bytes.Buffer) - require.NoError(t, response.ToNet(wbuf)) - - // verify round trip - desMsg, err := message1_1.FromNet(wbuf) - require.NoError(t, err) - assert.False(t, desMsg.IsRequest()) - assert.True(t, desMsg.IsNew()) - assert.False(t, desMsg.IsUpdate()) - assert.False(t, desMsg.IsPaused()) - assert.Equal(t, id, desMsg.TransferID()) - - desResp, ok := desMsg.(datatransfer.Response) - require.True(t, ok) - assert.True(t, desResp.Accepted()) - assert.True(t, desResp.IsNew()) - assert.False(t, desResp.IsUpdate()) - assert.False(t, desMsg.IsPaused()) - testutil.AssertFakeDTVoucherResult(t, desResp, voucherResult) - }) - t.Run("ipld-prime compat", func(t *testing.T) { - voucherResult := testutil.NewFakeDTType() - voucherResult.Data = "\xf5_\xf8\xf1%\b\xb6>\xf2\xbf\xec\xa7Uz\xe9\r\xf61\x1a^\xc1c\x1bJ\x1f\xa8C1\v\xd9ç\x10\xea\xac塽\xd7*п\xe0Iw\x1c\x11\xe7V3\x8b\xd98e\xe6E\xf1\xad웜\x99\xef@\u007f\xbdOƅ\x9ey\x04ŭ}ɽ\x10\xa5\xcc\x16\x97=[(\xec\x1am\xd4=\x9f\x82\xf9\xf1\x8c=\x03A\x8e5" - - msg, _ := hex.DecodeString("a36449735271f46752657175657374f668526573706f6e7365a66441637074f56450617573f46454797065006456526573817864f55ff8f12508b63ef2bfeca7557ae90df6311a5ec1631b4a1fa843310bd9c3a710eaace5a1bdd72ad0bfe049771c11e756338bd93865e645f1adec9b9c99ef407fbd4fc6859e7904c5ad7dc9bd10a5cc16973d5b28ec1a6dd43d9f82f9f18c3d03418e3564565479706a46616b65445454797065665866657249441a4d658221") - desMsg, err := message1_1.FromNet(bytes.NewReader(msg)) - require.NoError(t, err) - assert.False(t, desMsg.IsRequest()) - assert.True(t, desMsg.IsNew()) - assert.False(t, desMsg.IsUpdate()) - assert.False(t, desMsg.IsPaused()) - assert.Equal(t, datatransfer.TransferID(1298498081), desMsg.TransferID()) - - desResp, ok := desMsg.(datatransfer.Response) - require.True(t, ok) - assert.True(t, desResp.Accepted()) - assert.True(t, desResp.IsNew()) - assert.False(t, desResp.IsUpdate()) - assert.False(t, desMsg.IsPaused()) - testutil.AssertFakeDTVoucherResult(t, desResp, voucherResult) - }) -} - -func TestRequestCancel(t *testing.T) { - t.Run("round-trip", func(t *testing.T) { - id := datatransfer.TransferID(rand.Int31()) - req := message1_1.CancelRequest(id) - require.Equal(t, req.TransferID(), id) - require.True(t, req.IsRequest()) - require.True(t, req.IsCancel()) - require.False(t, req.IsUpdate()) - - wbuf := new(bytes.Buffer) - require.NoError(t, req.ToNet(wbuf)) - - deserialized, err := message1_1.FromNet(wbuf) - require.NoError(t, err) - - deserializedRequest, ok := deserialized.(datatransfer.Request) - require.True(t, ok) - require.Equal(t, deserializedRequest.TransferID(), req.TransferID()) - require.Equal(t, deserializedRequest.IsCancel(), req.IsCancel()) - require.Equal(t, deserializedRequest.IsRequest(), req.IsRequest()) - require.Equal(t, deserializedRequest.IsUpdate(), req.IsUpdate()) - }) - t.Run("ipld-prime compat", func(t *testing.T) { - id := datatransfer.TransferID(1298498081) - req := message1_1.CancelRequest(id) - require.Equal(t, req.TransferID(), id) - require.True(t, req.IsRequest()) - require.True(t, req.IsCancel()) - require.False(t, req.IsUpdate()) - - msg, _ := hex.DecodeString("a36449735271f56752657175657374aa6442436964f66450617274f46450617573f46450756c6cf46453746f72f664547970650264565479706065566f756368f6665866657249441a4d6582216e526573746172744368616e6e656c8360600068526573706f6e7365f6") - deserialized, err := message1_1.FromNet(bytes.NewReader(msg)) - require.NoError(t, err) - - deserializedRequest, ok := deserialized.(datatransfer.Request) - require.True(t, ok) - require.Equal(t, deserializedRequest.TransferID(), req.TransferID()) - require.Equal(t, deserializedRequest.IsCancel(), req.IsCancel()) - require.Equal(t, deserializedRequest.IsRequest(), req.IsRequest()) - require.Equal(t, deserializedRequest.IsUpdate(), req.IsUpdate()) - }) -} - -func TestRequestUpdate(t *testing.T) { - t.Run("round-trip", func(t *testing.T) { - id := datatransfer.TransferID(rand.Int31()) - req := message1_1.UpdateRequest(id, true) - require.Equal(t, req.TransferID(), id) - require.True(t, req.IsRequest()) - require.False(t, req.IsCancel()) - require.True(t, req.IsUpdate()) - require.True(t, req.IsPaused()) - - wbuf := new(bytes.Buffer) - require.NoError(t, req.ToNet(wbuf)) - - deserialized, err := message1_1.FromNet(wbuf) - require.NoError(t, err) - - deserializedRequest, ok := deserialized.(datatransfer.Request) - require.True(t, ok) - require.Equal(t, deserializedRequest.TransferID(), req.TransferID()) - require.Equal(t, deserializedRequest.IsCancel(), req.IsCancel()) - require.Equal(t, deserializedRequest.IsRequest(), req.IsRequest()) - require.Equal(t, deserializedRequest.IsUpdate(), req.IsUpdate()) - require.Equal(t, deserializedRequest.IsPaused(), req.IsPaused()) - }) - t.Run("ipld-prime compat", func(t *testing.T) { - id := datatransfer.TransferID(1298498081) - req := message1_1.UpdateRequest(id, true) - - msg, _ := hex.DecodeString("a36449735271f56752657175657374aa6442436964f66450617274f46450617573f56450756c6cf46453746f72f664547970650164565479706065566f756368f6665866657249441a4d6582216e526573746172744368616e6e656c8360600068526573706f6e7365f6") - deserialized, err := message1_1.FromNet(bytes.NewReader(msg)) - require.NoError(t, err) - - deserializedRequest, ok := deserialized.(datatransfer.Request) - require.True(t, ok) - require.Equal(t, deserializedRequest.TransferID(), req.TransferID()) - require.Equal(t, deserializedRequest.IsCancel(), req.IsCancel()) - require.Equal(t, deserializedRequest.IsRequest(), req.IsRequest()) - require.Equal(t, deserializedRequest.IsUpdate(), req.IsUpdate()) - require.Equal(t, deserializedRequest.IsPaused(), req.IsPaused()) - }) -} - -func TestUpdateResponse(t *testing.T) { - id := datatransfer.TransferID(rand.Int31()) - response := message1_1.UpdateResponse(id, true) // not accepted - assert.Equal(t, response.TransferID(), id) - assert.False(t, response.Accepted()) - assert.False(t, response.IsNew()) - assert.True(t, response.IsUpdate()) - assert.True(t, response.IsPaused()) - assert.False(t, response.IsRequest()) - - // Sanity check to make sure we can cast to datatransfer.Message - msg, ok := response.(datatransfer.Message) - require.True(t, ok) - - assert.False(t, msg.IsRequest()) - assert.False(t, msg.IsNew()) - assert.True(t, msg.IsUpdate()) - assert.True(t, msg.IsPaused()) - assert.Equal(t, response.TransferID(), msg.TransferID()) -} - -func TestCancelResponse(t *testing.T) { - id := datatransfer.TransferID(rand.Int31()) - response := message1_1.CancelResponse(id) - assert.Equal(t, response.TransferID(), id) - assert.False(t, response.IsNew()) - assert.False(t, response.IsUpdate()) - assert.True(t, response.IsCancel()) - assert.False(t, response.IsRequest()) - // Sanity check to make sure we can cast to datatransfer.Message - msg, ok := response.(datatransfer.Message) - require.True(t, ok) - - assert.False(t, msg.IsRequest()) - assert.False(t, msg.IsNew()) - assert.False(t, msg.IsUpdate()) - assert.True(t, msg.IsCancel()) - assert.Equal(t, response.TransferID(), msg.TransferID()) -} - -func TestCompleteResponse(t *testing.T) { - id := datatransfer.TransferID(rand.Int31()) - response, err := message1_1.CompleteResponse(id, true, true, datatransfer.EmptyTypeIdentifier, nil) - require.NoError(t, err) - assert.Equal(t, response.TransferID(), id) - assert.False(t, response.IsNew()) - assert.False(t, response.IsUpdate()) - assert.True(t, response.IsPaused()) - assert.True(t, response.IsVoucherResult()) - assert.True(t, response.EmptyVoucherResult()) - assert.True(t, response.IsComplete()) - assert.False(t, response.IsRequest()) - // Sanity check to make sure we can cast to datatransfer.Message - msg, ok := response.(datatransfer.Message) - require.True(t, ok) - - assert.False(t, msg.IsRequest()) - assert.False(t, msg.IsNew()) - assert.False(t, msg.IsUpdate()) - assert.Equal(t, response.TransferID(), msg.TransferID()) -} -func TestToNetFromNetEquivalency(t *testing.T) { - t.Run("round-trip", func(t *testing.T) { - baseCid := testutil.GenerateCids(1)[0] - selector := builder.NewSelectorSpecBuilder(basicnode.Prototype.Any).Matcher().Node() - isPull := false - id := datatransfer.TransferID(rand.Int31()) - accepted := false - voucher := testutil.NewFakeDTType() - voucherResult := testutil.NewFakeDTType() - request, err := message1_1.NewRequest(id, false, isPull, voucher.Type(), voucher, baseCid, selector) - require.NoError(t, err) - buf := new(bytes.Buffer) - err = request.ToNet(buf) - require.NoError(t, err) - require.Greater(t, buf.Len(), 0) - deserialized, err := message1_1.FromNet(buf) - require.NoError(t, err) - - deserializedRequest, ok := deserialized.(datatransfer.Request) - require.True(t, ok) - - require.Equal(t, deserializedRequest.TransferID(), request.TransferID()) - require.Equal(t, deserializedRequest.IsCancel(), request.IsCancel()) - require.Equal(t, deserializedRequest.IsPull(), request.IsPull()) - require.Equal(t, deserializedRequest.IsRequest(), request.IsRequest()) - require.Equal(t, deserializedRequest.BaseCid(), request.BaseCid()) - testutil.AssertEqualFakeDTVoucher(t, request, deserializedRequest) - testutil.AssertEqualSelector(t, request, deserializedRequest) - - response, err := message1_1.NewResponse(id, accepted, false, voucherResult.Type(), voucherResult) - require.NoError(t, err) - err = response.ToNet(buf) - require.NoError(t, err) - deserialized, err = message1_1.FromNet(buf) - require.NoError(t, err) - - deserializedResponse, ok := deserialized.(datatransfer.Response) - require.True(t, ok) - - require.Equal(t, deserializedResponse.TransferID(), response.TransferID()) - require.Equal(t, deserializedResponse.Accepted(), response.Accepted()) - require.Equal(t, deserializedResponse.IsRequest(), response.IsRequest()) - require.Equal(t, deserializedResponse.IsUpdate(), response.IsUpdate()) - require.Equal(t, deserializedResponse.IsPaused(), response.IsPaused()) - testutil.AssertEqualFakeDTVoucherResult(t, response, deserializedResponse) - - request = message1_1.CancelRequest(id) - err = request.ToNet(buf) - require.NoError(t, err) - deserialized, err = message1_1.FromNet(buf) - require.NoError(t, err) - - deserializedRequest, ok = deserialized.(datatransfer.Request) - require.True(t, ok) - - require.Equal(t, deserializedRequest.TransferID(), request.TransferID()) - require.Equal(t, deserializedRequest.IsCancel(), request.IsCancel()) - require.Equal(t, deserializedRequest.IsRequest(), request.IsRequest()) - }) - t.Run("ipld-prime compat", func(t *testing.T) { - baseCid := testutil.GenerateCids(1)[0] - selector := builder.NewSelectorSpecBuilder(basicnode.Prototype.Any).Matcher().Node() - isPull := false - id := datatransfer.TransferID(1298498081) - accepted := false - voucher := testutil.NewFakeDTType() - voucherResult := testutil.NewFakeDTType() - request, err := message1_1.NewRequest(id, false, isPull, voucher.Type(), voucher, baseCid, selector) - require.NoError(t, err) - buf := new(bytes.Buffer) - err = request.ToNet(buf) - require.NoError(t, err) - require.Greater(t, buf.Len(), 0) - msg, _ := hex.DecodeString("a36449735271f56752657175657374aa6442436964d82a58230012204bf5122f344554c53bde2ebb8cd2b7e3d1600ad631c385a5d7cce23c7785459a6450617274f46450617573f46450756c6cf46453746f72a1612ea064547970650064565479706a46616b6544545479706565566f756368817864f55ff8f12508b63ef2bfeca7557ae90df6311a5ec1631b4a1fa843310bd9c3a710eaace5a1bdd72ad0bfe049771c11e756338bd93865e645f1adec9b9c99ef407fbd4fc6859e7904c5ad7dc9bd10a5cc16973d5b28ec1a6dd43d9f82f9f18c3d03418e35665866657249441a4d6582216e526573746172744368616e6e656c8360600068526573706f6e7365f6") - deserialized, err := message1_1.FromNet(bytes.NewReader(msg)) - require.NoError(t, err) - - deserializedRequest, ok := deserialized.(datatransfer.Request) - require.True(t, ok) - - require.Equal(t, deserializedRequest.TransferID(), request.TransferID()) - require.Equal(t, deserializedRequest.IsCancel(), request.IsCancel()) - require.Equal(t, deserializedRequest.IsPull(), request.IsPull()) - require.Equal(t, deserializedRequest.IsRequest(), request.IsRequest()) - c, _ := cid.Parse("QmTTA2daxGqo5denp6SwLzzkLJm3fuisYEi9CoWsuHpzfb") - assert.Equal(t, c, deserializedRequest.BaseCid()) - testutil.AssertEqualFakeDTVoucher(t, request, deserializedRequest) - testutil.AssertEqualSelector(t, request, deserializedRequest) - - response, err := message1_1.NewResponse(id, accepted, false, voucherResult.Type(), voucherResult) - require.NoError(t, err) - err = response.ToNet(buf) - require.NoError(t, err) - msg, _ = hex.DecodeString("a36449735271f46752657175657374f668526573706f6e7365a66441637074f46450617573f464547970650064565265738178644204cb9a1e34c5f08e9b20aa76090e70020bb56c0ca3d3af7296cd1058a5112890fed218488f084d8df9e4835fb54ad045ffd936e3bf7261b0426c51352a097816ed74482bb9084b4a7ed8adc517f3371e0e0434b511625cd1a41792243dccdcfe88094b64565479706a46616b65445454797065665866657249441a4d658221") - deserialized, err = message1_1.FromNet(bytes.NewReader(msg)) - require.NoError(t, err) - - deserializedResponse, ok := deserialized.(datatransfer.Response) - require.True(t, ok) - - require.Equal(t, deserializedResponse.TransferID(), response.TransferID()) - require.Equal(t, deserializedResponse.Accepted(), response.Accepted()) - require.Equal(t, deserializedResponse.IsRequest(), response.IsRequest()) - require.Equal(t, deserializedResponse.IsUpdate(), response.IsUpdate()) - require.Equal(t, deserializedResponse.IsPaused(), response.IsPaused()) - testutil.AssertEqualFakeDTVoucherResult(t, response, deserializedResponse) - - request = message1_1.CancelRequest(id) - err = request.ToNet(buf) - require.NoError(t, err) - msg, _ = hex.DecodeString("a36449735271f56752657175657374aa6442436964f66450617274f46450617573f46450756c6cf46453746f72f664547970650264565479706065566f756368f6665866657249441a4d6582216e526573746172744368616e6e656c8360600068526573706f6e7365f6") - deserialized, err = message1_1.FromNet(bytes.NewReader(msg)) - require.NoError(t, err) - - deserializedRequest, ok = deserialized.(datatransfer.Request) - require.True(t, ok) - - require.Equal(t, deserializedRequest.TransferID(), request.TransferID()) - require.Equal(t, deserializedRequest.IsCancel(), request.IsCancel()) - require.Equal(t, deserializedRequest.IsRequest(), request.IsRequest()) - }) -} - -func TestFromNetMessageValidation(t *testing.T) { - // craft request message with nil request struct - buf := []byte{0x83, 0xf5, 0xf6, 0xf6} - msg, err := message1_1.FromNet(bytes.NewBuffer(buf)) - assert.Error(t, err) - assert.Nil(t, msg) - - // craft response message with nil response struct - buf = []byte{0x83, 0xf4, 0xf6, 0xf6} - msg, err = message1_1.FromNet(bytes.NewBuffer(buf)) - assert.Error(t, err) - assert.Nil(t, msg) -} - -func NewTestTransferRequest() (message1_1.TransferRequest1_1, error) { - bcid := testutil.GenerateCids(1)[0] - selector := builder.NewSelectorSpecBuilder(basicnode.Prototype.Any).Matcher().Node() - isPull := false - id := datatransfer.TransferID(rand.Int31()) - voucher := testutil.NewFakeDTType() - req, err := message1_1.NewRequest(id, false, isPull, voucher.Type(), voucher, bcid, selector) - if err != nil { - return message1_1.TransferRequest1_1{}, err - } - tr, ok := req.(*message1_1.TransferRequest1_1) - if !ok { - return message1_1.TransferRequest1_1{}, fmt.Errorf("expected *TransferRequest1_1") - } - return *tr, nil -} diff --git a/message/message1_1/transfer_message.go b/message/message1_1/transfer_message.go deleted file mode 100644 index f8866fd6..00000000 --- a/message/message1_1/transfer_message.go +++ /dev/null @@ -1,59 +0,0 @@ -package message1_1 - -import ( - "bytes" - "io" - - "github.com/ipld/go-ipld-prime/codec/dagcbor" - "github.com/ipld/go-ipld-prime/datamodel" - basicnode "github.com/ipld/go-ipld-prime/node/basic" - - datatransfer "github.com/filecoin-project/go-data-transfer" -) - -//go:generate cbor-gen-for --map-encoding TransferMessage1_1 - -// transferMessage1_1 is the transfer message for the 1.1 Data Transfer Protocol. -type TransferMessage1_1 struct { - IsRq bool - - Request *TransferRequest1_1 - Response *TransferResponse1_1 -} - -// ========= datatransfer.Message interface - -// IsRequest returns true if this message is a data request -func (tm *TransferMessage1_1) IsRequest() bool { - return tm.IsRq -} - -// TransferID returns the TransferID of this message -func (tm *TransferMessage1_1) TransferID() datatransfer.TransferID { - if tm.IsRequest() { - return tm.Request.TransferID() - } - return tm.Response.TransferID() -} - -// ToNet serializes a transfer message type. It is simply a wrapper for MarshalCBOR, to provide -// symmetry with FromNet -func (tm *TransferMessage1_1) ToIPLD() (datamodel.Node, error) { - buf := new(bytes.Buffer) - err := tm.ToNet(buf) - if err != nil { - return nil, err - } - nb := basicnode.Prototype.Any.NewBuilder() - err = dagcbor.Decode(nb, buf) - if err != nil { - return nil, err - } - return nb.Build(), nil -} - -// ToNet serializes a transfer message type. It is simply a wrapper for MarshalCBOR, to provide -// symmetry with FromNet -func (tm *TransferMessage1_1) ToNet(w io.Writer) error { - return tm.MarshalCBOR(w) -} diff --git a/message/message1_1/transfer_message_cbor_gen.go b/message/message1_1/transfer_message_cbor_gen.go deleted file mode 100644 index b7b25e8e..00000000 --- a/message/message1_1/transfer_message_cbor_gen.go +++ /dev/null @@ -1,179 +0,0 @@ -// Code generated by github.com/whyrusleeping/cbor-gen. DO NOT EDIT. - -package message1_1 - -import ( - "fmt" - "io" - "sort" - - cid "github.com/ipfs/go-cid" - cbg "github.com/whyrusleeping/cbor-gen" - xerrors "golang.org/x/xerrors" -) - -var _ = xerrors.Errorf -var _ = cid.Undef -var _ = sort.Sort - -func (t *TransferMessage1_1) MarshalCBOR(w io.Writer) error { - if t == nil { - _, err := w.Write(cbg.CborNull) - return err - } - if _, err := w.Write([]byte{163}); err != nil { - return err - } - - scratch := make([]byte, 9) - - // t.IsRq (bool) (bool) - if len("IsRq") > cbg.MaxLength { - return xerrors.Errorf("Value in field \"IsRq\" was too long") - } - - if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("IsRq"))); err != nil { - return err - } - if _, err := io.WriteString(w, string("IsRq")); err != nil { - return err - } - - if err := cbg.WriteBool(w, t.IsRq); err != nil { - return err - } - - // t.Request (message1_1.TransferRequest1_1) (struct) - if len("Request") > cbg.MaxLength { - return xerrors.Errorf("Value in field \"Request\" was too long") - } - - if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("Request"))); err != nil { - return err - } - if _, err := io.WriteString(w, string("Request")); err != nil { - return err - } - - if err := t.Request.MarshalCBOR(w); err != nil { - return err - } - - // t.Response (message1_1.TransferResponse1_1) (struct) - if len("Response") > cbg.MaxLength { - return xerrors.Errorf("Value in field \"Response\" was too long") - } - - if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("Response"))); err != nil { - return err - } - if _, err := io.WriteString(w, string("Response")); err != nil { - return err - } - - if err := t.Response.MarshalCBOR(w); err != nil { - return err - } - return nil -} - -func (t *TransferMessage1_1) UnmarshalCBOR(r io.Reader) error { - *t = TransferMessage1_1{} - - br := cbg.GetPeeker(r) - scratch := make([]byte, 8) - - maj, extra, err := cbg.CborReadHeaderBuf(br, scratch) - if err != nil { - return err - } - if maj != cbg.MajMap { - return fmt.Errorf("cbor input should be of type map") - } - - if extra > cbg.MaxLength { - return fmt.Errorf("TransferMessage1_1: map struct too large (%d)", extra) - } - - var name string - n := extra - - for i := uint64(0); i < n; i++ { - - { - sval, err := cbg.ReadStringBuf(br, scratch) - if err != nil { - return err - } - - name = string(sval) - } - - switch name { - // t.IsRq (bool) (bool) - case "IsRq": - - maj, extra, err = cbg.CborReadHeaderBuf(br, scratch) - if err != nil { - return err - } - if maj != cbg.MajOther { - return fmt.Errorf("booleans must be major type 7") - } - switch extra { - case 20: - t.IsRq = false - case 21: - t.IsRq = true - default: - return fmt.Errorf("booleans are either major type 7, value 20 or 21 (got %d)", extra) - } - // t.Request (message1_1.TransferRequest1_1) (struct) - case "Request": - - { - - b, err := br.ReadByte() - if err != nil { - return err - } - if b != cbg.CborNull[0] { - if err := br.UnreadByte(); err != nil { - return err - } - t.Request = new(TransferRequest1_1) - if err := t.Request.UnmarshalCBOR(br); err != nil { - return xerrors.Errorf("unmarshaling t.Request pointer: %w", err) - } - } - - } - // t.Response (message1_1.TransferResponse1_1) (struct) - case "Response": - - { - - b, err := br.ReadByte() - if err != nil { - return err - } - if b != cbg.CborNull[0] { - if err := br.UnreadByte(); err != nil { - return err - } - t.Response = new(TransferResponse1_1) - if err := t.Response.UnmarshalCBOR(br); err != nil { - return xerrors.Errorf("unmarshaling t.Response pointer: %w", err) - } - } - - } - - default: - // Field doesn't exist on this type, so ignore it - cbg.ScanForLinks(r, func(cid.Cid) {}) - } - } - - return nil -} diff --git a/message/message1_1/transfer_request.go b/message/message1_1/transfer_request.go deleted file mode 100644 index a2aac438..00000000 --- a/message/message1_1/transfer_request.go +++ /dev/null @@ -1,166 +0,0 @@ -package message1_1 - -import ( - "bytes" - "io" - - "github.com/ipfs/go-cid" - "github.com/ipld/go-ipld-prime" - "github.com/ipld/go-ipld-prime/codec/dagcbor" - "github.com/ipld/go-ipld-prime/datamodel" - basicnode "github.com/ipld/go-ipld-prime/node/basic" - "github.com/libp2p/go-libp2p-core/protocol" - cbg "github.com/whyrusleeping/cbor-gen" - xerrors "golang.org/x/xerrors" - - datatransfer "github.com/filecoin-project/go-data-transfer" - "github.com/filecoin-project/go-data-transfer/encoding" - "github.com/filecoin-project/go-data-transfer/message/types" -) - -//go:generate cbor-gen-for --map-encoding TransferRequest1_1 - -// TransferRequest1_1 is a struct for the 1.1 Data Transfer Protocol that fulfills the datatransfer.Request interface. -// its members are exported to be used by cbor-gen -type TransferRequest1_1 struct { - BCid *cid.Cid - Type uint64 - Paus bool - Part bool - Pull bool - Stor *cbg.Deferred - Vouch *cbg.Deferred - VTyp datatransfer.TypeIdentifier - XferID uint64 - - RestartChannel datatransfer.ChannelID -} - -func (trq *TransferRequest1_1) MessageForProtocol(targetProtocol protocol.ID) (datatransfer.Message, error) { - switch targetProtocol { - case datatransfer.ProtocolDataTransfer1_2: - return trq, nil - default: - return nil, xerrors.Errorf("protocol not supported") - } -} - -// IsRequest always returns true in this case because this is a transfer request -func (trq *TransferRequest1_1) IsRequest() bool { - return true -} - -func (trq *TransferRequest1_1) IsRestart() bool { - return trq.Type == uint64(types.RestartMessage) -} - -func (trq *TransferRequest1_1) IsRestartExistingChannelRequest() bool { - return trq.Type == uint64(types.RestartExistingChannelRequestMessage) -} - -func (trq *TransferRequest1_1) RestartChannelId() (datatransfer.ChannelID, error) { - if !trq.IsRestartExistingChannelRequest() { - return datatransfer.ChannelID{}, xerrors.New("not a restart request") - } - return trq.RestartChannel, nil -} - -func (trq *TransferRequest1_1) IsNew() bool { - return trq.Type == uint64(types.NewMessage) -} - -func (trq *TransferRequest1_1) IsUpdate() bool { - return trq.Type == uint64(types.UpdateMessage) -} - -func (trq *TransferRequest1_1) IsVoucher() bool { - return trq.Type == uint64(types.VoucherMessage) || trq.Type == uint64(types.NewMessage) -} - -func (trq *TransferRequest1_1) IsPaused() bool { - return trq.Paus -} - -func (trq *TransferRequest1_1) TransferID() datatransfer.TransferID { - return datatransfer.TransferID(trq.XferID) -} - -// ========= datatransfer.Request interface -// IsPull returns true if this is a data pull request -func (trq *TransferRequest1_1) IsPull() bool { - return trq.Pull -} - -// VoucherType returns the Voucher ID -func (trq *TransferRequest1_1) VoucherType() datatransfer.TypeIdentifier { - return trq.VTyp -} - -// Voucher returns the Voucher bytes -func (trq *TransferRequest1_1) Voucher(decoder encoding.Decoder) (encoding.Encodable, error) { - if trq.Vouch == nil { - return nil, xerrors.New("No voucher present to read") - } - return decoder.DecodeFromCbor(trq.Vouch.Raw) -} - -func (trq *TransferRequest1_1) EmptyVoucher() bool { - return trq.VTyp == datatransfer.EmptyTypeIdentifier -} - -// BaseCid returns the Base CID -func (trq *TransferRequest1_1) BaseCid() cid.Cid { - if trq.BCid == nil { - return cid.Undef - } - return *trq.BCid -} - -// Selector returns the message Selector bytes -func (trq *TransferRequest1_1) Selector() (ipld.Node, error) { - if trq.Stor == nil { - return nil, xerrors.New("No selector present to read") - } - builder := basicnode.Prototype.Any.NewBuilder() - reader := bytes.NewReader(trq.Stor.Raw) - err := dagcbor.Decode(builder, reader) - if err != nil { - return nil, xerrors.Errorf("Error decoding selector: %w", err) - } - return builder.Build(), nil -} - -// IsCancel returns true if this is a cancel request -func (trq *TransferRequest1_1) IsCancel() bool { - return trq.Type == uint64(types.CancelMessage) -} - -// IsPartial returns true if this is a partial request -func (trq *TransferRequest1_1) IsPartial() bool { - return trq.Part -} - -func (trq *TransferRequest1_1) ToIPLD() (datamodel.Node, error) { - buf := new(bytes.Buffer) - err := trq.ToNet(buf) - if err != nil { - return nil, err - } - nb := basicnode.Prototype.Any.NewBuilder() - err = dagcbor.Decode(nb, buf) - if err != nil { - return nil, err - } - return nb.Build(), nil -} - -// ToNet serializes a transfer request. It's a wrapper for MarshalCBOR to provide -// symmetry with FromNet -func (trq *TransferRequest1_1) ToNet(w io.Writer) error { - msg := TransferMessage1_1{ - IsRq: true, - Request: trq, - Response: nil, - } - return msg.MarshalCBOR(w) -} diff --git a/message/message1_1/transfer_request_cbor_gen.go b/message/message1_1/transfer_request_cbor_gen.go deleted file mode 100644 index 16873836..00000000 --- a/message/message1_1/transfer_request_cbor_gen.go +++ /dev/null @@ -1,397 +0,0 @@ -// Code generated by github.com/whyrusleeping/cbor-gen. DO NOT EDIT. - -package message1_1 - -import ( - "fmt" - "io" - "sort" - - datatransfer "github.com/filecoin-project/go-data-transfer" - cid "github.com/ipfs/go-cid" - cbg "github.com/whyrusleeping/cbor-gen" - xerrors "golang.org/x/xerrors" -) - -var _ = xerrors.Errorf -var _ = cid.Undef -var _ = sort.Sort - -func (t *TransferRequest1_1) MarshalCBOR(w io.Writer) error { - if t == nil { - _, err := w.Write(cbg.CborNull) - return err - } - if _, err := w.Write([]byte{170}); err != nil { - return err - } - - scratch := make([]byte, 9) - - // t.BCid (cid.Cid) (struct) - if len("BCid") > cbg.MaxLength { - return xerrors.Errorf("Value in field \"BCid\" was too long") - } - - if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("BCid"))); err != nil { - return err - } - if _, err := io.WriteString(w, string("BCid")); err != nil { - return err - } - - if t.BCid == nil { - if _, err := w.Write(cbg.CborNull); err != nil { - return err - } - } else { - if err := cbg.WriteCidBuf(scratch, w, *t.BCid); err != nil { - return xerrors.Errorf("failed to write cid field t.BCid: %w", err) - } - } - - // t.Type (uint64) (uint64) - if len("Type") > cbg.MaxLength { - return xerrors.Errorf("Value in field \"Type\" was too long") - } - - if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("Type"))); err != nil { - return err - } - if _, err := io.WriteString(w, string("Type")); err != nil { - return err - } - - if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajUnsignedInt, uint64(t.Type)); err != nil { - return err - } - - // t.Paus (bool) (bool) - if len("Paus") > cbg.MaxLength { - return xerrors.Errorf("Value in field \"Paus\" was too long") - } - - if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("Paus"))); err != nil { - return err - } - if _, err := io.WriteString(w, string("Paus")); err != nil { - return err - } - - if err := cbg.WriteBool(w, t.Paus); err != nil { - return err - } - - // t.Part (bool) (bool) - if len("Part") > cbg.MaxLength { - return xerrors.Errorf("Value in field \"Part\" was too long") - } - - if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("Part"))); err != nil { - return err - } - if _, err := io.WriteString(w, string("Part")); err != nil { - return err - } - - if err := cbg.WriteBool(w, t.Part); err != nil { - return err - } - - // t.Pull (bool) (bool) - if len("Pull") > cbg.MaxLength { - return xerrors.Errorf("Value in field \"Pull\" was too long") - } - - if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("Pull"))); err != nil { - return err - } - if _, err := io.WriteString(w, string("Pull")); err != nil { - return err - } - - if err := cbg.WriteBool(w, t.Pull); err != nil { - return err - } - - // t.Stor (typegen.Deferred) (struct) - if len("Stor") > cbg.MaxLength { - return xerrors.Errorf("Value in field \"Stor\" was too long") - } - - if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("Stor"))); err != nil { - return err - } - if _, err := io.WriteString(w, string("Stor")); err != nil { - return err - } - - if err := t.Stor.MarshalCBOR(w); err != nil { - return err - } - - // t.Vouch (typegen.Deferred) (struct) - if len("Vouch") > cbg.MaxLength { - return xerrors.Errorf("Value in field \"Vouch\" was too long") - } - - if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("Vouch"))); err != nil { - return err - } - if _, err := io.WriteString(w, string("Vouch")); err != nil { - return err - } - - if err := t.Vouch.MarshalCBOR(w); err != nil { - return err - } - - // t.VTyp (datatransfer.TypeIdentifier) (string) - if len("VTyp") > cbg.MaxLength { - return xerrors.Errorf("Value in field \"VTyp\" was too long") - } - - if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("VTyp"))); err != nil { - return err - } - if _, err := io.WriteString(w, string("VTyp")); err != nil { - return err - } - - if len(t.VTyp) > cbg.MaxLength { - return xerrors.Errorf("Value in field t.VTyp was too long") - } - - if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len(t.VTyp))); err != nil { - return err - } - if _, err := io.WriteString(w, string(t.VTyp)); err != nil { - return err - } - - // t.XferID (uint64) (uint64) - if len("XferID") > cbg.MaxLength { - return xerrors.Errorf("Value in field \"XferID\" was too long") - } - - if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("XferID"))); err != nil { - return err - } - if _, err := io.WriteString(w, string("XferID")); err != nil { - return err - } - - if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajUnsignedInt, uint64(t.XferID)); err != nil { - return err - } - - // t.RestartChannel (datatransfer.ChannelID) (struct) - if len("RestartChannel") > cbg.MaxLength { - return xerrors.Errorf("Value in field \"RestartChannel\" was too long") - } - - if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("RestartChannel"))); err != nil { - return err - } - if _, err := io.WriteString(w, string("RestartChannel")); err != nil { - return err - } - - if err := t.RestartChannel.MarshalCBOR(w); err != nil { - return err - } - return nil -} - -func (t *TransferRequest1_1) UnmarshalCBOR(r io.Reader) error { - *t = TransferRequest1_1{} - - br := cbg.GetPeeker(r) - scratch := make([]byte, 8) - - maj, extra, err := cbg.CborReadHeaderBuf(br, scratch) - if err != nil { - return err - } - if maj != cbg.MajMap { - return fmt.Errorf("cbor input should be of type map") - } - - if extra > cbg.MaxLength { - return fmt.Errorf("TransferRequest1_1: map struct too large (%d)", extra) - } - - var name string - n := extra - - for i := uint64(0); i < n; i++ { - - { - sval, err := cbg.ReadStringBuf(br, scratch) - if err != nil { - return err - } - - name = string(sval) - } - - switch name { - // t.BCid (cid.Cid) (struct) - case "BCid": - - { - - b, err := br.ReadByte() - if err != nil { - return err - } - if b != cbg.CborNull[0] { - if err := br.UnreadByte(); err != nil { - return err - } - - c, err := cbg.ReadCid(br) - if err != nil { - return xerrors.Errorf("failed to read cid field t.BCid: %w", err) - } - - t.BCid = &c - } - - } - // t.Type (uint64) (uint64) - case "Type": - - { - - maj, extra, err = cbg.CborReadHeaderBuf(br, scratch) - if err != nil { - return err - } - if maj != cbg.MajUnsignedInt { - return fmt.Errorf("wrong type for uint64 field") - } - t.Type = uint64(extra) - - } - // t.Paus (bool) (bool) - case "Paus": - - maj, extra, err = cbg.CborReadHeaderBuf(br, scratch) - if err != nil { - return err - } - if maj != cbg.MajOther { - return fmt.Errorf("booleans must be major type 7") - } - switch extra { - case 20: - t.Paus = false - case 21: - t.Paus = true - default: - return fmt.Errorf("booleans are either major type 7, value 20 or 21 (got %d)", extra) - } - // t.Part (bool) (bool) - case "Part": - - maj, extra, err = cbg.CborReadHeaderBuf(br, scratch) - if err != nil { - return err - } - if maj != cbg.MajOther { - return fmt.Errorf("booleans must be major type 7") - } - switch extra { - case 20: - t.Part = false - case 21: - t.Part = true - default: - return fmt.Errorf("booleans are either major type 7, value 20 or 21 (got %d)", extra) - } - // t.Pull (bool) (bool) - case "Pull": - - maj, extra, err = cbg.CborReadHeaderBuf(br, scratch) - if err != nil { - return err - } - if maj != cbg.MajOther { - return fmt.Errorf("booleans must be major type 7") - } - switch extra { - case 20: - t.Pull = false - case 21: - t.Pull = true - default: - return fmt.Errorf("booleans are either major type 7, value 20 or 21 (got %d)", extra) - } - // t.Stor (typegen.Deferred) (struct) - case "Stor": - - { - - t.Stor = new(cbg.Deferred) - - if err := t.Stor.UnmarshalCBOR(br); err != nil { - return xerrors.Errorf("failed to read deferred field: %w", err) - } - } - // t.Vouch (typegen.Deferred) (struct) - case "Vouch": - - { - - t.Vouch = new(cbg.Deferred) - - if err := t.Vouch.UnmarshalCBOR(br); err != nil { - return xerrors.Errorf("failed to read deferred field: %w", err) - } - } - // t.VTyp (datatransfer.TypeIdentifier) (string) - case "VTyp": - - { - sval, err := cbg.ReadStringBuf(br, scratch) - if err != nil { - return err - } - - t.VTyp = datatransfer.TypeIdentifier(sval) - } - // t.XferID (uint64) (uint64) - case "XferID": - - { - - maj, extra, err = cbg.CborReadHeaderBuf(br, scratch) - if err != nil { - return err - } - if maj != cbg.MajUnsignedInt { - return fmt.Errorf("wrong type for uint64 field") - } - t.XferID = uint64(extra) - - } - // t.RestartChannel (datatransfer.ChannelID) (struct) - case "RestartChannel": - - { - - if err := t.RestartChannel.UnmarshalCBOR(br); err != nil { - return xerrors.Errorf("unmarshaling t.RestartChannel: %w", err) - } - - } - - default: - // Field doesn't exist on this type, so ignore it - cbg.ScanForLinks(r, func(cid.Cid) {}) - } - } - - return nil -} diff --git a/message/message1_1/transfer_request_test.go b/message/message1_1/transfer_request_test.go deleted file mode 100644 index da0bccac..00000000 --- a/message/message1_1/transfer_request_test.go +++ /dev/null @@ -1,41 +0,0 @@ -package message1_1_test - -import ( - "math/rand" - "testing" - - basicnode "github.com/ipld/go-ipld-prime/node/basic" - "github.com/ipld/go-ipld-prime/traversal/selector/builder" - "github.com/stretchr/testify/require" - - datatransfer "github.com/filecoin-project/go-data-transfer" - "github.com/filecoin-project/go-data-transfer/message/message1_1" - "github.com/filecoin-project/go-data-transfer/testutil" -) - -func TestRequestMessageForProtocol(t *testing.T) { - baseCid := testutil.GenerateCids(1)[0] - selector := builder.NewSelectorSpecBuilder(basicnode.Prototype.Any).Matcher().Node() - isPull := true - id := datatransfer.TransferID(rand.Int31()) - voucher := testutil.NewFakeDTType() - - // for the new protocols - request, err := message1_1.NewRequest(id, false, isPull, voucher.Type(), voucher, baseCid, selector) - require.NoError(t, err) - - out12, err := request.MessageForProtocol(datatransfer.ProtocolDataTransfer1_2) - require.NoError(t, err) - require.Equal(t, request, out12) - - req, ok := out12.(datatransfer.Request) - require.True(t, ok) - require.False(t, req.IsRestart()) - require.False(t, req.IsRestartExistingChannelRequest()) - require.Equal(t, baseCid, req.BaseCid()) - require.True(t, req.IsPull()) - n, err := req.Selector() - require.NoError(t, err) - require.Equal(t, selector, n) - require.Equal(t, voucher.Type(), req.VoucherType()) -} diff --git a/message/message1_1/transfer_response.go b/message/message1_1/transfer_response.go deleted file mode 100644 index 83128feb..00000000 --- a/message/message1_1/transfer_response.go +++ /dev/null @@ -1,127 +0,0 @@ -package message1_1 - -import ( - "bytes" - "io" - - "github.com/ipld/go-ipld-prime/codec/dagcbor" - "github.com/ipld/go-ipld-prime/datamodel" - basicnode "github.com/ipld/go-ipld-prime/node/basic" - "github.com/libp2p/go-libp2p-core/protocol" - cbg "github.com/whyrusleeping/cbor-gen" - xerrors "golang.org/x/xerrors" - - datatransfer "github.com/filecoin-project/go-data-transfer" - "github.com/filecoin-project/go-data-transfer/encoding" - "github.com/filecoin-project/go-data-transfer/message/types" -) - -//go:generate cbor-gen-for --map-encoding TransferResponse1_1 - -// TransferResponse1_1 is a private struct that satisfies the datatransfer.Response interface -// It is the response message for the Data Transfer 1.1 and 1.2 Protocol. -type TransferResponse1_1 struct { - Type uint64 - Acpt bool - Paus bool - XferID uint64 - VRes *cbg.Deferred - VTyp datatransfer.TypeIdentifier -} - -func (trsp *TransferResponse1_1) TransferID() datatransfer.TransferID { - return datatransfer.TransferID(trsp.XferID) -} - -// IsRequest always returns false in this case because this is a transfer response -func (trsp *TransferResponse1_1) IsRequest() bool { - return false -} - -// IsNew returns true if this is the first response sent -func (trsp *TransferResponse1_1) IsNew() bool { - return trsp.Type == uint64(types.NewMessage) -} - -// IsUpdate returns true if this response is an update -func (trsp *TransferResponse1_1) IsUpdate() bool { - return trsp.Type == uint64(types.UpdateMessage) -} - -// IsPaused returns true if the responder is paused -func (trsp *TransferResponse1_1) IsPaused() bool { - return trsp.Paus -} - -// IsCancel returns true if the responder has cancelled this response -func (trsp *TransferResponse1_1) IsCancel() bool { - return trsp.Type == uint64(types.CancelMessage) -} - -// IsComplete returns true if the responder has completed this response -func (trsp *TransferResponse1_1) IsComplete() bool { - return trsp.Type == uint64(types.CompleteMessage) -} - -func (trsp *TransferResponse1_1) IsVoucherResult() bool { - return trsp.Type == uint64(types.VoucherResultMessage) || trsp.Type == uint64(types.NewMessage) || trsp.Type == uint64(types.CompleteMessage) || - trsp.Type == uint64(types.RestartMessage) -} - -// Accepted returns true if the request is accepted in the response -func (trsp *TransferResponse1_1) Accepted() bool { - return trsp.Acpt -} - -func (trsp *TransferResponse1_1) VoucherResultType() datatransfer.TypeIdentifier { - return trsp.VTyp -} - -func (trsp *TransferResponse1_1) VoucherResult(decoder encoding.Decoder) (encoding.Encodable, error) { - if trsp.VRes == nil { - return nil, xerrors.New("No voucher present to read") - } - return decoder.DecodeFromCbor(trsp.VRes.Raw) -} - -func (trq *TransferResponse1_1) IsRestart() bool { - return trq.Type == uint64(types.RestartMessage) -} - -func (trsp *TransferResponse1_1) EmptyVoucherResult() bool { - return trsp.VTyp == datatransfer.EmptyTypeIdentifier -} - -func (trsp *TransferResponse1_1) MessageForProtocol(targetProtocol protocol.ID) (datatransfer.Message, error) { - switch targetProtocol { - case datatransfer.ProtocolDataTransfer1_2: - return trsp, nil - default: - return nil, xerrors.Errorf("protocol %s not supported", targetProtocol) - } -} - -func (trsp *TransferResponse1_1) ToIPLD() (datamodel.Node, error) { - buf := new(bytes.Buffer) - err := trsp.ToNet(buf) - if err != nil { - return nil, err - } - nb := basicnode.Prototype.Any.NewBuilder() - err = dagcbor.Decode(nb, buf) - if err != nil { - return nil, err - } - return nb.Build(), nil -} - -// ToNet serializes a transfer response. It's a wrapper for MarshalCBOR to provide -// symmetry with FromNet -func (trsp *TransferResponse1_1) ToNet(w io.Writer) error { - msg := TransferMessage1_1{ - IsRq: false, - Request: nil, - Response: trsp, - } - return msg.MarshalCBOR(w) -} diff --git a/message/message1_1/transfer_response_cbor_gen.go b/message/message1_1/transfer_response_cbor_gen.go deleted file mode 100644 index 29907e2c..00000000 --- a/message/message1_1/transfer_response_cbor_gen.go +++ /dev/null @@ -1,265 +0,0 @@ -// Code generated by github.com/whyrusleeping/cbor-gen. DO NOT EDIT. - -package message1_1 - -import ( - "fmt" - "io" - "sort" - - datatransfer "github.com/filecoin-project/go-data-transfer" - cid "github.com/ipfs/go-cid" - cbg "github.com/whyrusleeping/cbor-gen" - xerrors "golang.org/x/xerrors" -) - -var _ = xerrors.Errorf -var _ = cid.Undef -var _ = sort.Sort - -func (t *TransferResponse1_1) MarshalCBOR(w io.Writer) error { - if t == nil { - _, err := w.Write(cbg.CborNull) - return err - } - if _, err := w.Write([]byte{166}); err != nil { - return err - } - - scratch := make([]byte, 9) - - // t.Type (uint64) (uint64) - if len("Type") > cbg.MaxLength { - return xerrors.Errorf("Value in field \"Type\" was too long") - } - - if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("Type"))); err != nil { - return err - } - if _, err := io.WriteString(w, string("Type")); err != nil { - return err - } - - if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajUnsignedInt, uint64(t.Type)); err != nil { - return err - } - - // t.Acpt (bool) (bool) - if len("Acpt") > cbg.MaxLength { - return xerrors.Errorf("Value in field \"Acpt\" was too long") - } - - if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("Acpt"))); err != nil { - return err - } - if _, err := io.WriteString(w, string("Acpt")); err != nil { - return err - } - - if err := cbg.WriteBool(w, t.Acpt); err != nil { - return err - } - - // t.Paus (bool) (bool) - if len("Paus") > cbg.MaxLength { - return xerrors.Errorf("Value in field \"Paus\" was too long") - } - - if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("Paus"))); err != nil { - return err - } - if _, err := io.WriteString(w, string("Paus")); err != nil { - return err - } - - if err := cbg.WriteBool(w, t.Paus); err != nil { - return err - } - - // t.XferID (uint64) (uint64) - if len("XferID") > cbg.MaxLength { - return xerrors.Errorf("Value in field \"XferID\" was too long") - } - - if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("XferID"))); err != nil { - return err - } - if _, err := io.WriteString(w, string("XferID")); err != nil { - return err - } - - if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajUnsignedInt, uint64(t.XferID)); err != nil { - return err - } - - // t.VRes (typegen.Deferred) (struct) - if len("VRes") > cbg.MaxLength { - return xerrors.Errorf("Value in field \"VRes\" was too long") - } - - if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("VRes"))); err != nil { - return err - } - if _, err := io.WriteString(w, string("VRes")); err != nil { - return err - } - - if err := t.VRes.MarshalCBOR(w); err != nil { - return err - } - - // t.VTyp (datatransfer.TypeIdentifier) (string) - if len("VTyp") > cbg.MaxLength { - return xerrors.Errorf("Value in field \"VTyp\" was too long") - } - - if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("VTyp"))); err != nil { - return err - } - if _, err := io.WriteString(w, string("VTyp")); err != nil { - return err - } - - if len(t.VTyp) > cbg.MaxLength { - return xerrors.Errorf("Value in field t.VTyp was too long") - } - - if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len(t.VTyp))); err != nil { - return err - } - if _, err := io.WriteString(w, string(t.VTyp)); err != nil { - return err - } - return nil -} - -func (t *TransferResponse1_1) UnmarshalCBOR(r io.Reader) error { - *t = TransferResponse1_1{} - - br := cbg.GetPeeker(r) - scratch := make([]byte, 8) - - maj, extra, err := cbg.CborReadHeaderBuf(br, scratch) - if err != nil { - return err - } - if maj != cbg.MajMap { - return fmt.Errorf("cbor input should be of type map") - } - - if extra > cbg.MaxLength { - return fmt.Errorf("TransferResponse1_1: map struct too large (%d)", extra) - } - - var name string - n := extra - - for i := uint64(0); i < n; i++ { - - { - sval, err := cbg.ReadStringBuf(br, scratch) - if err != nil { - return err - } - - name = string(sval) - } - - switch name { - // t.Type (uint64) (uint64) - case "Type": - - { - - maj, extra, err = cbg.CborReadHeaderBuf(br, scratch) - if err != nil { - return err - } - if maj != cbg.MajUnsignedInt { - return fmt.Errorf("wrong type for uint64 field") - } - t.Type = uint64(extra) - - } - // t.Acpt (bool) (bool) - case "Acpt": - - maj, extra, err = cbg.CborReadHeaderBuf(br, scratch) - if err != nil { - return err - } - if maj != cbg.MajOther { - return fmt.Errorf("booleans must be major type 7") - } - switch extra { - case 20: - t.Acpt = false - case 21: - t.Acpt = true - default: - return fmt.Errorf("booleans are either major type 7, value 20 or 21 (got %d)", extra) - } - // t.Paus (bool) (bool) - case "Paus": - - maj, extra, err = cbg.CborReadHeaderBuf(br, scratch) - if err != nil { - return err - } - if maj != cbg.MajOther { - return fmt.Errorf("booleans must be major type 7") - } - switch extra { - case 20: - t.Paus = false - case 21: - t.Paus = true - default: - return fmt.Errorf("booleans are either major type 7, value 20 or 21 (got %d)", extra) - } - // t.XferID (uint64) (uint64) - case "XferID": - - { - - maj, extra, err = cbg.CborReadHeaderBuf(br, scratch) - if err != nil { - return err - } - if maj != cbg.MajUnsignedInt { - return fmt.Errorf("wrong type for uint64 field") - } - t.XferID = uint64(extra) - - } - // t.VRes (typegen.Deferred) (struct) - case "VRes": - - { - - t.VRes = new(cbg.Deferred) - - if err := t.VRes.UnmarshalCBOR(br); err != nil { - return xerrors.Errorf("failed to read deferred field: %w", err) - } - } - // t.VTyp (datatransfer.TypeIdentifier) (string) - case "VTyp": - - { - sval, err := cbg.ReadStringBuf(br, scratch) - if err != nil { - return err - } - - t.VTyp = datatransfer.TypeIdentifier(sval) - } - - default: - // Field doesn't exist on this type, so ignore it - cbg.ScanForLinks(r, func(cid.Cid) {}) - } - } - - return nil -} diff --git a/message/message1_1/transfer_response_test.go b/message/message1_1/transfer_response_test.go deleted file mode 100644 index 8e98e668..00000000 --- a/message/message1_1/transfer_response_test.go +++ /dev/null @@ -1,35 +0,0 @@ -package message1_1_test - -import ( - "math/rand" - "testing" - - "github.com/stretchr/testify/require" - - datatransfer "github.com/filecoin-project/go-data-transfer" - "github.com/filecoin-project/go-data-transfer/message/message1_1" - "github.com/filecoin-project/go-data-transfer/testutil" -) - -func TestResponseMessageForProtocol(t *testing.T) { - id := datatransfer.TransferID(rand.Int31()) - voucherResult := testutil.NewFakeDTType() - response, err := message1_1.NewResponse(id, false, true, voucherResult.Type(), voucherResult) // not accepted - require.NoError(t, err) - - // v1.2 protocol - out, err := response.MessageForProtocol(datatransfer.ProtocolDataTransfer1_2) - require.NoError(t, err) - require.Equal(t, response, out) - - resp, ok := (out).(datatransfer.Response) - require.True(t, ok) - require.True(t, resp.IsPaused()) - require.Equal(t, voucherResult.Type(), resp.VoucherResultType()) - require.True(t, resp.IsVoucherResult()) - - // random protocol - out, err = response.MessageForProtocol("RAND") - require.Error(t, err) - require.Nil(t, out) -} diff --git a/message/message1_1prime/message.go b/message/message1_1prime/message.go index 4bad9c4e..a6fccd71 100644 --- a/message/message1_1prime/message.go +++ b/message/message1_1prime/message.go @@ -7,22 +7,23 @@ import ( "github.com/ipld/go-ipld-prime" "github.com/ipld/go-ipld-prime/codec/dagcbor" "github.com/ipld/go-ipld-prime/datamodel" - "github.com/ipld/go-ipld-prime/node/bindnode" "github.com/ipld/go-ipld-prime/schema" xerrors "golang.org/x/xerrors" - datatransfer "github.com/filecoin-project/go-data-transfer" - "github.com/filecoin-project/go-data-transfer/encoding" - "github.com/filecoin-project/go-data-transfer/message/types" + datatransfer "github.com/filecoin-project/go-data-transfer/v2" + "github.com/filecoin-project/go-data-transfer/v2/message/types" ) +var emptyTypedVoucher = datatransfer.TypedVoucher{ + Voucher: ipld.Null, + Type: datatransfer.EmptyTypeIdentifier, +} + // NewRequest generates a new request for the data transfer protocol -func NewRequest(id datatransfer.TransferID, isRestart bool, isPull bool, vtype datatransfer.TypeIdentifier, voucher encoding.Encodable, baseCid cid.Cid, selector ipld.Node) (datatransfer.Request, error) { - vnode, err := encoding.EncodeToNode(voucher) - if err != nil { - return nil, xerrors.Errorf("Creating request: %w", err) +func NewRequest(id datatransfer.TransferID, isRestart bool, isPull bool, voucher *datatransfer.TypedVoucher, baseCid cid.Cid, selector datamodel.Node) (datatransfer.Request, error) { + if voucher == nil { + voucher = &emptyTypedVoucher } - if baseCid == cid.Undef { return nil, xerrors.Errorf("base CID must be defined") } @@ -34,13 +35,17 @@ func NewRequest(id datatransfer.TransferID, isRestart bool, isPull bool, vtype d typ = uint64(types.NewMessage) } + if voucher == nil { + voucher = &emptyTypedVoucher + } + return &TransferRequest1_1{ MessageType: typ, Pull: isPull, - VoucherPtr: &vnode, - SelectorPtr: &selector, + VoucherPtr: voucher.Voucher, + SelectorPtr: selector, BaseCidPtr: &baseCid, - VoucherTypeIdentifier: vtype, + VoucherTypeIdentifier: voucher.Type, TransferId: uint64(id), }, nil } @@ -71,65 +76,86 @@ func UpdateRequest(id datatransfer.TransferID, isPaused bool) datatransfer.Reque } // VoucherRequest generates a new request for the data transfer protocol -func VoucherRequest(id datatransfer.TransferID, vtype datatransfer.TypeIdentifier, voucher encoding.Encodable) (datatransfer.Request, error) { - vnode, err := encoding.EncodeToNode(voucher) - if err != nil { - return nil, xerrors.Errorf("Creating request: %w", err) +func VoucherRequest(id datatransfer.TransferID, voucher *datatransfer.TypedVoucher) datatransfer.Request { + if voucher == nil { + voucher = &emptyTypedVoucher } return &TransferRequest1_1{ MessageType: uint64(types.VoucherMessage), - VoucherPtr: &vnode, - VoucherTypeIdentifier: vtype, + VoucherPtr: voucher.Voucher, + VoucherTypeIdentifier: voucher.Type, TransferId: uint64(id), - }, nil + } } // RestartResponse builds a new Data Transfer response -func RestartResponse(id datatransfer.TransferID, accepted bool, isPaused bool, voucherResultType datatransfer.TypeIdentifier, voucherResult encoding.Encodable) (datatransfer.Response, error) { - vnode, err := encoding.EncodeToNode(voucherResult) - if err != nil { - return nil, xerrors.Errorf("Creating request: %w", err) +func RestartResponse(id datatransfer.TransferID, accepted bool, isPaused bool, voucherResult *datatransfer.TypedVoucher) datatransfer.Response { + if voucherResult == nil { + voucherResult = &emptyTypedVoucher } return &TransferResponse1_1{ RequestAccepted: accepted, MessageType: uint64(types.RestartMessage), Paused: isPaused, TransferId: uint64(id), - VoucherTypeIdentifier: voucherResultType, - VoucherResultPtr: &vnode, - }, nil + VoucherResultPtr: voucherResult.Voucher, + VoucherTypeIdentifier: voucherResult.Type, + } +} + +// ValidationResultResponse response generates a response based on a validation result +// messageType determines what kind of response is created +func ValidationResultResponse( + messageType types.MessageType, + id datatransfer.TransferID, + validationResult datatransfer.ValidationResult, + validationErr error, + paused bool) datatransfer.Response { + + voucherResult := &emptyTypedVoucher + if validationResult.VoucherResult != nil { + voucherResult = validationResult.VoucherResult + } + return &TransferResponse1_1{ + // TODO: when we area able to change the protocol, it would be helpful to record + // Validation errors vs rejections + RequestAccepted: validationErr == nil && validationResult.Accepted, + MessageType: uint64(messageType), + Paused: paused, + TransferId: uint64(id), + VoucherTypeIdentifier: voucherResult.Type, + VoucherResultPtr: voucherResult.Voucher, + } } // NewResponse builds a new Data Transfer response -func NewResponse(id datatransfer.TransferID, accepted bool, isPaused bool, voucherResultType datatransfer.TypeIdentifier, voucherResult encoding.Encodable) (datatransfer.Response, error) { - vnode, err := encoding.EncodeToNode(voucherResult) - if err != nil { - return nil, xerrors.Errorf("Creating request: %w", err) +func NewResponse(id datatransfer.TransferID, accepted bool, isPaused bool, voucherResult *datatransfer.TypedVoucher) datatransfer.Response { + if voucherResult == nil { + voucherResult = &emptyTypedVoucher } return &TransferResponse1_1{ RequestAccepted: accepted, MessageType: uint64(types.NewMessage), Paused: isPaused, TransferId: uint64(id), - VoucherTypeIdentifier: voucherResultType, - VoucherResultPtr: &vnode, - }, nil + VoucherTypeIdentifier: voucherResult.Type, + VoucherResultPtr: voucherResult.Voucher, + } } // VoucherResultResponse builds a new response for a voucher result -func VoucherResultResponse(id datatransfer.TransferID, accepted bool, isPaused bool, voucherResultType datatransfer.TypeIdentifier, voucherResult encoding.Encodable) (datatransfer.Response, error) { - vnode, err := encoding.EncodeToNode(voucherResult) - if err != nil { - return nil, xerrors.Errorf("Creating request: %w", err) +func VoucherResultResponse(id datatransfer.TransferID, accepted bool, isPaused bool, voucherResult *datatransfer.TypedVoucher) datatransfer.Response { + if voucherResult == nil { + voucherResult = &emptyTypedVoucher } return &TransferResponse1_1{ RequestAccepted: accepted, MessageType: uint64(types.VoucherResultMessage), Paused: isPaused, TransferId: uint64(id), - VoucherTypeIdentifier: voucherResultType, - VoucherResultPtr: &vnode, - }, nil + VoucherTypeIdentifier: voucherResult.Type, + VoucherResultPtr: voucherResult.Voucher, + } } // UpdateResponse returns a new update response @@ -150,31 +176,32 @@ func CancelResponse(id datatransfer.TransferID) datatransfer.Response { } // CompleteResponse returns a new complete response message -func CompleteResponse(id datatransfer.TransferID, isAccepted bool, isPaused bool, voucherResultType datatransfer.TypeIdentifier, voucherResult encoding.Encodable) (datatransfer.Response, error) { - vnode, err := encoding.EncodeToNode(voucherResult) - if err != nil { - return nil, xerrors.Errorf("Creating request: %w", err) +func CompleteResponse(id datatransfer.TransferID, isAccepted bool, isPaused bool, voucherResult *datatransfer.TypedVoucher) datatransfer.Response { + if voucherResult == nil { + voucherResult = &emptyTypedVoucher } return &TransferResponse1_1{ MessageType: uint64(types.CompleteMessage), RequestAccepted: isAccepted, Paused: isPaused, - VoucherTypeIdentifier: voucherResultType, - VoucherResultPtr: &vnode, + VoucherTypeIdentifier: voucherResult.Type, + VoucherResultPtr: voucherResult.Voucher, TransferId: uint64(id), - }, nil + } } // FromNet can read a network stream to deserialize a GraphSyncMessage func FromNet(r io.Reader) (datatransfer.Message, error) { - builder := Prototype.TransferMessage.Representation().NewBuilder() - err := dagcbor.Decode(builder, r) + tm, err := bindnodeRegistry.TypeFromReader(r, &TransferMessage1_1{}, dagcbor.Decode) if err != nil { return nil, err } - node := builder.Build() - tresp := bindnode.Unwrap(node).(*TransferMessage1_1) + tresp := tm.(*TransferMessage1_1) + + return fromMessage(tresp) +} +func fromMessage(tresp *TransferMessage1_1) (datatransfer.Message, error) { if (tresp.IsRequest && tresp.Request == nil) || (!tresp.IsRequest && tresp.Response == nil) { return nil, xerrors.Errorf("invalid/malformed message") } @@ -185,23 +212,45 @@ func FromNet(r io.Reader) (datatransfer.Message, error) { return tresp.Response, nil } +func fromWrappedMessage(wtresp *WrappedTransferMessage1_1) (datatransfer.TransportedMessage, error) { + tresp := wtresp.Message + if (tresp.IsRequest && tresp.Request == nil) || (!tresp.IsRequest && tresp.Response == nil) { + return nil, xerrors.Errorf("invalid/malformed message") + } + + if tresp.IsRequest { + return &WrappedTransferRequest1_1{ + tresp.Request, + wtresp.TransportVersion, + wtresp.TransportID, + }, nil + } + return &WrappedTransferResponse1_1{ + tresp.Response, + wtresp.TransportID, + wtresp.TransportVersion, + }, nil +} + +// FromNetWrraped can read a network stream to deserialize a message + transport ID +func FromNetWrapped(r io.Reader) (datatransfer.TransportedMessage, error) { + tm, err := bindnodeRegistry.TypeFromReader(r, &WrappedTransferMessage1_1{}, dagcbor.Decode) + if err != nil { + return nil, err + } + wtresp := tm.(*WrappedTransferMessage1_1) + return fromWrappedMessage(wtresp) +} + // FromNet can read a network stream to deserialize a GraphSyncMessage func FromIPLD(node datamodel.Node) (datatransfer.Message, error) { if tn, ok := node.(schema.TypedNode); ok { // shouldn't need this if from Graphsync node = tn.Representation() } - builder := Prototype.TransferMessage.Representation().NewBuilder() - err := builder.AssignNode(node) + tm, err := bindnodeRegistry.TypeFromNode(node, &TransferMessage1_1{}) if err != nil { return nil, err } - tresp := bindnode.Unwrap(builder.Build()).(*TransferMessage1_1) - if (tresp.IsRequest && tresp.Request == nil) || (!tresp.IsRequest && tresp.Response == nil) { - return nil, xerrors.Errorf("invalid/malformed message") - } - - if tresp.IsRequest { - return tresp.Request, nil - } - return tresp.Response, nil + tresp := tm.(*TransferMessage1_1) + return fromMessage(tresp) } diff --git a/message/message1_1prime/message_test.go b/message/message1_1prime/message_test.go index 2e13a4d5..b84b39c0 100644 --- a/message/message1_1prime/message_test.go +++ b/message/message1_1prime/message_test.go @@ -8,18 +8,15 @@ import ( "testing" "github.com/ipfs/go-cid" - "github.com/ipld/go-ipld-prime/codec/dagcbor" basicnode "github.com/ipld/go-ipld-prime/node/basic" - "github.com/ipld/go-ipld-prime/node/bindnode" "github.com/ipld/go-ipld-prime/traversal/selector/builder" "github.com/libp2p/go-libp2p-core/peer" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - datatransfer "github.com/filecoin-project/go-data-transfer" - "github.com/filecoin-project/go-data-transfer/encoding" - message1_1 "github.com/filecoin-project/go-data-transfer/message/message1_1prime" - "github.com/filecoin-project/go-data-transfer/testutil" + datatransfer "github.com/filecoin-project/go-data-transfer/v2" + message1_1 "github.com/filecoin-project/go-data-transfer/v2/message/message1_1prime" + "github.com/filecoin-project/go-data-transfer/v2/testutil" ) func TestNewRequest(t *testing.T) { @@ -27,8 +24,8 @@ func TestNewRequest(t *testing.T) { selector := builder.NewSelectorSpecBuilder(basicnode.Prototype.Any).Matcher().Node() isPull := true id := datatransfer.TransferID(rand.Int31()) - voucher := testutil.NewFakeDTType() - request, err := message1_1.NewRequest(id, false, isPull, voucher.Type(), voucher, baseCid, selector) + voucher := testutil.NewTestTypedVoucher() + request, err := message1_1.NewRequest(id, false, isPull, &voucher, baseCid, selector) require.NoError(t, err) assert.Equal(t, id, request.TransferID()) assert.False(t, request.IsCancel()) @@ -36,8 +33,7 @@ func TestNewRequest(t *testing.T) { assert.True(t, request.IsPull()) assert.True(t, request.IsRequest()) assert.Equal(t, baseCid.String(), request.BaseCid().String()) - encoding.NewDecoder(request) - testutil.AssertFakeDTVoucher(t, request, voucher) + testutil.AssertTestVoucher(t, request, voucher) receivedSelector, err := request.Selector() require.NoError(t, err) require.Equal(t, selector, receivedSelector) @@ -56,8 +52,8 @@ func TestRestartRequest(t *testing.T) { selector := builder.NewSelectorSpecBuilder(basicnode.Prototype.Any).Matcher().Node() isPull := true id := datatransfer.TransferID(rand.Int31()) - voucher := testutil.NewFakeDTType() - request, err := message1_1.NewRequest(id, true, isPull, voucher.Type(), voucher, baseCid, selector) + voucher := testutil.NewTestTypedVoucher() + request, err := message1_1.NewRequest(id, true, isPull, &voucher, baseCid, selector) require.NoError(t, err) assert.Equal(t, id, request.TransferID()) assert.False(t, request.IsCancel()) @@ -65,7 +61,7 @@ func TestRestartRequest(t *testing.T) { assert.True(t, request.IsPull()) assert.True(t, request.IsRequest()) assert.Equal(t, baseCid.String(), request.BaseCid().String()) - testutil.AssertFakeDTVoucher(t, request, voucher) + testutil.AssertTestVoucher(t, request, voucher) receivedSelector, err := request.Selector() require.NoError(t, err) require.Equal(t, selector, receivedSelector) @@ -115,19 +111,9 @@ func TestRestartExistingChannelRequest(t *testing.T) { }) } -func TestTransferRequest_MarshalCBOR(t *testing.T) { - // sanity check MarshalCBOR does its thing w/o error - req, err := NewTestTransferRequest() - require.NoError(t, err) - wbuf := new(bytes.Buffer) - node := bindnode.Wrap(&req, message1_1.Prototype.TransferRequest.Type()) - err = dagcbor.Encode(node.Representation(), wbuf) - require.NoError(t, err) - assert.Greater(t, wbuf.Len(), 0) -} func TestTransferRequest_UnmarshalCBOR(t *testing.T) { t.Run("round-trip", func(t *testing.T) { - req, err := NewTestTransferRequest() + req, err := NewTestTransferRequest("test data here") require.NoError(t, err) wbuf := new(bytes.Buffer) // use ToNet / FromNet @@ -144,14 +130,15 @@ func TestTransferRequest_UnmarshalCBOR(t *testing.T) { assert.Equal(t, req.IsPull(), desReq.IsPull()) assert.Equal(t, req.IsCancel(), desReq.IsCancel()) assert.Equal(t, req.BaseCid(), desReq.BaseCid()) - testutil.AssertEqualFakeDTVoucher(t, &req, desReq) + testutil.AssertEqualTestVoucher(t, &req, desReq) testutil.AssertEqualSelector(t, &req, desReq) }) t.Run("cbor-gen compat", func(t *testing.T) { - req, err := NewTestTransferRequest() + vouchByts, _ := hex.DecodeString("f55ff8f12508b63ef2bfeca7557ae90df6311a5ec1631b4a1fa843310bd9c3a710eaace5a1bdd72ad0bfe049771c11e756338bd93865e645f1adec9b9c99ef407fbd4fc6859e7904c5ad7dc9bd10a5cc16973d5b28ec1a6dd43d9f82f9f18c3d03418e35") + req, err := NewTestTransferRequest(string(vouchByts)) require.NoError(t, err) - msg, _ := hex.DecodeString("a36449735271f56752657175657374aa6442436964d82a58230012204bf5122f344554c53bde2ebb8cd2b7e3d1600ad631c385a5d7cce23c7785459a6454797065006450617573f46450617274f46450756c6cf46453746f72a1612ea065566f756368817864f55ff8f12508b63ef2bfeca7557ae90df6311a5ec1631b4a1fa843310bd9c3a710eaace5a1bdd72ad0bfe049771c11e756338bd93865e645f1adec9b9c99ef407fbd4fc6859e7904c5ad7dc9bd10a5cc16973d5b28ec1a6dd43d9f82f9f18c3d03418e3564565479706a46616b65445454797065665866657249441a4d6582216e526573746172744368616e6e656c8360600068526573706f6e7365f6") + msg, _ := hex.DecodeString("a36449735271f56752657175657374aa6442436964d82a58230012204bf5122f344554c53bde2ebb8cd2b7e3d1600ad631c385a5d7cce23c7785459a6454797065006450617573f46450617274f46450756c6cf46453746f72a1612ea065566f756368817864f55ff8f12508b63ef2bfeca7557ae90df6311a5ec1631b4a1fa843310bd9c3a710eaace5a1bdd72ad0bfe049771c11e756338bd93865e645f1adec9b9c99ef407fbd4fc6859e7904c5ad7dc9bd10a5cc16973d5b28ec1a6dd43d9f82f9f18c3d03418e3564565479706b54657374566f7563686572665866657249441a4d6582216e526573746172744368616e6e656c8360600068526573706f6e7365f6") desMsg, err := message1_1.FromNet(bytes.NewReader(msg)) require.NoError(t, err) @@ -164,23 +151,22 @@ func TestTransferRequest_UnmarshalCBOR(t *testing.T) { assert.Equal(t, req.IsCancel(), desReq.IsCancel()) c, _ := cid.Parse("QmTTA2daxGqo5denp6SwLzzkLJm3fuisYEi9CoWsuHpzfb") assert.Equal(t, c, desReq.BaseCid()) - testutil.AssertEqualFakeDTVoucher(t, &req, desReq) + testutil.AssertEqualTestVoucher(t, &req, desReq) testutil.AssertEqualSelector(t, &req, desReq) }) } func TestResponses(t *testing.T) { id := datatransfer.TransferID(rand.Int31()) - voucherResult := testutil.NewFakeDTType() - response, err := message1_1.NewResponse(id, false, true, voucherResult.Type(), voucherResult) // not accepted - require.NoError(t, err) + voucherResult := testutil.NewTestTypedVoucher() + response := message1_1.NewResponse(id, false, true, &voucherResult) // not accepted assert.Equal(t, response.TransferID(), id) assert.False(t, response.Accepted()) assert.True(t, response.IsNew()) assert.False(t, response.IsUpdate()) assert.True(t, response.IsPaused()) assert.False(t, response.IsRequest()) - testutil.AssertFakeDTVoucherResult(t, response, voucherResult) + testutil.AssertTestVoucherResult(t, response, voucherResult) // Sanity check to make sure we can cast to datatransfer.Message msg, ok := response.(datatransfer.Message) require.True(t, ok) @@ -194,9 +180,8 @@ func TestResponses(t *testing.T) { func TestTransferResponse_MarshalCBOR(t *testing.T) { id := datatransfer.TransferID(rand.Int31()) - voucherResult := testutil.NewFakeDTType() - response, err := message1_1.NewResponse(id, true, false, voucherResult.Type(), voucherResult) // accepted - require.NoError(t, err) + voucherResult := testutil.NewTestTypedVoucher() + response := message1_1.NewResponse(id, true, false, &voucherResult) // accepted // sanity check that we can marshal data wbuf := new(bytes.Buffer) @@ -207,9 +192,8 @@ func TestTransferResponse_MarshalCBOR(t *testing.T) { func TestTransferResponse_UnmarshalCBOR(t *testing.T) { t.Run("round-trip", func(t *testing.T) { id := datatransfer.TransferID(rand.Int31()) - voucherResult := testutil.NewFakeDTType() - response, err := message1_1.NewResponse(id, true, false, voucherResult.Type(), voucherResult) // accepted - require.NoError(t, err) + voucherResult := testutil.NewTestTypedVoucher() + response := message1_1.NewResponse(id, true, false, &voucherResult) // accepted wbuf := new(bytes.Buffer) require.NoError(t, response.ToNet(wbuf)) @@ -229,13 +213,11 @@ func TestTransferResponse_UnmarshalCBOR(t *testing.T) { assert.True(t, desResp.IsNew()) assert.False(t, desResp.IsUpdate()) assert.False(t, desMsg.IsPaused()) - testutil.AssertFakeDTVoucherResult(t, desResp, voucherResult) + testutil.AssertTestVoucherResult(t, desResp, voucherResult) }) t.Run("cbor-gen compat", func(t *testing.T) { - voucherResult := testutil.NewFakeDTType() - voucherResult.Data = "\xf5_\xf8\xf1%\b\xb6>\xf2\xbf\xec\xa7Uz\xe9\r\xf61\x1a^\xc1c\x1bJ\x1f\xa8C1\v\xd9ç\x10\xea\xac塽\xd7*п\xe0Iw\x1c\x11\xe7V3\x8b\xd98e\xe6E\xf1\xad웜\x99\xef@\u007f\xbdOƅ\x9ey\x04ŭ}ɽ\x10\xa5\xcc\x16\x97=[(\xec\x1am\xd4=\x9f\x82\xf9\xf1\x8c=\x03A\x8e5" - - msg, _ := hex.DecodeString("a36449735271f46752657175657374f668526573706f6e7365a66454797065006441637074f56450617573f4665866657249441a4d6582216456526573817864f55ff8f12508b63ef2bfeca7557ae90df6311a5ec1631b4a1fa843310bd9c3a710eaace5a1bdd72ad0bfe049771c11e756338bd93865e645f1adec9b9c99ef407fbd4fc6859e7904c5ad7dc9bd10a5cc16973d5b28ec1a6dd43d9f82f9f18c3d03418e3564565479706a46616b65445454797065") + voucherResult := testutil.NewTestTypedVoucherWith("\xf5_\xf8\xf1%\b\xb6>\xf2\xbf\xec\xa7Uz\xe9\r\xf61\x1a^\xc1c\x1bJ\x1f\xa8C1\v\xd9ç\x10\xea\xac塽\xd7*п\xe0Iw\x1c\x11\xe7V3\x8b\xd98e\xe6E\xf1\xad웜\x99\xef@\u007f\xbdOƅ\x9ey\x04ŭ}ɽ\x10\xa5\xcc\x16\x97=[(\xec\x1am\xd4=\x9f\x82\xf9\xf1\x8c=\x03A\x8e5") + msg, _ := hex.DecodeString("a36449735271f46752657175657374f668526573706f6e7365a66454797065006441637074f56450617573f4665866657249441a4d6582216456526573817864f55ff8f12508b63ef2bfeca7557ae90df6311a5ec1631b4a1fa843310bd9c3a710eaace5a1bdd72ad0bfe049771c11e756338bd93865e645f1adec9b9c99ef407fbd4fc6859e7904c5ad7dc9bd10a5cc16973d5b28ec1a6dd43d9f82f9f18c3d03418e3564565479706b54657374566f7563686572") desMsg, err := message1_1.FromNet(bytes.NewReader(msg)) require.NoError(t, err) assert.False(t, desMsg.IsRequest()) @@ -250,7 +232,7 @@ func TestTransferResponse_UnmarshalCBOR(t *testing.T) { assert.True(t, desResp.IsNew()) assert.False(t, desResp.IsUpdate()) assert.False(t, desMsg.IsPaused()) - testutil.AssertFakeDTVoucherResult(t, desResp, voucherResult) + testutil.AssertTestVoucherResult(t, desResp, voucherResult) }) } @@ -381,13 +363,12 @@ func TestCancelResponse(t *testing.T) { func TestCompleteResponse(t *testing.T) { id := datatransfer.TransferID(rand.Int31()) - response, err := message1_1.CompleteResponse(id, true, true, datatransfer.EmptyTypeIdentifier, nil) - require.NoError(t, err) + response := message1_1.CompleteResponse(id, true, true, nil) assert.Equal(t, response.TransferID(), id) assert.False(t, response.IsNew()) assert.False(t, response.IsUpdate()) assert.True(t, response.IsPaused()) - assert.True(t, response.IsVoucherResult()) + assert.True(t, response.IsValidationResult()) assert.True(t, response.EmptyVoucherResult()) assert.True(t, response.IsComplete()) assert.False(t, response.IsRequest()) @@ -407,9 +388,9 @@ func TestToNetFromNetEquivalency(t *testing.T) { isPull := false id := datatransfer.TransferID(rand.Int31()) accepted := false - voucher := testutil.NewFakeDTType() - voucherResult := testutil.NewFakeDTType() - request, err := message1_1.NewRequest(id, false, isPull, voucher.Type(), voucher, baseCid, selector) + voucher := testutil.NewTestTypedVoucher() + voucherResult := testutil.NewTestTypedVoucher() + request, err := message1_1.NewRequest(id, false, isPull, &voucher, baseCid, selector) require.NoError(t, err) buf := new(bytes.Buffer) err = request.ToNet(buf) @@ -426,11 +407,10 @@ func TestToNetFromNetEquivalency(t *testing.T) { require.Equal(t, deserializedRequest.IsPull(), request.IsPull()) require.Equal(t, deserializedRequest.IsRequest(), request.IsRequest()) require.Equal(t, deserializedRequest.BaseCid(), request.BaseCid()) - testutil.AssertEqualFakeDTVoucher(t, request, deserializedRequest) + testutil.AssertEqualTestVoucher(t, request, deserializedRequest) testutil.AssertEqualSelector(t, request, deserializedRequest) - response, err := message1_1.NewResponse(id, accepted, false, voucherResult.Type(), voucherResult) - require.NoError(t, err) + response := message1_1.NewResponse(id, accepted, false, &voucherResult) err = response.ToNet(buf) require.NoError(t, err) deserialized, err = message1_1.FromNet(buf) @@ -444,7 +424,7 @@ func TestToNetFromNetEquivalency(t *testing.T) { require.Equal(t, deserializedResponse.IsRequest(), response.IsRequest()) require.Equal(t, deserializedResponse.IsUpdate(), response.IsUpdate()) require.Equal(t, deserializedResponse.IsPaused(), response.IsPaused()) - testutil.AssertEqualFakeDTVoucherResult(t, response, deserializedResponse) + testutil.AssertEqualTestVoucherResult(t, response, deserializedResponse) request = message1_1.CancelRequest(id) err = request.ToNet(buf) @@ -465,15 +445,17 @@ func TestToNetFromNetEquivalency(t *testing.T) { isPull := false id := datatransfer.TransferID(1298498081) accepted := false - voucher := testutil.NewFakeDTType() - voucherResult := testutil.NewFakeDTType() - request, err := message1_1.NewRequest(id, false, isPull, voucher.Type(), voucher, baseCid, selector) + vouchByts, _ := hex.DecodeString("f55ff8f12508b63ef2bfeca7557ae90df6311a5ec1631b4a1fa843310bd9c3a710eaace5a1bdd72ad0bfe049771c11e756338bd93865e645f1adec9b9c99ef407fbd4fc6859e7904c5ad7dc9bd10a5cc16973d5b28ec1a6dd43d9f82f9f18c3d03418e35") + voucher := testutil.NewTestTypedVoucherWith(string(vouchByts)) + vouchResultByts, _ := hex.DecodeString("4204cb9a1e34c5f08e9b20aa76090e70020bb56c0ca3d3af7296cd1058a5112890fed218488f084d8df9e4835fb54ad045ffd936e3bf7261b0426c51352a097816ed74482bb9084b4a7ed8adc517f3371e0e0434b511625cd1a41792243dccdcfe88094b") + voucherResult := testutil.NewTestTypedVoucherWith(string(vouchResultByts)) + request, err := message1_1.NewRequest(id, false, isPull, &voucher, baseCid, selector) require.NoError(t, err) buf := new(bytes.Buffer) err = request.ToNet(buf) require.NoError(t, err) require.Greater(t, buf.Len(), 0) - msg, _ := hex.DecodeString("a36449735271f56752657175657374aa6442436964d82a58230012204bf5122f344554c53bde2ebb8cd2b7e3d1600ad631c385a5d7cce23c7785459a6454797065006450617573f46450617274f46450756c6cf46453746f72a1612ea065566f756368817864f55ff8f12508b63ef2bfeca7557ae90df6311a5ec1631b4a1fa843310bd9c3a710eaace5a1bdd72ad0bfe049771c11e756338bd93865e645f1adec9b9c99ef407fbd4fc6859e7904c5ad7dc9bd10a5cc16973d5b28ec1a6dd43d9f82f9f18c3d03418e3564565479706a46616b65445454797065665866657249441a4d6582216e526573746172744368616e6e656c8360600068526573706f6e7365f6") + msg, _ := hex.DecodeString("a36449735271f56752657175657374aa6442436964d82a58230012204bf5122f344554c53bde2ebb8cd2b7e3d1600ad631c385a5d7cce23c7785459a6454797065006450617573f46450617274f46450756c6cf46453746f72a1612ea065566f756368817864f55ff8f12508b63ef2bfeca7557ae90df6311a5ec1631b4a1fa843310bd9c3a710eaace5a1bdd72ad0bfe049771c11e756338bd93865e645f1adec9b9c99ef407fbd4fc6859e7904c5ad7dc9bd10a5cc16973d5b28ec1a6dd43d9f82f9f18c3d03418e3564565479706b54657374566f7563686572665866657249441a4d6582216e526573746172744368616e6e656c8360600068526573706f6e7365f6") deserialized, err := message1_1.FromNet(bytes.NewReader(msg)) require.NoError(t, err) @@ -486,14 +468,13 @@ func TestToNetFromNetEquivalency(t *testing.T) { require.Equal(t, deserializedRequest.IsRequest(), request.IsRequest()) c, _ := cid.Parse("QmTTA2daxGqo5denp6SwLzzkLJm3fuisYEi9CoWsuHpzfb") assert.Equal(t, c, deserializedRequest.BaseCid()) - testutil.AssertEqualFakeDTVoucher(t, request, deserializedRequest) + testutil.AssertEqualTestVoucher(t, request, deserializedRequest) testutil.AssertEqualSelector(t, request, deserializedRequest) - response, err := message1_1.NewResponse(id, accepted, false, voucherResult.Type(), voucherResult) - require.NoError(t, err) + response := message1_1.NewResponse(id, accepted, false, &voucherResult) err = response.ToNet(buf) require.NoError(t, err) - msg, _ = hex.DecodeString("a36449735271f46752657175657374f668526573706f6e7365a66454797065006441637074f46450617573f4665866657249441a4d65822164565265738178644204cb9a1e34c5f08e9b20aa76090e70020bb56c0ca3d3af7296cd1058a5112890fed218488f084d8df9e4835fb54ad045ffd936e3bf7261b0426c51352a097816ed74482bb9084b4a7ed8adc517f3371e0e0434b511625cd1a41792243dccdcfe88094b64565479706a46616b65445454797065") + msg, _ = hex.DecodeString("a36449735271f46752657175657374f668526573706f6e7365a66454797065006441637074f46450617573f4665866657249441a4d65822164565265738178644204cb9a1e34c5f08e9b20aa76090e70020bb56c0ca3d3af7296cd1058a5112890fed218488f084d8df9e4835fb54ad045ffd936e3bf7261b0426c51352a097816ed74482bb9084b4a7ed8adc517f3371e0e0434b511625cd1a41792243dccdcfe88094b64565479706b54657374566f7563686572") deserialized, err = message1_1.FromNet(bytes.NewReader(msg)) require.NoError(t, err) @@ -505,7 +486,7 @@ func TestToNetFromNetEquivalency(t *testing.T) { require.Equal(t, deserializedResponse.IsRequest(), response.IsRequest()) require.Equal(t, deserializedResponse.IsUpdate(), response.IsUpdate()) require.Equal(t, deserializedResponse.IsPaused(), response.IsPaused()) - testutil.AssertEqualFakeDTVoucherResult(t, response, deserializedResponse) + testutil.AssertEqualTestVoucherResult(t, response, deserializedResponse) request = message1_1.CancelRequest(id) err = request.ToNet(buf) @@ -517,6 +498,74 @@ func TestToNetFromNetEquivalency(t *testing.T) { deserializedRequest, ok = deserialized.(datatransfer.Request) require.True(t, ok) + require.Equal(t, deserializedRequest.TransferID(), request.TransferID()) + require.Equal(t, deserializedRequest.IsCancel(), request.IsCancel()) + require.Equal(t, deserializedRequest.IsRequest(), request.IsRequest()) + }) + t.Run("round-trip with wrapping", func(t *testing.T) { + transportID := datatransfer.TransportID("applesauce") + transportVersion := datatransfer.Version{Major: 1, Minor: 5, Patch: 0} + baseCid := testutil.GenerateCids(1)[0] + selector := builder.NewSelectorSpecBuilder(basicnode.Prototype.Any).Matcher().Node() + isPull := false + id := datatransfer.TransferID(rand.Int31()) + accepted := false + voucher := testutil.NewTestTypedVoucher() + voucherResult := testutil.NewTestTypedVoucher() + request, err := message1_1.NewRequest(id, false, isPull, &voucher, baseCid, selector) + require.NoError(t, err) + wrequest := request.WrappedForTransport(transportID, transportVersion) + buf := new(bytes.Buffer) + err = wrequest.ToNet(buf) + require.NoError(t, err) + require.Greater(t, buf.Len(), 0) + deserialized, err := message1_1.FromNetWrapped(buf) + require.NoError(t, err) + + require.Equal(t, transportID, deserialized.TransportID()) + require.Equal(t, transportVersion, deserialized.TransportVersion()) + deserializedRequest, ok := deserialized.(datatransfer.Request) + require.True(t, ok) + + require.Equal(t, deserializedRequest.TransferID(), request.TransferID()) + require.Equal(t, deserializedRequest.IsCancel(), request.IsCancel()) + require.Equal(t, deserializedRequest.IsPull(), request.IsPull()) + require.Equal(t, deserializedRequest.IsRequest(), request.IsRequest()) + require.Equal(t, deserializedRequest.BaseCid(), request.BaseCid()) + testutil.AssertEqualTestVoucher(t, request, deserializedRequest) + testutil.AssertEqualSelector(t, request, deserializedRequest) + + response := message1_1.NewResponse(id, accepted, false, &voucherResult) + wresponse := response.WrappedForTransport(transportID, transportVersion) + err = wresponse.ToNet(buf) + require.NoError(t, err) + deserialized, err = message1_1.FromNetWrapped(buf) + require.NoError(t, err) + require.Equal(t, transportID, deserialized.TransportID()) + require.Equal(t, transportVersion, deserialized.TransportVersion()) + + deserializedResponse, ok := deserialized.(datatransfer.Response) + require.True(t, ok) + + require.Equal(t, deserializedResponse.TransferID(), response.TransferID()) + require.Equal(t, deserializedResponse.Accepted(), response.Accepted()) + require.Equal(t, deserializedResponse.IsRequest(), response.IsRequest()) + require.Equal(t, deserializedResponse.IsUpdate(), response.IsUpdate()) + require.Equal(t, deserializedResponse.IsPaused(), response.IsPaused()) + testutil.AssertEqualTestVoucherResult(t, response, deserializedResponse) + + request = message1_1.CancelRequest(id) + wrequest = request.WrappedForTransport(transportID, transportVersion) + err = wrequest.ToNet(buf) + require.NoError(t, err) + deserialized, err = message1_1.FromNetWrapped(buf) + require.NoError(t, err) + require.Equal(t, transportID, deserialized.TransportID()) + require.Equal(t, transportVersion, deserialized.TransportVersion()) + + deserializedRequest, ok = deserialized.(datatransfer.Request) + require.True(t, ok) + require.Equal(t, deserializedRequest.TransferID(), request.TransferID()) require.Equal(t, deserializedRequest.IsCancel(), request.IsCancel()) require.Equal(t, deserializedRequest.IsRequest(), request.IsRequest()) @@ -537,13 +586,13 @@ func TestFromNetMessageValidation(t *testing.T) { assert.Nil(t, msg) } -func NewTestTransferRequest() (message1_1.TransferRequest1_1, error) { +func NewTestTransferRequest(data string) (message1_1.TransferRequest1_1, error) { bcid := testutil.GenerateCids(1)[0] selector := builder.NewSelectorSpecBuilder(basicnode.Prototype.Any).Matcher().Node() isPull := false id := datatransfer.TransferID(rand.Int31()) - voucher := testutil.NewFakeDTType() - req, err := message1_1.NewRequest(id, false, isPull, voucher.Type(), voucher, bcid, selector) + voucher := testutil.NewTestTypedVoucherWith(data) + req, err := message1_1.NewRequest(id, false, isPull, &voucher, bcid, selector) if err != nil { return message1_1.TransferRequest1_1{}, err } diff --git a/message/message1_1prime/schema.go b/message/message1_1prime/schema.go deleted file mode 100644 index c779b1fc..00000000 --- a/message/message1_1prime/schema.go +++ /dev/null @@ -1,29 +0,0 @@ -package message1_1 - -import ( - _ "embed" - - "github.com/ipld/go-ipld-prime" - "github.com/ipld/go-ipld-prime/node/bindnode" - "github.com/ipld/go-ipld-prime/schema" -) - -//go:embed schema.ipldsch -var embedSchema []byte - -var Prototype struct { - TransferMessage schema.TypedPrototype - TransferRequest schema.TypedPrototype - TransferResponse schema.TypedPrototype -} - -func init() { - ts, err := ipld.LoadSchemaBytes(embedSchema) - if err != nil { - panic(err) - } - - Prototype.TransferMessage = bindnode.Prototype((*TransferMessage1_1)(nil), ts.TypeByName("TransferMessage")) - Prototype.TransferRequest = bindnode.Prototype((*TransferRequest1_1)(nil), ts.TypeByName("TransferRequest")) - Prototype.TransferResponse = bindnode.Prototype((*TransferResponse1_1)(nil), ts.TypeByName("TransferResponse")) -} diff --git a/message/message1_1prime/schema.ipldsch b/message/message1_1prime/schema.ipldsch index 71413514..32683663 100644 --- a/message/message1_1prime/schema.ipldsch +++ b/message/message1_1prime/schema.ipldsch @@ -1,6 +1,7 @@ type PeerID string # peer.ID, really should be bytes (this is non-utf8) but is string for backward compat type TransferID int type TypeIdentifier string +type TransportID string type ChannelID struct { Initiator PeerID @@ -30,8 +31,20 @@ type TransferResponse struct { VoucherTypeIdentifier TypeIdentifier (rename "VTyp") } -type TransferMessage struct { +type TransferMessage1_1 struct { IsRequest Bool (rename "IsRq") Request nullable TransferRequest Response nullable TransferResponse } + +type Version struct { + Major Int + Minor Int + Patch Int +} representation tuple + +type WrappedTransferMessage1_1 struct { + TransportID TransportID (rename "ID") + TransportVersion Version (rename "TV") + Message TransferMessage1_1 (rename "Msg") +} \ No newline at end of file diff --git a/message/message1_1prime/transfer_message.go b/message/message1_1prime/transfer_message.go index 9e4aa261..0f28e52d 100644 --- a/message/message1_1prime/transfer_message.go +++ b/message/message1_1prime/transfer_message.go @@ -1,16 +1,22 @@ package message1_1 import ( + _ "embed" "io" "github.com/ipld/go-ipld-prime/codec/dagcbor" "github.com/ipld/go-ipld-prime/datamodel" - "github.com/ipld/go-ipld-prime/node/bindnode" + bindnoderegistry "github.com/ipld/go-ipld-prime/node/bindnode/registry" "github.com/ipld/go-ipld-prime/schema" - datatransfer "github.com/filecoin-project/go-data-transfer" + datatransfer "github.com/filecoin-project/go-data-transfer/v2" ) +var bindnodeRegistry = bindnoderegistry.NewRegistry() + +//go:embed schema.ipldsch +var embedSchema []byte + // TransferMessage1_1 is the transfer message for the 1.1 Data Transfer Protocol. type TransferMessage1_1 struct { IsRequest bool @@ -30,15 +36,38 @@ func (tm *TransferMessage1_1) TransferID() datatransfer.TransferID { } func (tm *TransferMessage1_1) toIPLD() schema.TypedNode { - return bindnode.Wrap(tm, Prototype.TransferMessage.Type()) + return bindnodeRegistry.TypeToNode(tm) } -// ToNet serializes a transfer message type. +// ToIPLD converts a transfer message type to an ipld Node func (tm *TransferMessage1_1) ToIPLD() (datamodel.Node, error) { return tm.toIPLD().Representation(), nil } // ToNet serializes a transfer message type. func (tm *TransferMessage1_1) ToNet(w io.Writer) error { - return dagcbor.Encode(tm.toIPLD().Representation(), w) + return bindnodeRegistry.TypeToWriter(tm.toIPLD(), w, dagcbor.Encode) +} + +func init() { + if err := bindnodeRegistry.RegisterType((*TransferMessage1_1)(nil), string(embedSchema), "TransferMessage1_1"); err != nil { + panic(err.Error()) + } + if err := bindnodeRegistry.RegisterType((*WrappedTransferMessage1_1)(nil), string(embedSchema), "WrappedTransferMessage1_1"); err != nil { + panic(err.Error()) + } +} + +type WrappedTransferMessage1_1 struct { + TransportID string + TransportVersion datatransfer.Version + Message TransferMessage1_1 +} + +func (wtm *WrappedTransferMessage1_1) BindnodeSchema() string { + return string(embedSchema) +} + +func (wtm *WrappedTransferMessage1_1) toIPLD() schema.TypedNode { + return bindnodeRegistry.TypeToNode(wtm) } diff --git a/message/message1_1prime/transfer_request.go b/message/message1_1prime/transfer_request.go index 495f0668..007d651d 100644 --- a/message/message1_1prime/transfer_request.go +++ b/message/message1_1prime/transfer_request.go @@ -4,15 +4,14 @@ import ( "io" "github.com/ipfs/go-cid" + "github.com/ipld/go-ipld-prime" "github.com/ipld/go-ipld-prime/codec/dagcbor" "github.com/ipld/go-ipld-prime/datamodel" "github.com/ipld/go-ipld-prime/schema" - "github.com/libp2p/go-libp2p-core/protocol" xerrors "golang.org/x/xerrors" - datatransfer "github.com/filecoin-project/go-data-transfer" - "github.com/filecoin-project/go-data-transfer/encoding" - "github.com/filecoin-project/go-data-transfer/message/types" + datatransfer "github.com/filecoin-project/go-data-transfer/v2" + "github.com/filecoin-project/go-data-transfer/v2/message/types" ) // TransferRequest1_1 is a struct for the 1.1 Data Transfer Protocol that fulfills the datatransfer.Request interface. @@ -23,22 +22,34 @@ type TransferRequest1_1 struct { Pause bool Partial bool Pull bool - SelectorPtr *datamodel.Node - VoucherPtr *datamodel.Node + SelectorPtr datamodel.Node + VoucherPtr datamodel.Node VoucherTypeIdentifier datatransfer.TypeIdentifier TransferId uint64 RestartChannel datatransfer.ChannelID } -func (trq *TransferRequest1_1) MessageForProtocol(targetProtocol protocol.ID) (datatransfer.Message, error) { - switch targetProtocol { - case datatransfer.ProtocolDataTransfer1_2: +func (trq *TransferRequest1_1) MessageForVersion(version datatransfer.Version) (datatransfer.Message, error) { + switch version { + case datatransfer.DataTransfer1_2: return trq, nil default: return nil, xerrors.Errorf("protocol not supported") } } +func (trq *TransferRequest1_1) Version() datatransfer.Version { + return datatransfer.DataTransfer1_2 +} + +func (trq *TransferRequest1_1) WrappedForTransport(transportID datatransfer.TransportID, transportVersion datatransfer.Version) datatransfer.TransportedMessage { + return &WrappedTransferRequest1_1{ + TransferRequest1_1: trq, + transportID: string(transportID), + transportVersion: transportVersion, + } +} + // IsRequest always returns true in this case because this is a transfer request func (trq *TransferRequest1_1) IsRequest() bool { return true @@ -91,11 +102,22 @@ func (trq *TransferRequest1_1) VoucherType() datatransfer.TypeIdentifier { } // Voucher returns the Voucher bytes -func (trq *TransferRequest1_1) Voucher(decoder encoding.Decoder) (encoding.Encodable, error) { +func (trq *TransferRequest1_1) Voucher() (datamodel.Node, error) { if trq.VoucherPtr == nil { return nil, xerrors.New("No voucher present to read") } - return decoder.DecodeFromNode(*trq.VoucherPtr) + return trq.VoucherPtr, nil +} + +// TypedVoucher is a convenience method that returns the voucher and its typed +// as a TypedVoucher object +// TODO(rvagg): tests for this +func (trq *TransferRequest1_1) TypedVoucher() (datatransfer.TypedVoucher, error) { + voucher, err := trq.Voucher() + if err != nil { + return datatransfer.TypedVoucher{}, err + } + return datatransfer.TypedVoucher{Voucher: voucher, Type: trq.VoucherType()}, nil } func (trq *TransferRequest1_1) EmptyVoucher() bool { @@ -115,7 +137,7 @@ func (trq *TransferRequest1_1) Selector() (datamodel.Node, error) { if trq.SelectorPtr == nil { return nil, xerrors.New("No selector present to read") } - return *trq.SelectorPtr, nil + return trq.SelectorPtr, nil } // IsCancel returns true if this is a cancel request @@ -128,20 +150,57 @@ func (trq *TransferRequest1_1) IsPartial() bool { return trq.Partial } -func (trsp *TransferRequest1_1) toIPLD() schema.TypedNode { +func (trq *TransferRequest1_1) toIPLD() schema.TypedNode { msg := TransferMessage1_1{ IsRequest: true, - Request: trsp, + Request: trq, Response: nil, } return msg.toIPLD() } -func (trq *TransferRequest1_1) ToIPLD() (datamodel.Node, error) { - return trq.toIPLD().Representation(), nil +func (trq *TransferRequest1_1) ToIPLD() datamodel.Node { + return trq.toIPLD().Representation() } // ToNet serializes a transfer request. func (trq *TransferRequest1_1) ToNet(w io.Writer) error { - return dagcbor.Encode(trq.toIPLD().Representation(), w) + return ipld.EncodeStreaming(w, trq.toIPLD(), dagcbor.Encode) +} + +// WrappedTransferRequest1_1 is used to serialize a request along with a +// transport id +type WrappedTransferRequest1_1 struct { + *TransferRequest1_1 + transportVersion datatransfer.Version + transportID string +} + +func (trq *WrappedTransferRequest1_1) TransportID() datatransfer.TransportID { + return datatransfer.TransportID(trq.transportID) +} + +func (trq *WrappedTransferRequest1_1) TransportVersion() datatransfer.Version { + return trq.transportVersion +} + +func (trq *WrappedTransferRequest1_1) toIPLD() schema.TypedNode { + msg := WrappedTransferMessage1_1{ + TransportID: trq.transportID, + TransportVersion: trq.transportVersion, + Message: TransferMessage1_1{ + IsRequest: true, + Request: trq.TransferRequest1_1, + Response: nil, + }, + } + return msg.toIPLD() +} + +func (trq *WrappedTransferRequest1_1) ToIPLD() datamodel.Node { + return trq.toIPLD().Representation() +} + +func (trq *WrappedTransferRequest1_1) ToNet(w io.Writer) error { + return ipld.EncodeStreaming(w, trq.toIPLD(), dagcbor.Encode) } diff --git a/message/message1_1prime/transfer_request_test.go b/message/message1_1prime/transfer_request_test.go index 96d43d29..410917b8 100644 --- a/message/message1_1prime/transfer_request_test.go +++ b/message/message1_1prime/transfer_request_test.go @@ -8,23 +8,24 @@ import ( "github.com/ipld/go-ipld-prime/traversal/selector/builder" "github.com/stretchr/testify/require" - datatransfer "github.com/filecoin-project/go-data-transfer" - message1_1 "github.com/filecoin-project/go-data-transfer/message/message1_1prime" - "github.com/filecoin-project/go-data-transfer/testutil" + datatransfer "github.com/filecoin-project/go-data-transfer/v2" + message1_1 "github.com/filecoin-project/go-data-transfer/v2/message/message1_1prime" + "github.com/filecoin-project/go-data-transfer/v2/testutil" ) -func TestRequestMessageForProtocol(t *testing.T) { +func TestRequestMessageForVersion(t *testing.T) { baseCid := testutil.GenerateCids(1)[0] selector := builder.NewSelectorSpecBuilder(basicnode.Prototype.Any).Matcher().Node() isPull := true id := datatransfer.TransferID(rand.Int31()) - voucher := testutil.NewFakeDTType() + voucher := testutil.NewTestTypedVoucher() // for the new protocols - request, err := message1_1.NewRequest(id, false, isPull, voucher.Type(), voucher, baseCid, selector) + request, err := message1_1.NewRequest(id, false, isPull, &voucher, baseCid, selector) require.NoError(t, err) - out12, err := request.MessageForProtocol(datatransfer.ProtocolDataTransfer1_2) + // v1.2 new protocol + out12, err := request.MessageForVersion(datatransfer.DataTransfer1_2) require.NoError(t, err) require.Equal(t, request, out12) @@ -37,5 +38,18 @@ func TestRequestMessageForProtocol(t *testing.T) { n, err := req.Selector() require.NoError(t, err) require.Equal(t, selector, n) - require.Equal(t, voucher.Type(), req.VoucherType()) + require.Equal(t, testutil.TestVoucherType, req.VoucherType()) + + wrappedOut12 := out12.WrappedForTransport(datatransfer.LegacyTransportID, datatransfer.LegacyTransportVersion) + require.Equal(t, datatransfer.LegacyTransportID, wrappedOut12.TransportID()) + require.Equal(t, datatransfer.LegacyTransportVersion, wrappedOut12.TransportVersion()) + + // random protocol should fail + _, err = request.MessageForVersion(datatransfer.Version{ + Major: rand.Uint64(), + Minor: rand.Uint64(), + Patch: rand.Uint64(), + }) + require.Error(t, err) + } diff --git a/message/message1_1prime/transfer_response.go b/message/message1_1prime/transfer_response.go index c468906d..3e3c41f5 100644 --- a/message/message1_1prime/transfer_response.go +++ b/message/message1_1prime/transfer_response.go @@ -3,15 +3,14 @@ package message1_1 import ( "io" + "github.com/ipld/go-ipld-prime" "github.com/ipld/go-ipld-prime/codec/dagcbor" "github.com/ipld/go-ipld-prime/datamodel" "github.com/ipld/go-ipld-prime/schema" - "github.com/libp2p/go-libp2p-core/protocol" xerrors "golang.org/x/xerrors" - datatransfer "github.com/filecoin-project/go-data-transfer" - "github.com/filecoin-project/go-data-transfer/encoding" - "github.com/filecoin-project/go-data-transfer/message/types" + datatransfer "github.com/filecoin-project/go-data-transfer/v2" + "github.com/filecoin-project/go-data-transfer/v2/message/types" ) // TransferResponse1_1 is a private struct that satisfies the datatransfer.Response interface @@ -21,7 +20,7 @@ type TransferResponse1_1 struct { RequestAccepted bool Paused bool TransferId uint64 - VoucherResultPtr *datamodel.Node + VoucherResultPtr datamodel.Node VoucherTypeIdentifier datatransfer.TypeIdentifier } @@ -59,7 +58,7 @@ func (trsp *TransferResponse1_1) IsComplete() bool { return trsp.MessageType == uint64(types.CompleteMessage) } -func (trsp *TransferResponse1_1) IsVoucherResult() bool { +func (trsp *TransferResponse1_1) IsValidationResult() bool { return trsp.MessageType == uint64(types.VoucherResultMessage) || trsp.MessageType == uint64(types.NewMessage) || trsp.MessageType == uint64(types.CompleteMessage) || trsp.MessageType == uint64(types.RestartMessage) } @@ -73,11 +72,11 @@ func (trsp *TransferResponse1_1) VoucherResultType() datatransfer.TypeIdentifier return trsp.VoucherTypeIdentifier } -func (trsp *TransferResponse1_1) VoucherResult(decoder encoding.Decoder) (encoding.Encodable, error) { +func (trsp *TransferResponse1_1) VoucherResult() (datamodel.Node, error) { if trsp.VoucherResultPtr == nil { return nil, xerrors.New("No voucher present to read") } - return decoder.DecodeFromNode(*trsp.VoucherResultPtr) + return trsp.VoucherResultPtr, nil } func (trq *TransferResponse1_1) IsRestart() bool { @@ -88,15 +87,26 @@ func (trsp *TransferResponse1_1) EmptyVoucherResult() bool { return trsp.VoucherTypeIdentifier == datatransfer.EmptyTypeIdentifier } -func (trsp *TransferResponse1_1) MessageForProtocol(targetProtocol protocol.ID) (datatransfer.Message, error) { - switch targetProtocol { - case datatransfer.ProtocolDataTransfer1_2: +func (trsp *TransferResponse1_1) MessageForVersion(version datatransfer.Version) (datatransfer.Message, error) { + switch version { + case datatransfer.DataTransfer1_2: return trsp, nil default: - return nil, xerrors.Errorf("protocol %s not supported", targetProtocol) + return nil, xerrors.Errorf("protocol %s not supported", version) } } +func (trsp *TransferResponse1_1) Version() datatransfer.Version { + return datatransfer.DataTransfer1_2 +} + +func (trsp *TransferResponse1_1) WrappedForTransport(transportID datatransfer.TransportID, transportVersion datatransfer.Version) datatransfer.TransportedMessage { + return &WrappedTransferResponse1_1{ + TransferResponse1_1: trsp, + transportID: string(transportID), + transportVersion: transportVersion, + } +} func (trsp *TransferResponse1_1) toIPLD() schema.TypedNode { msg := TransferMessage1_1{ IsRequest: false, @@ -106,11 +116,47 @@ func (trsp *TransferResponse1_1) toIPLD() schema.TypedNode { return msg.toIPLD() } -func (trsp *TransferResponse1_1) ToIPLD() (datamodel.Node, error) { - return trsp.toIPLD().Representation(), nil +func (trsp *TransferResponse1_1) ToIPLD() datamodel.Node { + return trsp.toIPLD().Representation() } // ToNet serializes a transfer response. func (trsp *TransferResponse1_1) ToNet(w io.Writer) error { - return dagcbor.Encode(trsp.toIPLD().Representation(), w) + return ipld.EncodeStreaming(w, trsp.toIPLD(), dagcbor.Encode) +} + +// WrappedTransferResponse1_1 is used to serialize a response along with a +// transport id +type WrappedTransferResponse1_1 struct { + *TransferResponse1_1 + transportID string + transportVersion datatransfer.Version +} + +func (trsp *WrappedTransferResponse1_1) TransportID() datatransfer.TransportID { + return datatransfer.TransportID(trsp.transportID) +} +func (trsp *WrappedTransferResponse1_1) TransportVersion() datatransfer.Version { + return trsp.transportVersion +} + +func (trsp *WrappedTransferResponse1_1) toIPLD() schema.TypedNode { + msg := WrappedTransferMessage1_1{ + TransportID: trsp.transportID, + TransportVersion: trsp.transportVersion, + Message: TransferMessage1_1{ + IsRequest: false, + Request: nil, + Response: trsp.TransferResponse1_1, + }, + } + return msg.toIPLD() +} + +func (trsp *WrappedTransferResponse1_1) ToIPLD() datamodel.Node { + return trsp.toIPLD().Representation() +} + +func (trsp *WrappedTransferResponse1_1) ToNet(w io.Writer) error { + return ipld.EncodeStreaming(w, trsp.toIPLD(), dagcbor.Encode) } diff --git a/message/message1_1prime/transfer_response_test.go b/message/message1_1prime/transfer_response_test.go index fb4b6fa6..85741882 100644 --- a/message/message1_1prime/transfer_response_test.go +++ b/message/message1_1prime/transfer_response_test.go @@ -6,30 +6,36 @@ import ( "github.com/stretchr/testify/require" - datatransfer "github.com/filecoin-project/go-data-transfer" - message1_1 "github.com/filecoin-project/go-data-transfer/message/message1_1prime" - "github.com/filecoin-project/go-data-transfer/testutil" + datatransfer "github.com/filecoin-project/go-data-transfer/v2" + message1_1 "github.com/filecoin-project/go-data-transfer/v2/message/message1_1prime" + "github.com/filecoin-project/go-data-transfer/v2/testutil" ) -func TestResponseMessageForProtocol(t *testing.T) { +func TestResponseMessageForVersion(t *testing.T) { id := datatransfer.TransferID(rand.Int31()) - voucherResult := testutil.NewFakeDTType() - response, err := message1_1.NewResponse(id, false, true, voucherResult.Type(), voucherResult) // not accepted - require.NoError(t, err) + voucherResult := testutil.NewTestTypedVoucher() + response := message1_1.NewResponse(id, false, true, &voucherResult) // not accepted - // v1.2 protocol - out, err := response.MessageForProtocol(datatransfer.ProtocolDataTransfer1_2) + // v1.2 new protocol + out, err := response.MessageForVersion(datatransfer.DataTransfer1_2) require.NoError(t, err) require.Equal(t, response, out) resp, ok := (out).(datatransfer.Response) require.True(t, ok) require.True(t, resp.IsPaused()) - require.Equal(t, voucherResult.Type(), resp.VoucherResultType()) - require.True(t, resp.IsVoucherResult()) + require.Equal(t, testutil.TestVoucherType, resp.VoucherResultType()) + require.True(t, resp.IsValidationResult()) + + wrappedOut := out.WrappedForTransport(datatransfer.LegacyTransportID, datatransfer.LegacyTransportVersion) + require.Equal(t, datatransfer.LegacyTransportID, wrappedOut.TransportID()) + require.Equal(t, datatransfer.LegacyTransportVersion, wrappedOut.TransportVersion()) - // random protocol - out, err = response.MessageForProtocol("RAND") + // random protocol should fail + _, err = response.MessageForVersion(datatransfer.Version{ + Major: rand.Uint64(), + Minor: rand.Uint64(), + Patch: rand.Uint64(), + }) require.Error(t, err) - require.Nil(t, out) } diff --git a/network/interface.go b/network/interface.go deleted file mode 100644 index 0cb3f528..00000000 --- a/network/interface.go +++ /dev/null @@ -1,58 +0,0 @@ -package network - -import ( - "context" - - "github.com/libp2p/go-libp2p-core/peer" - "github.com/libp2p/go-libp2p-core/protocol" - - datatransfer "github.com/filecoin-project/go-data-transfer" -) - -// DataTransferNetwork provides network connectivity for GraphSync. -type DataTransferNetwork interface { - Protect(id peer.ID, tag string) - Unprotect(id peer.ID, tag string) bool - - // SendMessage sends a GraphSync message to a peer. - SendMessage( - context.Context, - peer.ID, - datatransfer.Message) error - - // SetDelegate registers the Reciver to handle messages received from the - // network. - SetDelegate(Receiver) - - // ConnectTo establishes a connection to the given peer - ConnectTo(context.Context, peer.ID) error - - // ConnectWithRetry establishes a connection to the given peer, retrying if - // necessary, and opens a stream on the data-transfer protocol to verify - // the peer will accept messages on the protocol - ConnectWithRetry(ctx context.Context, p peer.ID) error - - // ID returns the peer id of this libp2p host - ID() peer.ID - - // Protocol returns the protocol version of the peer, connecting to - // the peer if necessary - Protocol(context.Context, peer.ID) (protocol.ID, error) -} - -// Receiver is an interface for receiving messages from the GraphSyncNetwork. -type Receiver interface { - ReceiveRequest( - ctx context.Context, - sender peer.ID, - incoming datatransfer.Request) - - ReceiveResponse( - ctx context.Context, - sender peer.ID, - incoming datatransfer.Response) - - ReceiveRestartExistingChannelRequest(ctx context.Context, sender peer.ID, incoming datatransfer.Request) - - ReceiveError(error) -} diff --git a/registry/registry.go b/registry/registry.go index 3386d1da..00f3815d 100644 --- a/registry/registry.go +++ b/registry/registry.go @@ -5,8 +5,7 @@ import ( "golang.org/x/xerrors" - datatransfer "github.com/filecoin-project/go-data-transfer" - "github.com/filecoin-project/go-data-transfer/encoding" + datatransfer "github.com/filecoin-project/go-data-transfer/v2" ) // Processor is an interface that processes a certain type of encodable objects @@ -14,11 +13,6 @@ import ( // left to the user of the registry type Processor interface{} -type registryEntry struct { - decoder encoding.Decoder - processor Processor -} - // Registry maintans a register of types of encodable objects and a corresponding // processor for those objects // The encodable types must have a method Type() that specifies and identifier @@ -26,54 +20,41 @@ type registryEntry struct { // on this unique identifier type Registry struct { registryLk sync.RWMutex - entries map[datatransfer.TypeIdentifier]registryEntry + entries map[datatransfer.TypeIdentifier]Processor } // NewRegistry initialzes a new registy func NewRegistry() *Registry { return &Registry{ - entries: make(map[datatransfer.TypeIdentifier]registryEntry), + entries: make(map[datatransfer.TypeIdentifier]Processor), } } // Register registers the given processor for the given entry type -func (r *Registry) Register(entry datatransfer.Registerable, processor Processor) error { - identifier := entry.Type() - decoder, err := encoding.NewDecoder(entry) - if err != nil { - return xerrors.Errorf("registering entry type %s: %w", identifier, err) - } +func (r *Registry) Register(identifier datatransfer.TypeIdentifier, processor Processor) error { r.registryLk.Lock() defer r.registryLk.Unlock() if _, ok := r.entries[identifier]; ok { return xerrors.Errorf("identifier already registered: %s", identifier) } - r.entries[identifier] = registryEntry{decoder, processor} + r.entries[identifier] = processor return nil } -// Decoder gets a decoder for the given identifier -func (r *Registry) Decoder(identifier datatransfer.TypeIdentifier) (encoding.Decoder, bool) { - r.registryLk.RLock() - entry, has := r.entries[identifier] - r.registryLk.RUnlock() - return entry.decoder, has -} - // Processor gets the processing interface for the given identifer func (r *Registry) Processor(identifier datatransfer.TypeIdentifier) (Processor, bool) { r.registryLk.RLock() entry, has := r.entries[identifier] r.registryLk.RUnlock() - return entry.processor, has + return entry, has } // Each iterates through all of the entries in this registry -func (r *Registry) Each(process func(datatransfer.TypeIdentifier, encoding.Decoder, Processor) error) error { +func (r *Registry) Each(process func(datatransfer.TypeIdentifier, Processor) error) error { r.registryLk.RLock() defer r.registryLk.RUnlock() - for identifier, entry := range r.entries { - err := process(identifier, entry.decoder, entry.processor) + for identifier, processor := range r.entries { + err := process(identifier, processor) if err != nil { return err } diff --git a/registry/registry_test.go b/registry/registry_test.go index f3eaeb96..84fd95c8 100644 --- a/registry/registry_test.go +++ b/registry/registry_test.go @@ -5,34 +5,22 @@ import ( "github.com/stretchr/testify/require" - "github.com/filecoin-project/go-data-transfer/registry" - "github.com/filecoin-project/go-data-transfer/testutil" + "github.com/filecoin-project/go-data-transfer/v2/registry" + "github.com/filecoin-project/go-data-transfer/v2/testutil" ) func TestRegistry(t *testing.T) { r := registry.NewRegistry() t.Run("it registers", func(t *testing.T) { - err := r.Register(&testutil.FakeDTType{}, func() {}) + err := r.Register(testutil.TestVoucherType, func() {}) require.NoError(t, err) }) t.Run("it errors when registred again", func(t *testing.T) { - err := r.Register(&testutil.FakeDTType{}, func() {}) - require.EqualError(t, err, "identifier already registered: FakeDTType") - }) - t.Run("it errors when decoder setup fails", func(t *testing.T) { - err := r.Register(testutil.FakeDTType{}, func() {}) - require.EqualError(t, err, "registering entry type FakeDTType: type must be a pointer") - }) - t.Run("it reads decoders", func(t *testing.T) { - decoder, has := r.Decoder("FakeDTType") - require.True(t, has) - require.NotNil(t, decoder) - decoder, has = r.Decoder("OtherType") - require.False(t, has) - require.Nil(t, decoder) + err := r.Register(testutil.TestVoucherType, func() {}) + require.EqualError(t, err, "identifier already registered: TestVoucher") }) t.Run("it reads processors", func(t *testing.T) { - processor, has := r.Processor("FakeDTType") + processor, has := r.Processor("TestVoucher") require.True(t, has) require.NotNil(t, processor) processor, has = r.Processor("OtherType") diff --git a/scripts/fiximports b/scripts/fiximports index e51d54ab..5e2d3e2d 100755 --- a/scripts/fiximports +++ b/scripts/fiximports @@ -8,5 +8,5 @@ find . -type f -name \*.go -not -name \*_cbor_gen.go | xargs -I '{}' sed -i.bak }' '{}' git clean -fd find . -type f -name \*.go -not -name \*_cbor_gen.go | xargs -I '{}' goimports -w -local "github.com/filecoin-project" '{}' -find . -type f -name \*.go -not -name \*_cbor_gen.go | xargs -I '{}' goimports -w -local "github.com/filecoin-project/go-data-transfer" '{}' +find . -type f -name \*.go -not -name \*_cbor_gen.go | xargs -I '{}' goimports -w -local "github.com/filecoin-project/go-data-transfer/v2" '{}' diff --git a/statuses.go b/statuses.go index 6a4c89be..3b263de0 100644 --- a/statuses.go +++ b/statuses.go @@ -1,5 +1,7 @@ package datatransfer +import "github.com/filecoin-project/go-statemachine/fsm" + // Status is the status of transfer for a given channel type Status uint64 @@ -40,13 +42,13 @@ const ( // Cancelled means the data transfer ended prematurely Cancelled - // InitiatorPaused means the data sender has paused the channel (only the sender can unpause this) + // DEPRECATED: Use InitiatorPaused() method on ChannelState InitiatorPaused - // ResponderPaused means the data receiver has paused the channel (only the receiver can unpause this) + // DEPRECATED: Use ResponderPaused() method on ChannelState ResponderPaused - // BothPaused means both sender and receiver have paused the channel seperately (both must unpause) + // DEPRECATED: Use BothPaused() method on ChannelState BothPaused // ResponderFinalizing is a unique state where the responder is awaiting a final voucher @@ -58,8 +60,83 @@ const ( // ChannelNotFoundError means the searched for data transfer does not exist ChannelNotFoundError + + // Queued indicates a data transfer request has been accepted, but is not actively transfering yet + Queued + + // AwaitingAcceptance indicates a transfer request is actively being processed by the transport + // even if the remote has not yet responded that it's accepted the transfer. Such a state can + // occur, for example, in a requestor-initiated transfer that starts processing prior to receiving + // acceptance from the server. + AwaitingAcceptance ) +type statusList []Status + +func (sl statusList) Contains(s Status) bool { + for _, ts := range sl { + if ts == s { + return true + } + } + return false +} + +func (sl statusList) AsFSMStates() []fsm.StateKey { + sk := make([]fsm.StateKey, 0, len(sl)) + for _, s := range sl { + sk = append(sk, s) + } + return sk +} + +var NotAcceptedStates = statusList{ + Requested, + AwaitingAcceptance, + Cancelled, + Cancelling, + Failed, + Failing, + ChannelNotFoundError} + +func (s Status) IsAccepted() bool { + return !NotAcceptedStates.Contains(s) +} + +var FinalizationStatuses = statusList{Finalizing, Completed, Completing} + +func (s Status) InFinalization() bool { + return FinalizationStatuses.Contains(s) +} + +var TransferCompleteStates = statusList{ + TransferFinished, + ResponderFinalizingTransferFinished, + Finalizing, + Completed, + Completing, + Failing, + Failed, + Cancelling, + Cancelled, + ChannelNotFoundError, +} + +func (s Status) TransferComplete() bool { + return TransferCompleteStates.Contains(s) +} + +var TransferringStates = statusList{ + Ongoing, + ResponderCompleted, + ResponderFinalizing, + AwaitingAcceptance, +} + +func (s Status) Transferring() bool { + return TransferringStates.Contains(s) +} + // Statuses are human readable names for data transfer states var Statuses = map[Status]string{ // Requested means a data transfer was requested by has not yet been approved @@ -80,4 +157,6 @@ var Statuses = map[Status]string{ ResponderFinalizing: "ResponderFinalizing", ResponderFinalizingTransferFinished: "ResponderFinalizingTransferFinished", ChannelNotFoundError: "ChannelNotFoundError", + Queued: "Queued", + AwaitingAcceptance: "AwaitingAcceptance", } diff --git a/testutil/fakedttype.go b/testutil/fakedttype.go index 9fde52f0..03db9273 100644 --- a/testutil/fakedttype.go +++ b/testutil/fakedttype.go @@ -3,71 +3,80 @@ package testutil import ( "testing" + "github.com/ipld/go-ipld-prime" + "github.com/ipld/go-ipld-prime/datamodel" + "github.com/ipld/go-ipld-prime/fluent/qp" + basicnode "github.com/ipld/go-ipld-prime/node/basic" "github.com/stretchr/testify/require" - datatransfer "github.com/filecoin-project/go-data-transfer" - "github.com/filecoin-project/go-data-transfer/encoding" + datatransfer "github.com/filecoin-project/go-data-transfer/v2" ) -//go:generate cbor-gen-for FakeDTType +const TestVoucherType = datatransfer.TypeIdentifier("TestVoucher") -// FakeDTType simple fake type for using with registries -type FakeDTType struct { - Data string -} - -// Type satisfies registry.Entry -func (ft FakeDTType) Type() datatransfer.TypeIdentifier { - return "FakeDTType" -} - -// AssertFakeDTVoucher asserts that a data transfer requests contains the expected fake data transfer voucher type -func AssertFakeDTVoucher(t *testing.T, request datatransfer.Request, expected *FakeDTType) { - require.Equal(t, datatransfer.TypeIdentifier("FakeDTType"), request.VoucherType()) - fakeDTDecoder, err := encoding.NewDecoder(&FakeDTType{}) +// AssertTestVoucher asserts that a data transfer requests contains the expected fake data transfer voucher type +func AssertTestVoucher(t *testing.T, request datatransfer.Request, expected datatransfer.TypedVoucher) { + require.Equal(t, expected.Type, request.VoucherType()) + voucher, err := request.Voucher() require.NoError(t, err) - decoded, err := request.Voucher(fakeDTDecoder) - require.NoError(t, err) - require.Equal(t, expected, decoded) + require.True(t, ipld.DeepEqual(expected.Voucher, voucher)) } -// AssertEqualFakeDTVoucher asserts that two requests have the same fake data transfer voucher -func AssertEqualFakeDTVoucher(t *testing.T, expectedRequest datatransfer.Request, request datatransfer.Request) { +// AssertEqualTestVoucher asserts that two requests have the same fake data transfer voucher +func AssertEqualTestVoucher(t *testing.T, expectedRequest datatransfer.Request, request datatransfer.Request) { require.Equal(t, expectedRequest.VoucherType(), request.VoucherType()) - fakeDTDecoder, err := encoding.NewDecoder(&FakeDTType{}) - require.NoError(t, err) - expectedDecoded, err := request.Voucher(fakeDTDecoder) + require.Equal(t, TestVoucherType, request.VoucherType()) + expected, err := expectedRequest.Voucher() require.NoError(t, err) - decoded, err := request.Voucher(fakeDTDecoder) + actual, err := request.Voucher() require.NoError(t, err) - require.Equal(t, expectedDecoded, decoded) + require.True(t, ipld.DeepEqual(expected, actual)) } -// AssertFakeDTVoucherResult asserts that a data transfer response contains the expected fake data transfer voucher result type -func AssertFakeDTVoucherResult(t *testing.T, response datatransfer.Response, expected *FakeDTType) { - require.Equal(t, datatransfer.TypeIdentifier("FakeDTType"), response.VoucherResultType()) - fakeDTDecoder, err := encoding.NewDecoder(&FakeDTType{}) - require.NoError(t, err) - decoded, err := response.VoucherResult(fakeDTDecoder) +// AssertTestVoucherResult asserts that a data transfer response contains the expected fake data transfer voucher result type +func AssertTestVoucherResult(t *testing.T, response datatransfer.Response, expected datatransfer.TypedVoucher) { + require.Equal(t, expected.Type, response.VoucherResultType()) + voucherResult, err := response.VoucherResult() require.NoError(t, err) - require.Equal(t, expected, decoded) + require.True(t, ipld.DeepEqual(expected.Voucher, voucherResult)) } -// AssertEqualFakeDTVoucherResult asserts that two responses have the same fake data transfer voucher result -func AssertEqualFakeDTVoucherResult(t *testing.T, expectedResponse datatransfer.Response, response datatransfer.Response) { +// AssertEqualTestVoucherResult asserts that two responses have the same fake data transfer voucher result +func AssertEqualTestVoucherResult(t *testing.T, expectedResponse datatransfer.Response, response datatransfer.Response) { require.Equal(t, expectedResponse.VoucherResultType(), response.VoucherResultType()) - fakeDTDecoder, err := encoding.NewDecoder(&FakeDTType{}) + expectedVoucherResult, err := expectedResponse.VoucherResult() require.NoError(t, err) - expectedDecoded, err := response.VoucherResult(fakeDTDecoder) + actualVoucherResult, err := response.VoucherResult() require.NoError(t, err) - decoded, err := response.VoucherResult(fakeDTDecoder) - require.NoError(t, err) - require.Equal(t, expectedDecoded, decoded) + require.True(t, ipld.DeepEqual(expectedVoucherResult, actualVoucherResult)) +} + +// NewTestVoucher returns a fake voucher with random data +func NewTestVoucher() datamodel.Node { + n, err := qp.BuildList(basicnode.Prototype.Any, 1, func(ma datamodel.ListAssembler) { + qp.ListEntry(ma, qp.String(string(RandomBytes(100)))) + }) + if err != nil { + panic(err) + } + return n } -// NewFakeDTType returns a fake dt type with random data -func NewFakeDTType() *FakeDTType { - return &FakeDTType{Data: string(RandomBytes(100))} +func NewTestTypedVoucher() datatransfer.TypedVoucher { + return datatransfer.TypedVoucher{Voucher: NewTestVoucher(), Type: TestVoucherType} } -var _ datatransfer.Registerable = &FakeDTType{} +// NewTestVoucher returns a fake voucher with random data +func NewTestVoucherWith(data string) datamodel.Node { + n, err := qp.BuildList(basicnode.Prototype.Any, 1, func(ma datamodel.ListAssembler) { + qp.ListEntry(ma, qp.String(data)) + }) + if err != nil { + panic(err) + } + return n +} + +func NewTestTypedVoucherWith(data string) datatransfer.TypedVoucher { + return datatransfer.TypedVoucher{Voucher: NewTestVoucherWith(data), Type: TestVoucherType} +} diff --git a/testutil/fakedttype_cbor_gen.go b/testutil/fakedttype_cbor_gen.go deleted file mode 100644 index d7913605..00000000 --- a/testutil/fakedttype_cbor_gen.go +++ /dev/null @@ -1,75 +0,0 @@ -// Code generated by github.com/whyrusleeping/cbor-gen. DO NOT EDIT. - -package testutil - -import ( - "fmt" - "io" - "sort" - - cid "github.com/ipfs/go-cid" - cbg "github.com/whyrusleeping/cbor-gen" - xerrors "golang.org/x/xerrors" -) - -var _ = xerrors.Errorf -var _ = cid.Undef -var _ = sort.Sort - -var lengthBufFakeDTType = []byte{129} - -func (t *FakeDTType) MarshalCBOR(w io.Writer) error { - if t == nil { - _, err := w.Write(cbg.CborNull) - return err - } - if _, err := w.Write(lengthBufFakeDTType); err != nil { - return err - } - - scratch := make([]byte, 9) - - // t.Data (string) (string) - if len(t.Data) > cbg.MaxLength { - return xerrors.Errorf("Value in field t.Data was too long") - } - - if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len(t.Data))); err != nil { - return err - } - if _, err := io.WriteString(w, string(t.Data)); err != nil { - return err - } - return nil -} - -func (t *FakeDTType) UnmarshalCBOR(r io.Reader) error { - *t = FakeDTType{} - - br := cbg.GetPeeker(r) - scratch := make([]byte, 8) - - maj, extra, err := cbg.CborReadHeaderBuf(br, scratch) - if err != nil { - return err - } - if maj != cbg.MajArray { - return fmt.Errorf("cbor input should be of type array") - } - - if extra != 1 { - return fmt.Errorf("cbor input had wrong number of fields") - } - - // t.Data (string) (string) - - { - sval, err := cbg.ReadStringBuf(br, scratch) - if err != nil { - return err - } - - t.Data = string(sval) - } - return nil -} diff --git a/testutil/faketransport.go b/testutil/faketransport.go index ecada472..58726448 100644 --- a/testutil/faketransport.go +++ b/testutil/faketransport.go @@ -3,24 +3,23 @@ package testutil import ( "context" - "github.com/ipld/go-ipld-prime" - "github.com/libp2p/go-libp2p-core/peer" - - datatransfer "github.com/filecoin-project/go-data-transfer" + datatransfer "github.com/filecoin-project/go-data-transfer/v2" ) // OpenedChannel records a call to open a channel type OpenedChannel struct { - DataSender peer.ID - ChannelID datatransfer.ChannelID - Root ipld.Link - Selector ipld.Node - Channel datatransfer.ChannelState - Message datatransfer.Message + Channel datatransfer.Channel + Message datatransfer.Request +} + +// RestartedChannel records a call to restart a channel +type RestartedChannel struct { + Channel datatransfer.ChannelState + Message datatransfer.Request } -// ResumedChannel records a call to resume a channel -type ResumedChannel struct { +// MessageSent records a message sent +type MessageSent struct { ChannelID datatransfer.ChannelID Message datatransfer.Message } @@ -28,19 +27,20 @@ type ResumedChannel struct { // CustomizedTransfer is just a way to record calls made to transport configurer type CustomizedTransfer struct { ChannelID datatransfer.ChannelID - Voucher datatransfer.Voucher + Voucher datatransfer.TypedVoucher } +var _ datatransfer.Transport = &FakeTransport{} + // FakeTransport is a fake transport with mocked results type FakeTransport struct { OpenedChannels []OpenedChannel OpenChannelErr error - ClosedChannels []datatransfer.ChannelID - CloseChannelErr error - PausedChannels []datatransfer.ChannelID - PauseChannelErr error - ResumedChannels []ResumedChannel - ResumeChannelErr error + RestartedChannels []RestartedChannel + RestartChannelErr error + MessagesSent []MessageSent + UpdateError error + ChannelsUpdated []datatransfer.ChannelID CleanedUpChannels []datatransfer.ChannelID CustomizedTransfers []CustomizedTransfer EventHandler datatransfer.EventsHandler @@ -52,20 +52,54 @@ func NewFakeTransport() *FakeTransport { return &FakeTransport{} } +// ID is a unique identifier for this transport +func (ft *FakeTransport) ID() datatransfer.TransportID { + return "fake" +} + +// Versions indicates what versions of this transport are supported +func (ft *FakeTransport) Versions() []datatransfer.Version { + return []datatransfer.Version{{Major: 1, Minor: 1, Patch: 0}} +} + +// Capabilities tells datatransfer what kinds of capabilities this transport supports +func (ft *FakeTransport) Capabilities() datatransfer.TransportCapabilities { + return datatransfer.TransportCapabilities{ + Restartable: true, + Pausable: true, + } +} + // OpenChannel initiates an outgoing request for the other peer to send data // to us on this channel // Note: from a data transfer symantic standpoint, it doesn't matter if the // request is push or pull -- OpenChannel is called by the party that is // intending to receive data -func (ft *FakeTransport) OpenChannel(ctx context.Context, dataSender peer.ID, channelID datatransfer.ChannelID, root ipld.Link, stor ipld.Node, channel datatransfer.ChannelState, msg datatransfer.Message) error { - ft.OpenedChannels = append(ft.OpenedChannels, OpenedChannel{dataSender, channelID, root, stor, channel, msg}) +func (ft *FakeTransport) OpenChannel(ctx context.Context, channel datatransfer.Channel, msg datatransfer.Request) error { + ft.OpenedChannels = append(ft.OpenedChannels, OpenedChannel{channel, msg}) return ft.OpenChannelErr } -// CloseChannel closes the given channel -func (ft *FakeTransport) CloseChannel(ctx context.Context, chid datatransfer.ChannelID) error { - ft.ClosedChannels = append(ft.ClosedChannels, chid) - return ft.CloseChannelErr +// RestartChannel restarts a channel +func (ft *FakeTransport) RestartChannel(ctx context.Context, channelState datatransfer.ChannelState, msg datatransfer.Request) error { + ft.RestartedChannels = append(ft.RestartedChannels, RestartedChannel{channelState, msg}) + return ft.RestartChannelErr +} + +// WithChannel takes actions on a channel +func (ft *FakeTransport) ChannelUpdated(ctx context.Context, chid datatransfer.ChannelID, msg datatransfer.Message) error { + + if msg != nil { + ft.MessagesSent = append(ft.MessagesSent, MessageSent{chid, msg}) + } + ft.ChannelsUpdated = append(ft.ChannelsUpdated, chid) + return nil +} + +// SendMessage sends a data transfer message over the channel to the other peer +func (ft *FakeTransport) SendMessage(ctx context.Context, chid datatransfer.ChannelID, msg datatransfer.Message) error { + ft.MessagesSent = append(ft.MessagesSent, MessageSent{chid, msg}) + return ft.UpdateError } // SetEventHandler sets the handler for events on channels @@ -74,27 +108,16 @@ func (ft *FakeTransport) SetEventHandler(events datatransfer.EventsHandler) erro return ft.SetEventHandlerErr } +// Shutdown close this transport func (ft *FakeTransport) Shutdown(ctx context.Context) error { return nil } -// PauseChannel paused the given channel ID -func (ft *FakeTransport) PauseChannel(ctx context.Context, chid datatransfer.ChannelID) error { - ft.PausedChannels = append(ft.PausedChannels, chid) - return ft.PauseChannelErr -} - -// ResumeChannel resumes the given channel -func (ft *FakeTransport) ResumeChannel(ctx context.Context, msg datatransfer.Message, chid datatransfer.ChannelID) error { - ft.ResumedChannels = append(ft.ResumedChannels, ResumedChannel{chid, msg}) - return ft.ResumeChannelErr -} - // CleanupChannel cleans up the given channel func (ft *FakeTransport) CleanupChannel(chid datatransfer.ChannelID) { ft.CleanedUpChannels = append(ft.CleanedUpChannels, chid) } -func (ft *FakeTransport) RecordCustomizedTransfer(chid datatransfer.ChannelID, voucher datatransfer.Voucher) { +func (ft *FakeTransport) RecordCustomizedTransfer(chid datatransfer.ChannelID, voucher datatransfer.TypedVoucher) { ft.CustomizedTransfers = append(ft.CustomizedTransfers, CustomizedTransfer{chid, voucher}) } diff --git a/testutil/message.go b/testutil/message.go deleted file mode 100644 index eefc7551..00000000 --- a/testutil/message.go +++ /dev/null @@ -1,30 +0,0 @@ -package testutil - -import ( - "testing" - - basicnode "github.com/ipld/go-ipld-prime/node/basic" - "github.com/ipld/go-ipld-prime/traversal/selector/builder" - "github.com/stretchr/testify/require" - - datatransfer "github.com/filecoin-project/go-data-transfer" - "github.com/filecoin-project/go-data-transfer/message" -) - -// NewDTRequest makes a new DT Request message -func NewDTRequest(t *testing.T, transferID datatransfer.TransferID) datatransfer.Request { - voucher := NewFakeDTType() - baseCid := GenerateCids(1)[0] - selector := builder.NewSelectorSpecBuilder(basicnode.Prototype.Any).Matcher().Node() - r, err := message.NewRequest(transferID, false, false, voucher.Type(), voucher, baseCid, selector) - require.NoError(t, err) - return r -} - -// NewDTResponse makes a new DT Request message -func NewDTResponse(t *testing.T, transferID datatransfer.TransferID) datatransfer.Response { - vresult := NewFakeDTType() - r, err := message.NewResponse(transferID, false, false, vresult.Type(), vresult) - require.NoError(t, err) - return r -} diff --git a/testutil/mockchannelstate.go b/testutil/mockchannelstate.go new file mode 100644 index 00000000..92fc5f06 --- /dev/null +++ b/testutil/mockchannelstate.go @@ -0,0 +1,252 @@ +package testutil + +import ( + cid "github.com/ipfs/go-cid" + "github.com/ipld/go-ipld-prime" + "github.com/ipld/go-ipld-prime/datamodel" + "github.com/libp2p/go-libp2p-core/peer" + + datatransfer "github.com/filecoin-project/go-data-transfer/v2" +) + +type MockChannelStateParams struct { + ReceivedIndex datamodel.Node + SentIndex datamodel.Node + QueuedIndex datamodel.Node + ChannelID datatransfer.ChannelID + Queued uint64 + Sent uint64 + Received uint64 + Complete bool + BaseCID cid.Cid + Selector ipld.Node + Voucher datatransfer.TypedVoucher + IsPull bool + Self peer.ID + DataLimit uint64 + InitiatorPaused bool + ResponderPaused bool +} + +func NewMockChannelState(params MockChannelStateParams) *MockChannelState { + return &MockChannelState{ + receivedIndex: params.ReceivedIndex, + sentIndex: params.SentIndex, + queuedIndex: params.QueuedIndex, + dataLimit: params.DataLimit, + chid: params.ChannelID, + queued: params.Queued, + sent: params.Sent, + received: params.Received, + complete: params.Complete, + isPull: params.IsPull, + self: params.Self, + baseCID: params.BaseCID, + initiatorPaused: params.InitiatorPaused, + responderPaused: params.ResponderPaused, + } +} + +type MockChannelState struct { + receivedIndex datamodel.Node + sentIndex datamodel.Node + queuedIndex datamodel.Node + dataLimit uint64 + chid datatransfer.ChannelID + queued uint64 + sent uint64 + received uint64 + complete bool + isPull bool + baseCID cid.Cid + selector ipld.Node + voucher datatransfer.TypedVoucher + self peer.ID + initiatorPaused bool + responderPaused bool +} + +var _ datatransfer.ChannelState = (*MockChannelState)(nil) + +func (m *MockChannelState) Queued() uint64 { + return m.queued +} + +func (m *MockChannelState) SetQueued(queued uint64) { + m.queued = queued +} + +func (m *MockChannelState) Sent() uint64 { + return m.sent +} + +func (m *MockChannelState) SetSent(sent uint64) { + m.sent = sent +} + +func (m *MockChannelState) Received() uint64 { + return m.received +} + +func (m *MockChannelState) SetReceived(received uint64) { + m.received = received +} + +func (m *MockChannelState) ChannelID() datatransfer.ChannelID { + return m.chid +} + +func (m *MockChannelState) SetComplete(complete bool) { + m.complete = complete +} +func (m *MockChannelState) Status() datatransfer.Status { + if m.complete { + return datatransfer.Completed + } + return datatransfer.Ongoing +} + +func (m *MockChannelState) SetReceivedIndex(receivedIndex datamodel.Node) { + m.receivedIndex = receivedIndex +} + +func (m *MockChannelState) ReceivedIndex() datamodel.Node { + if m.receivedIndex == nil { + return datamodel.Null + } + return m.receivedIndex +} + +func (m *MockChannelState) QueuedIndex() datamodel.Node { + if m.queuedIndex == nil { + return datamodel.Null + } + return m.queuedIndex +} + +func (m *MockChannelState) SetQueuedIndex(queuedIndex datamodel.Node) { + m.queuedIndex = queuedIndex +} + +func (m *MockChannelState) SentIndex() datamodel.Node { + if m.sentIndex == nil { + return datamodel.Null + } + return m.sentIndex +} + +func (m *MockChannelState) SetSentIndex(sentIndex datamodel.Node) { + m.sentIndex = sentIndex +} + +func (m *MockChannelState) TransferID() datatransfer.TransferID { + return m.chid.ID +} + +func (m *MockChannelState) BaseCID() cid.Cid { + return m.baseCID +} + +func (m *MockChannelState) Selector() datamodel.Node { + return m.selector +} + +func (m *MockChannelState) Voucher() datatransfer.TypedVoucher { + return m.voucher +} + +func (m *MockChannelState) Sender() peer.ID { + if m.isPull { + return m.chid.Responder + } + return m.chid.Initiator +} + +func (m *MockChannelState) Recipient() peer.ID { + if m.isPull { + return m.chid.Initiator + } + return m.chid.Responder +} + +func (m *MockChannelState) TotalSize() uint64 { + panic("implement me") +} + +func (m *MockChannelState) IsPull() bool { + return m.isPull +} + +func (m *MockChannelState) OtherPeer() peer.ID { + if m.self == m.chid.Initiator { + return m.chid.Responder + } + return m.chid.Initiator +} + +func (m *MockChannelState) SelfPeer() peer.ID { + return m.self +} + +func (m *MockChannelState) Message() string { + panic("implement me") +} + +func (m *MockChannelState) Vouchers() []datatransfer.TypedVoucher { + panic("implement me") +} + +func (m *MockChannelState) VoucherResults() []datatransfer.TypedVoucher { + panic("implement me") +} + +func (m *MockChannelState) LastVoucher() datatransfer.TypedVoucher { + panic("implement me") +} + +func (m *MockChannelState) LastVoucherResult() datatransfer.TypedVoucher { + panic("implement me") +} + +func (m *MockChannelState) Stages() *datatransfer.ChannelStages { + panic("implement me") +} + +func (m *MockChannelState) SetDataLimit(dataLimit uint64) { + m.dataLimit = dataLimit +} + +func (m *MockChannelState) DataLimit() uint64 { + return m.dataLimit +} + +func (m *MockChannelState) RequiresFinalization() bool { + panic("implement me") +} + +func (m *MockChannelState) SetResponderPaused(responderPaused bool) { + m.responderPaused = responderPaused +} + +func (m *MockChannelState) ResponderPaused() bool { + return m.responderPaused +} + +func (m *MockChannelState) SetInitiatorPaused(initiatorPaused bool) { + m.initiatorPaused = initiatorPaused +} + +func (m *MockChannelState) InitiatorPaused() bool { + return m.initiatorPaused +} + +func (m *MockChannelState) BothPaused() bool { + return m.initiatorPaused && m.responderPaused +} + +func (m *MockChannelState) SelfPaused() bool { + if m.self == m.chid.Initiator { + return m.initiatorPaused + } + return m.responderPaused +} diff --git a/testutil/stubbedvalidator.go b/testutil/stubbedvalidator.go index 8cef087e..1647bdc0 100644 --- a/testutil/stubbedvalidator.go +++ b/testutil/stubbedvalidator.go @@ -5,11 +5,11 @@ import ( "testing" "github.com/ipfs/go-cid" - "github.com/ipld/go-ipld-prime" + "github.com/ipld/go-ipld-prime/datamodel" "github.com/libp2p/go-libp2p-core/peer" "github.com/stretchr/testify/require" - datatransfer "github.com/filecoin-project/go-data-transfer" + datatransfer "github.com/filecoin-project/go-data-transfer/v2" ) // NewStubbedValidator returns a new instance of a stubbed validator @@ -19,12 +19,11 @@ func NewStubbedValidator() *StubbedValidator { // ValidatePush returns a stubbed result for a push validation func (sv *StubbedValidator) ValidatePush( - isRestart bool, chid datatransfer.ChannelID, sender peer.ID, - voucher datatransfer.Voucher, + voucher datamodel.Node, baseCid cid.Cid, - selector ipld.Node) (datatransfer.VoucherResult, error) { + selector datamodel.Node) (datatransfer.ValidationResult, error) { sv.didPush = true sv.ValidationsReceived = append(sv.ValidationsReceived, ReceivedValidation{false, sender, voucher, baseCid, selector}) return sv.result, sv.pushError @@ -32,19 +31,18 @@ func (sv *StubbedValidator) ValidatePush( // ValidatePull returns a stubbed result for a pull validation func (sv *StubbedValidator) ValidatePull( - isRestart bool, chid datatransfer.ChannelID, receiver peer.ID, - voucher datatransfer.Voucher, + voucher datamodel.Node, baseCid cid.Cid, - selector ipld.Node) (datatransfer.VoucherResult, error) { + selector datamodel.Node) (datatransfer.ValidationResult, error) { sv.didPull = true sv.ValidationsReceived = append(sv.ValidationsReceived, ReceivedValidation{true, receiver, voucher, baseCid, selector}) return sv.result, sv.pullError } // StubResult returns thes given voucher result when a validate call is made -func (sv *StubbedValidator) StubResult(voucherResult datatransfer.VoucherResult) { +func (sv *StubbedValidator) StubResult(voucherResult datatransfer.ValidationResult) { sv.result = voucherResult } @@ -58,11 +56,6 @@ func (sv *StubbedValidator) StubSuccessPush() { sv.pushError = nil } -// StubPausePush sets ValidatePush to pause -func (sv *StubbedValidator) StubPausePush() { - sv.pushError = datatransfer.ErrPause -} - // ExpectErrorPush expects ValidatePush to error func (sv *StubbedValidator) ExpectErrorPush() { sv.expectPush = true @@ -75,12 +68,6 @@ func (sv *StubbedValidator) ExpectSuccessPush() { sv.StubSuccessPush() } -// ExpectPausePush expects ValidatePush to pause -func (sv *StubbedValidator) ExpectPausePush() { - sv.expectPush = true - sv.StubPausePush() -} - // StubErrorPull sets ValidatePull to error func (sv *StubbedValidator) StubErrorPull() { sv.pullError = errors.New("something went wrong") @@ -91,11 +78,6 @@ func (sv *StubbedValidator) StubSuccessPull() { sv.pullError = nil } -// StubPausePull sets ValidatePull to pause -func (sv *StubbedValidator) StubPausePull() { - sv.pullError = datatransfer.ErrPause -} - // ExpectErrorPull expects ValidatePull to error func (sv *StubbedValidator) ExpectErrorPull() { sv.expectPull = true @@ -108,12 +90,6 @@ func (sv *StubbedValidator) ExpectSuccessPull() { sv.StubSuccessPull() } -// ExpectPausePull expects ValidatePull to pause -func (sv *StubbedValidator) ExpectPausePull() { - sv.expectPull = true - sv.StubPausePull() -} - // VerifyExpectations verifies the specified calls were made func (sv *StubbedValidator) VerifyExpectations(t *testing.T) { if sv.expectPush { @@ -122,238 +98,74 @@ func (sv *StubbedValidator) VerifyExpectations(t *testing.T) { if sv.expectPull { require.True(t, sv.didPull) } + if sv.expectRevalidate { + require.True(t, sv.didRevalidate) + } } -// ReceivedValidation records a call to either ValidatePush or ValidatePull -type ReceivedValidation struct { - IsPull bool - Other peer.ID - Voucher datatransfer.Voucher - BaseCid cid.Cid - Selector ipld.Node -} - -// StubbedValidator is a validator that returns predictable results -type StubbedValidator struct { - result datatransfer.VoucherResult - didPush bool - didPull bool - expectPush bool - expectPull bool - pushError error - pullError error - ValidationsReceived []ReceivedValidation -} - -// StubbedRevalidator is a revalidator that returns predictable results -type StubbedRevalidator struct { - revalidationResult datatransfer.VoucherResult - checkResult datatransfer.VoucherResult - didRevalidate bool - didPushCheck bool - didPullCheck bool - didComplete bool - expectRevalidate bool - expectPushCheck bool - expectPullCheck bool - expectComplete bool - revalidationError error - pushCheckError error - pullCheckError error - completeError error -} - -// NewStubbedRevalidator returns a new instance of a stubbed revalidator -func NewStubbedRevalidator() *StubbedRevalidator { - return &StubbedRevalidator{} -} - -// OnPullDataSent returns a stubbed result for checking when pull data is sent -func (srv *StubbedRevalidator) OnPullDataSent(chid datatransfer.ChannelID, additionalBytesSent uint64) (bool, datatransfer.VoucherResult, error) { - srv.didPullCheck = true - return srv.expectPullCheck, srv.revalidationResult, srv.pullCheckError -} - -// OnPushDataReceived returns a stubbed result for checking when push data is received -func (srv *StubbedRevalidator) OnPushDataReceived(chid datatransfer.ChannelID, additionalBytesReceived uint64) (bool, datatransfer.VoucherResult, error) { - srv.didPushCheck = true - return srv.expectPushCheck, srv.revalidationResult, srv.pushCheckError -} - -// OnComplete returns a stubbed result for checking when the requests completes -func (srv *StubbedRevalidator) OnComplete(chid datatransfer.ChannelID) (bool, datatransfer.VoucherResult, error) { - srv.didComplete = true - return srv.expectComplete, srv.revalidationResult, srv.completeError -} - -// Revalidate returns a stubbed result for revalidating a request -func (srv *StubbedRevalidator) Revalidate(chid datatransfer.ChannelID, voucher datatransfer.Voucher) (datatransfer.VoucherResult, error) { - srv.didRevalidate = true - return srv.checkResult, srv.revalidationError -} - -// StubRevalidationResult returns the given voucher result when a call is made to Revalidate -func (srv *StubbedRevalidator) StubRevalidationResult(voucherResult datatransfer.VoucherResult) { - srv.revalidationResult = voucherResult -} - -// StubCheckResult returns the given voucher result when a call is made to -// OnPullDataSent, OnPushDataReceived, or OnComplete -func (srv *StubbedRevalidator) StubCheckResult(voucherResult datatransfer.VoucherResult) { - srv.checkResult = voucherResult -} - -// StubErrorPushCheck sets OnPushDataReceived to error -func (srv *StubbedRevalidator) StubErrorPushCheck() { - srv.pushCheckError = errors.New("something went wrong") -} - -// StubSuccessPushCheck sets OnPushDataReceived to succeed -func (srv *StubbedRevalidator) StubSuccessPushCheck() { - srv.pushCheckError = nil -} - -// StubPausePushCheck sets OnPushDataReceived to pause -func (srv *StubbedRevalidator) StubPausePushCheck() { - srv.pushCheckError = datatransfer.ErrPause -} - -// ExpectErrorPushCheck expects OnPushDataReceived to error -func (srv *StubbedRevalidator) ExpectErrorPushCheck() { - srv.expectPushCheck = true - srv.StubErrorPushCheck() -} - -// ExpectSuccessPushCheck expects OnPushDataReceived to succeed -func (srv *StubbedRevalidator) ExpectSuccessPushCheck() { - srv.expectPushCheck = true - srv.StubSuccessPushCheck() -} - -// ExpectPausePushCheck expects OnPushDataReceived to pause -func (srv *StubbedRevalidator) ExpectPausePushCheck() { - srv.expectPushCheck = true - srv.StubPausePushCheck() -} - -// StubErrorPullCheck sets OnPullDataSent to error -func (srv *StubbedRevalidator) StubErrorPullCheck() { - srv.pullCheckError = errors.New("something went wrong") -} - -// StubSuccessPullCheck sets OnPullDataSent to succeed -func (srv *StubbedRevalidator) StubSuccessPullCheck() { - srv.pullCheckError = nil -} - -// StubPausePullCheck sets OnPullDataSent to pause -func (srv *StubbedRevalidator) StubPausePullCheck() { - srv.pullCheckError = datatransfer.ErrPause -} - -// ExpectErrorPullCheck expects OnPullDataSent to error -func (srv *StubbedRevalidator) ExpectErrorPullCheck() { - srv.expectPullCheck = true - srv.StubErrorPullCheck() -} - -// ExpectSuccessPullCheck expects OnPullDataSent to succeed -func (srv *StubbedRevalidator) ExpectSuccessPullCheck() { - srv.expectPullCheck = true - srv.StubSuccessPullCheck() -} - -// ExpectPausePullCheck expects OnPullDataSent to pause -func (srv *StubbedRevalidator) ExpectPausePullCheck() { - srv.expectPullCheck = true - srv.StubPausePullCheck() -} - -// StubErrorComplete sets OnComplete to error -func (srv *StubbedRevalidator) StubErrorComplete() { - srv.completeError = errors.New("something went wrong") -} - -// StubSuccessComplete sets OnComplete to succeed -func (srv *StubbedRevalidator) StubSuccessComplete() { - srv.completeError = nil -} - -// StubPauseComplete sets OnComplete to pause -func (srv *StubbedRevalidator) StubPauseComplete() { - srv.completeError = datatransfer.ErrPause -} - -// ExpectErrorComplete expects OnComplete to error -func (srv *StubbedRevalidator) ExpectErrorComplete() { - srv.expectComplete = true - srv.StubErrorComplete() -} - -// ExpectSuccessComplete expects OnComplete to succeed -func (srv *StubbedRevalidator) ExpectSuccessComplete() { - srv.expectComplete = true - srv.StubSuccessComplete() -} - -// ExpectPauseComplete expects OnComplete to pause -func (srv *StubbedRevalidator) ExpectPauseComplete() { - srv.expectComplete = true - srv.StubPauseComplete() +func (sv *StubbedValidator) ValidateRestart(chid datatransfer.ChannelID, channelState datatransfer.ChannelState) (datatransfer.ValidationResult, error) { + sv.didRevalidate = true + sv.RevalidationsReceived = append(sv.RevalidationsReceived, ReceivedRestartValidation{chid, channelState}) + return sv.revalidationResult, sv.revalidationError } -// StubErrorRevalidation sets Revalidate to error -func (srv *StubbedRevalidator) StubErrorRevalidation() { - srv.revalidationError = errors.New("something went wrong") +// StubRestartResult returns the given voucher result when a call is made to ValidateRestart +func (sv *StubbedValidator) StubRestartResult(voucherResult datatransfer.ValidationResult) { + sv.revalidationResult = voucherResult } -// StubSuccessRevalidation sets Revalidate to succeed -func (srv *StubbedRevalidator) StubSuccessRevalidation() { - srv.revalidationError = nil +// StubErrorValidateRestart sets ValidateRestart to error +func (sv *StubbedValidator) StubErrorValidateRestart() { + sv.revalidationError = errors.New("something went wrong") } -// StubPauseRevalidation sets Revalidate to pause -func (srv *StubbedRevalidator) StubPauseRevalidation() { - srv.revalidationError = datatransfer.ErrPause +// StubSuccessValidateRestart sets ValidateRestart to succeed +func (sv *StubbedValidator) StubSuccessValidateRestart() { + sv.revalidationError = nil } -// ExpectSuccessErrResume configures Revalidate to return an ErrResume -// and expect a Revalidate call. -func (srv *StubbedRevalidator) ExpectSuccessErrResume() { - srv.expectRevalidate = true - srv.revalidationError = datatransfer.ErrResume +// ExpectErrorValidateRestart expects ValidateRestart to error +func (sv *StubbedValidator) ExpectErrorValidateRestart() { + sv.expectRevalidate = true + sv.StubErrorValidateRestart() } -// ExpectErrorRevalidation expects Revalidate to error -func (srv *StubbedRevalidator) ExpectErrorRevalidation() { - srv.expectRevalidate = true - srv.StubErrorRevalidation() +// ExpectSuccessValidateRestart expects ValidateRestart to succeed +func (sv *StubbedValidator) ExpectSuccessValidateRestart() { + sv.expectRevalidate = true + sv.StubSuccessValidateRestart() } -// ExpectSuccessRevalidation expects Revalidate to succeed -func (srv *StubbedRevalidator) ExpectSuccessRevalidation() { - srv.expectRevalidate = true - srv.StubSuccessRevalidation() +// ReceivedValidation records a call to either ValidatePush or ValidatePull +type ReceivedValidation struct { + IsPull bool + Other peer.ID + Voucher datamodel.Node + BaseCid cid.Cid + Selector datamodel.Node } -// ExpectPauseRevalidation expects Revalidate to pause -func (srv *StubbedRevalidator) ExpectPauseRevalidation() { - srv.expectRevalidate = true - srv.StubPauseRevalidation() +// ReceivedRestartValidation records a call to ValidateRestart +type ReceivedRestartValidation struct { + ChannelID datatransfer.ChannelID + ChannelState datatransfer.ChannelState } -// VerifyExpectations verifies the specified calls were made -func (srv *StubbedRevalidator) VerifyExpectations(t *testing.T) { - if srv.expectRevalidate { - require.True(t, srv.didRevalidate) - } - if srv.expectPushCheck { - require.True(t, srv.didPushCheck) - } - if srv.expectPullCheck { - require.True(t, srv.didPullCheck) - } - if srv.expectComplete { - require.True(t, srv.didComplete) - } -} +// StubbedValidator is a validator that returns predictable results +type StubbedValidator struct { + result datatransfer.ValidationResult + revalidationResult datatransfer.ValidationResult + expectRevalidate bool + didRevalidate bool + didPush bool + didPull bool + expectPush bool + expectPull bool + pushError error + pullError error + revalidationError error + ValidationsReceived []ReceivedValidation + RevalidationsReceived []ReceivedRestartValidation +} + +var _ datatransfer.RequestValidator = (*StubbedValidator)(nil) diff --git a/testutil/testnet.go b/testutil/testnet.go deleted file mode 100644 index 9b99c9b6..00000000 --- a/testutil/testnet.go +++ /dev/null @@ -1,71 +0,0 @@ -package testutil - -import ( - "context" - - "github.com/libp2p/go-libp2p-core/peer" - "github.com/libp2p/go-libp2p-core/protocol" - - datatransfer "github.com/filecoin-project/go-data-transfer" - "github.com/filecoin-project/go-data-transfer/network" -) - -// FakeSentMessage is a recording of a message sent on the FakeNetwork -type FakeSentMessage struct { - PeerID peer.ID - Message datatransfer.Message -} - -// FakeNetwork is a network that satisfies the DataTransferNetwork interface but -// does not actually do anything -type FakeNetwork struct { - PeerID peer.ID - SentMessages []FakeSentMessage - Delegate network.Receiver -} - -// NewFakeNetwork returns a new fake data transfer network instance -func NewFakeNetwork(id peer.ID) *FakeNetwork { - return &FakeNetwork{PeerID: id} -} - -var _ network.DataTransferNetwork = (*FakeNetwork)(nil) - -// SendMessage sends a GraphSync message to a peer. -func (fn *FakeNetwork) SendMessage(ctx context.Context, p peer.ID, m datatransfer.Message) error { - fn.SentMessages = append(fn.SentMessages, FakeSentMessage{p, m}) - return nil -} - -// SetDelegate registers the Reciver to handle messages received from the -// network. -func (fn *FakeNetwork) SetDelegate(receiver network.Receiver) { - fn.Delegate = receiver -} - -// ConnectTo establishes a connection to the given peer -func (fn *FakeNetwork) ConnectTo(_ context.Context, _ peer.ID) error { - panic("not implemented") -} - -func (fn *FakeNetwork) ConnectWithRetry(ctx context.Context, p peer.ID) error { - panic("implement me") -} - -// ID returns a stubbed id for host of this network -func (fn *FakeNetwork) ID() peer.ID { - return fn.PeerID -} - -// Protect does nothing on the fake network -func (fn *FakeNetwork) Protect(id peer.ID, tag string) { -} - -// Unprotect does nothing on the fake network -func (fn *FakeNetwork) Unprotect(id peer.ID, tag string) bool { - return false -} - -func (fn *FakeNetwork) Protocol(ctx context.Context, id peer.ID) (protocol.ID, error) { - return datatransfer.ProtocolDataTransfer1_2, nil -} diff --git a/testutil/testutil.go b/testutil/testutil.go index 731ee3f8..0c143969 100644 --- a/testutil/testutil.go +++ b/testutil/testutil.go @@ -4,20 +4,17 @@ import ( "bytes" "context" "fmt" + "math/rand" "testing" blocks "github.com/ipfs/go-block-format" "github.com/ipfs/go-cid" blocksutil "github.com/ipfs/go-ipfs-blocksutil" - "github.com/ipld/go-ipld-prime" - basicnode "github.com/ipld/go-ipld-prime/node/basic" - "github.com/ipld/go-ipld-prime/traversal/selector" - "github.com/ipld/go-ipld-prime/traversal/selector/builder" "github.com/jbenet/go-random" "github.com/libp2p/go-libp2p-core/peer" "github.com/stretchr/testify/require" - datatransfer "github.com/filecoin-project/go-data-transfer" + datatransfer "github.com/filecoin-project/go-data-transfer/v2" ) var blockGenerator = blocksutil.NewBlockGenerator() @@ -101,13 +98,6 @@ func AssertEqualSelector(t *testing.T, expectedRequest datatransfer.Request, req require.Equal(t, expectedSelector, selector) } -// AllSelector just returns a new instance of a "whole dag selector" -func AllSelector() ipld.Node { - ssb := builder.NewSelectorSpecBuilder(basicnode.Prototype.Any) - return ssb.ExploreRecursive(selector.RecursionLimitNone(), - ssb.ExploreAll(ssb.ExploreRecursiveEdge())).Node() -} - // StartAndWaitForReady is a utility function to start a module and verify it reaches the ready state func StartAndWaitForReady(ctx context.Context, t *testing.T, manager datatransfer.Manager) { ready := make(chan error, 1) @@ -122,3 +112,9 @@ func StartAndWaitForReady(ctx context.Context, t *testing.T, manager datatransfe require.NoError(t, err) } } + +// GenerateChannelID generates a new data transfer channel id for use in tests +func GenerateChannelID() datatransfer.ChannelID { + p := GeneratePeers(2) + return datatransfer.ChannelID{Initiator: p[0], Responder: p[1], ID: datatransfer.TransferID(rand.Int31())} +} diff --git a/tracing/tracing.go b/tracing/tracing.go index 2ba36800..204793e6 100644 --- a/tracing/tracing.go +++ b/tracing/tracing.go @@ -8,7 +8,7 @@ import ( "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/trace" - datatransfer "github.com/filecoin-project/go-data-transfer" + datatransfer "github.com/filecoin-project/go-data-transfer/v2" ) type SpansIndex struct { diff --git a/tracing/tracing_test.go b/tracing/tracing_test.go index f8f8c00f..c342e24d 100644 --- a/tracing/tracing_test.go +++ b/tracing/tracing_test.go @@ -7,9 +7,9 @@ import ( "github.com/stretchr/testify/require" "golang.org/x/exp/rand" - datatransfer "github.com/filecoin-project/go-data-transfer" - "github.com/filecoin-project/go-data-transfer/testutil" - "github.com/filecoin-project/go-data-transfer/tracing" + datatransfer "github.com/filecoin-project/go-data-transfer/v2" + "github.com/filecoin-project/go-data-transfer/v2/testutil" + "github.com/filecoin-project/go-data-transfer/v2/tracing" ) func TestSpansIndex(t *testing.T) { diff --git a/transport.go b/transport.go index 6d0b99fa..91aebdfc 100644 --- a/transport.go +++ b/transport.go @@ -3,72 +3,113 @@ package datatransfer import ( "context" - ipld "github.com/ipld/go-ipld-prime" - peer "github.com/libp2p/go-libp2p-core/peer" + "github.com/ipld/go-ipld-prime/datamodel" ) -// EventsHandler are semantic data transfer events that happen as a result of graphsync hooks -type EventsHandler interface { - // OnChannelOpened is called when we send a request for data to the other - // peer on the given channel ID - // return values are: - // - error = ignore incoming data for this channel - OnChannelOpened(chid ChannelID) error - // OnResponseReceived is called when we receive a response to a request - // - nil = continue receiving data - // - error = cancel this request - OnResponseReceived(chid ChannelID, msg Response) error - // OnDataReceive is called when we receive data for the given channel ID - // return values are: - // - nil = proceed with sending data - // - error = cancel this request - // - err == ErrPause - pause this request - OnDataReceived(chid ChannelID, link ipld.Link, size uint64, index int64, unique bool) error - - // OnDataQueued is called when data is queued for sending for the given channel ID - // return values are: - // message = data transfer message along with data - // err = error - // - nil = proceed with sending data - // - error = cancel this request - // - err == ErrPause - pause this request - OnDataQueued(chid ChannelID, link ipld.Link, size uint64, index int64, unique bool) (Message, error) - - // OnDataSent is called when we send data for the given channel ID - OnDataSent(chid ChannelID, link ipld.Link, size uint64, index int64, unique bool) error - - // OnTransferQueued is called when a new data transfer request is queued in the transport layer. - OnTransferQueued(chid ChannelID) - - // OnRequestReceived is called when we receive a new request to send data - // for the given channel ID - // return values are: - // message = data transfer message along with reply - // err = error - // - nil = proceed with sending data - // - error = cancel this request - // - err == ErrPause - pause this request (only for new requests) - // - err == ErrResume - resume this request (only for update requests) - OnRequestReceived(chid ChannelID, msg Request) (Response, error) - // OnChannelCompleted is called when we finish transferring data for the given channel ID - // Error returns are logged but otherwise have no effect - OnChannelCompleted(chid ChannelID, err error) error +// TransportID identifies a unique transport +type TransportID string + +// LegacyTransportID is the only transport for the fil/data-transfer protocol -- +// i.e. graphsync +const LegacyTransportID TransportID = "graphsync" + +// LegacyTransportVersion is the only transport version for the fil/data-transfer protocol -- +// i.e. graphsync 1.0.0 +var LegacyTransportVersion Version = Version{1, 0, 0} - // OnRequestCancelled is called when a request we opened (with the given channel Id) to - // receive data is cancelled by us. - // Error returns are logged but otherwise have no effect - OnRequestCancelled(chid ChannelID, err error) error +type TransportEvent interface { + transportEvent() +} - // OnRequestDisconnected is called when a network error occurs trying to send a request - OnRequestDisconnected(chid ChannelID, err error) error +// TransportOpenedChannel occurs when the transport begins processing the +// request (prior to that it may simply be queued) -- only applies to initiator +type TransportOpenedChannel struct{} - // OnSendDataError is called when a network error occurs sending data - // at the transport layer - OnSendDataError(chid ChannelID, err error) error +// TransportInitiatedTransfer occurs when the transport actually begins sending/receiving data +type TransportInitiatedTransfer struct{} - // OnReceiveDataError is called when a network error occurs receiving data - // at the transport layer - OnReceiveDataError(chid ChannelID, err error) error +// TransportReceivedData occurs when we receive data for the given channel ID +// index is a transport dependent of serializing "here's where I am in this transport" +type TransportReceivedData struct { + Size uint64 + Index datamodel.Node +} + +// TransportSentData occurs when we send data for the given channel ID +// index is a transport dependent of serializing "here's where I am in this transport" +type TransportSentData struct { + Size uint64 + Index datamodel.Node +} + +// TransportQueuedData occurs when data is queued for sending for the given channel ID +// index is a transport dependent of serializing "here's where I am in this transport" +type TransportQueuedData struct { + Size uint64 + Index datamodel.Node +} + +// TransportReachedDataLimit occurs when a channel hits a previously set data limit +type TransportReachedDataLimit struct{} + +// TransportTransferCancelled occurs when a request we opened (with the given channel Id) to +// receive data is cancelled by us. +type TransportTransferCancelled struct { + ErrorMessage string +} + +// TransportErrorSendingData occurs when a network error occurs trying to send a request +type TransportErrorSendingData struct { + ErrorMessage string +} + +// TransportErrorReceivingData occurs when a network error occurs receiving data +// at the transport layer +type TransportErrorReceivingData struct { + ErrorMessage string +} + +// TransportCompletedTransfer occurs when we finish transferring data for the given channel ID +type TransportCompletedTransfer struct { + Success bool + ErrorMessage string +} + +type TransportReceivedRestartExistingChannelRequest struct{} + +// TransportErrorSendingMessage occurs when a network error occurs trying to send a request +type TransportErrorSendingMessage struct { + ErrorMessage string +} + +type TransportPaused struct{} + +type TransportResumed struct{} + +// EventsHandler are semantic data transfer events that happen as a result of transport events +type EventsHandler interface { + // ChannelState queries for the current channel state + ChannelState(ctx context.Context, chid ChannelID) (ChannelState, error) + + // OnTransportEvent is dispatched when an event occurs on the transport + // It MAY be dispatched asynchronously by the transport to the time the + // event occurs + // However, the other handler functions may ONLY be called on the same channel + // after all events are dispatched. In other words, the transport MUST allow + // the handler to process all events before calling the other functions which + // have a synchronous return + OnTransportEvent(chid ChannelID, event TransportEvent) + + // OnRequestReceived occurs when we receive a request for the given channel ID + // return values are a message to send an error if the transport should be closed + // TODO: in a future improvement, a received request should become a + // just TransportEvent, and should be handled asynchronously + OnRequestReceived(chid ChannelID, msg Request) (Response, error) + + // OnRequestReceived occurs when we receive a response to a request + // TODO: in a future improvement, a received response should become a + // just TransportEvent, and should be handled asynchronously + OnResponseReceived(chid ChannelID, msg Response) error // OnContextAugment allows the transport to attach data transfer tracing information // to its local context, in order to create a hierarchical trace @@ -78,53 +119,75 @@ type EventsHandler interface { /* Transport defines the interface for a transport layer for data transfer. Where the data transfer manager will coordinate setting up push and -pull requests, validation, etc, the transport layer is responsible for moving +pull requests, persistence, validation, etc, the transport layer is responsible for moving data back and forth, and may be medium specific. For example, some transports may have the ability to pause and resume requests, while others may not. -Some may support individual data events, while others may only support message +Some may dispatch data update events, while others may only support message events. Some transport layers may opt to use the actual data transfer network protocols directly while others may be able to encode messages in their own data protocol. Transport is the minimum interface that must be satisfied to serve as a datatransfer -transport layer. Transports must be able to open (open is always called by the receiving peer) -and close channels, and set at an event handler */ +transport layer. Transports must be able to open and close channels, set at an event handler, +and send messages. Beyond that, additional commands may or may not be supported. +Whether a command is supported can be determined ahead by calling Capabilities(). +*/ type Transport interface { - // OpenChannel initiates an outgoing request for the other peer to send data - // to us on this channel - // Note: from a data transfer symantic standpoint, it doesn't matter if the - // request is push or pull -- OpenChannel is called by the party that is - // intending to receive data + // ID is a unique identifier for this transport + ID() TransportID + + // Versions indicates what versions of this transport are supported + Versions() []Version + + // Capabilities tells datatransfer what kinds of capabilities this transport supports + Capabilities() TransportCapabilities + // OpenChannel opens a channel on a given transport to move data back and forth. + // OpenChannel MUST ALWAYS called by the initiator. OpenChannel( ctx context.Context, - dataSender peer.ID, - channelID ChannelID, - root ipld.Link, - stor ipld.Node, - channel ChannelState, - msg Message, + channel Channel, + req Request, ) error - // CloseChannel closes the given channel - CloseChannel(ctx context.Context, chid ChannelID) error + // ChannelUpdated notifies the transport that state of the channel has been updated, + // along with an optional message to send over the transport to tell + // the other peer about the update + ChannelUpdated(ctx context.Context, chid ChannelID, message Message) error // SetEventHandler sets the handler for events on channels SetEventHandler(events EventsHandler) error - // CleanupChannel is called on the otherside of a cancel - removes any associated - // data for the channel + // CleanupChannel removes any associated data on a closed channel CleanupChannel(chid ChannelID) + // SendMessage sends a data transfer message over the channel to the other peer + SendMessage(ctx context.Context, chid ChannelID, msg Message) error + // Shutdown unregisters the current EventHandler and ends all active data transfers Shutdown(ctx context.Context) error + + // Optional Methods: Some channels may not support these + + // Restart restarts a channel on the initiator side + // RestartChannel MUST ALWAYS called by the initiator + RestartChannel(ctx context.Context, channel ChannelState, req Request) error } -// PauseableTransport is a transport that can also pause and resume channels -type PauseableTransport interface { - Transport - // PauseChannel paused the given channel ID - PauseChannel(ctx context.Context, - chid ChannelID, - ) error - // ResumeChannel resumes the given channel - ResumeChannel(ctx context.Context, - msg Message, - chid ChannelID, - ) error +// TransportCapabilities describes additional capabilities supported by ChannelActions +type TransportCapabilities struct { + // Restarable indicates ChannelActions will support RestartActions + Restartable bool + // Pausable indicates ChannelActions will support PauseActions + Pausable bool } + +func (TransportOpenedChannel) transportEvent() {} +func (TransportInitiatedTransfer) transportEvent() {} +func (TransportReceivedData) transportEvent() {} +func (TransportSentData) transportEvent() {} +func (TransportQueuedData) transportEvent() {} +func (TransportReachedDataLimit) transportEvent() {} +func (TransportTransferCancelled) transportEvent() {} +func (TransportErrorSendingData) transportEvent() {} +func (TransportErrorReceivingData) transportEvent() {} +func (TransportCompletedTransfer) transportEvent() {} +func (TransportReceivedRestartExistingChannelRequest) transportEvent() {} +func (TransportErrorSendingMessage) transportEvent() {} +func (TransportPaused) transportEvent() {} +func (TransportResumed) transportEvent() {} diff --git a/transport/graphsync/dtchannel/dtchannel.go b/transport/graphsync/dtchannel/dtchannel.go new file mode 100644 index 00000000..81a9944f --- /dev/null +++ b/transport/graphsync/dtchannel/dtchannel.go @@ -0,0 +1,535 @@ +package dtchannel + +import ( + "context" + "errors" + "fmt" + "sync" + "time" + + "github.com/ipfs/go-graphsync" + "github.com/ipfs/go-graphsync/donotsendfirstblocks" + logging "github.com/ipfs/go-log/v2" + ipld "github.com/ipld/go-ipld-prime" + peer "github.com/libp2p/go-libp2p-core/peer" + + datatransfer "github.com/filecoin-project/go-data-transfer/v2" + "github.com/filecoin-project/go-data-transfer/v2/transport/graphsync/executor" +) + +const maxGSCancelWait = time.Second + +var log = logging.Logger("dt_graphsync") + +// state is the state graphsync data transfer channel +type state uint64 + +// +const ( + channelClosed state = iota + channelOpen + channelPaused +) + +// Info needed to keep track of a data transfer channel +type Channel struct { + isSender bool + channelID datatransfer.ChannelID + gs graphsync.GraphExchange + + lk sync.RWMutex + state state + requestID *graphsync.RequestID + completed chan struct{} + requesterCancelled bool + pendingExtensions []graphsync.ExtensionData + + storeLk sync.RWMutex + storeRegistered bool + + receivedIndex int64 + sentIndex int64 + queuedIndex int64 + dataLimit uint64 + progress uint64 +} + +func NewChannel(channelID datatransfer.ChannelID, gs graphsync.GraphExchange) *Channel { + return &Channel{ + channelID: channelID, + gs: gs, + } +} + +// Open a graphsync request for data to the remote peer +func (c *Channel) Open( + ctx context.Context, + requestID graphsync.RequestID, + dataSender peer.ID, + root ipld.Link, + stor ipld.Node, + exts []graphsync.ExtensionData, +) (*executor.Executor, error) { + c.lk.Lock() + defer c.lk.Unlock() + + // If there is an existing graphsync request for this channelID + if c.requestID != nil { + // Cancel the existing graphsync request + completed := c.completed + errch := c.cancel(ctx) + + // Wait for the complete callback to be called + c.lk.Unlock() + err := waitForCompleteHook(ctx, completed) + c.lk.Lock() + if err != nil { + return nil, fmt.Errorf("%s: waiting for cancelled graphsync request to complete: %w", c.channelID, err) + } + + // Wait for the cancel request method to complete + select { + case err = <-errch: + case <-ctx.Done(): + err = fmt.Errorf("timed out waiting for graphsync request to be cancelled") + } + if err != nil { + return nil, fmt.Errorf("%s: restarting graphsync request: %w", c.channelID, err) + } + } + + // add do not send cids ext as needed + if c.receivedIndex > 0 { + data := donotsendfirstblocks.EncodeDoNotSendFirstBlocks(c.receivedIndex) + exts = append(exts, graphsync.ExtensionData{ + Name: graphsync.ExtensionsDoNotSendFirstBlocks, + Data: data, + }) + } + + // Set up a completed channel that will be closed when the request + // completes (or is cancelled) + completed := make(chan struct{}) + var onCompleteOnce sync.Once + onComplete := func() { + // Ensure the channel is only closed once + onCompleteOnce.Do(func() { + c.MarkTransferComplete() + log.Infow("closing the completion ch for data-transfer channel", "chid", c.channelID) + close(completed) + }) + } + c.completed = completed + + // Open a new graphsync request + msg := fmt.Sprintf("Opening graphsync request to %s for root %s", dataSender, root) + if c.receivedIndex > 0 { + msg += fmt.Sprintf(" with %d Blocks already received", c.receivedIndex) + } + log.Info(msg) + c.requestID = &requestID + ctx = context.WithValue(ctx, graphsync.RequestIDContextKey{}, *c.requestID) + responseChan, errChan := c.gs.Request(ctx, dataSender, root, stor, exts...) + c.state = channelOpen + // Save a mapping from the graphsync key to the channel ID so that + // subsequent graphsync callbacks are associated with this channel + + e := executor.NewExecutor(c.channelID, responseChan, errChan, onComplete) + return e, nil +} + +func waitForCompleteHook(ctx context.Context, completed chan struct{}) error { + // Wait for the cancel to propagate through to graphsync, and for + // the graphsync request to complete + select { + case <-completed: + return nil + case <-time.After(maxGSCancelWait): + // Fail-safe: give up waiting after a certain amount of time + return nil + case <-ctx.Done(): + return ctx.Err() + } +} + +// gsReqOpened is called when graphsync makes a request to the remote peer to ask for data +func (c *Channel) GsReqOpened(sender peer.ID, requestID graphsync.RequestID, hookActions graphsync.OutgoingRequestHookActions) { + // Tell graphsync to store the received blocks in the registered store + if c.hasStore() { + hookActions.UsePersistenceOption("data-transfer-" + c.channelID.String()) + } + log.Infow("outgoing graphsync request", "peer", sender, "graphsync request id", requestID, "data transfer channel id", c.channelID) +} + +// gsDataRequestRcvd is called when the transport receives an incoming request +// for data. +func (c *Channel) GsDataRequestRcvd(sender peer.ID, requestID graphsync.RequestID, chst datatransfer.ChannelState, hookActions graphsync.IncomingRequestHookActions) { + c.lk.Lock() + defer c.lk.Unlock() + log.Debugf("%s: received request for data, req_id=%d", c.channelID, requestID) + // If the requester had previously cancelled their request, send any + // message that was queued since the cancel + if c.requesterCancelled { + c.requesterCancelled = false + + extensions := c.pendingExtensions + c.pendingExtensions = nil + for _, ext := range extensions { + hookActions.SendExtensionData(ext) + } + } + + // Tell graphsync to load blocks from the registered store + if c.hasStore() { + hookActions.UsePersistenceOption("data-transfer-" + c.channelID.String()) + } + + // Save a mapping from the graphsync key to the channel ID so that + // subsequent graphsync callbacks are associated with this channel + c.requestID = &requestID + log.Infow("incoming graphsync request", "peer", sender, "graphsync request id", requestID, "data transfer channel id", c.channelID) + + c.state = channelOpen + + err := c.updateFromChannelState(chst) + if err != nil { + hookActions.TerminateWithError(err) + return + } + + action := c.actionFromChannelState(chst) + switch action { + case Pause: + c.state = channelPaused + hookActions.PauseResponse() + case Close: + c.state = channelClosed + hookActions.TerminateWithError(datatransfer.ErrRejected) + return + default: + } + hookActions.ValidateRequest() +} + +func (c *Channel) MarkPaused() { + c.lk.Lock() + defer c.lk.Unlock() + c.state = channelPaused +} + +func (c *Channel) Paused() bool { + c.lk.RLock() + defer c.lk.RUnlock() + return c.state == channelPaused +} + +func (c *Channel) Pause(ctx context.Context) error { + c.lk.Lock() + defer c.lk.Unlock() + + // Check if the channel was already cancelled + if c.requestID == nil { + log.Debugf("%s: channel was cancelled so not pausing channel", c.channelID) + return nil + } + + if c.state != channelOpen { + log.Debugf("%s: channel is not open so not pausing channel", c.channelID) + return nil + } + + c.state = channelPaused + + // If the requester cancelled, bail out + if c.requesterCancelled { + log.Debugf("%s: requester has cancelled so not pausing response for now", c.channelID) + return nil + } + + // Pause the response + log.Debugf("%s: pausing response", c.channelID) + return c.gs.Pause(ctx, *c.requestID) +} + +func (c *Channel) Resume(ctx context.Context, extensions []graphsync.ExtensionData) error { + c.lk.Lock() + defer c.lk.Unlock() + + // Check if the channel was already cancelled + if c.requestID == nil { + log.Debugf("%s: channel was cancelled so not resuming channel", c.channelID) + return nil + } + if c.state != channelPaused { + log.Debugf("%s: channel is not paused so not resuming channel", c.channelID) + return nil + } + + c.state = channelOpen + + // If the requester cancelled, bail out + if c.requesterCancelled { + // If there was an associated message, we still want to send it to the + // remote peer. We're not sending any message now, so instead queue up + // the message to be sent next time the peer makes a request to us. + c.pendingExtensions = append(c.pendingExtensions, extensions...) + + log.Debugf("%s: requester has cancelled so not unpausing for now", c.channelID) + return nil + } + + log.Debugf("%s: unpausing response", c.channelID) + return c.gs.Unpause(ctx, *c.requestID, extensions...) +} + +type Action string + +const ( + NoAction Action = "" + Close Action = "close" + Pause Action = "pause" + Resume Action = "resume" +) + +// UpdateFromChannelState updates internal graphsync channel state form a datatransfer +// channel state +func (c *Channel) UpdateFromChannelState(chst datatransfer.ChannelState) error { + c.lk.Lock() + defer c.lk.Unlock() + return c.updateFromChannelState(chst) +} + +func (c *Channel) updateFromChannelState(chst datatransfer.ChannelState) error { + // read the sent value + sentNode := chst.SentIndex() + if !sentNode.IsNull() { + sentIndex, err := sentNode.AsInt() + if err != nil { + return err + } + if sentIndex > c.sentIndex { + c.sentIndex = sentIndex + } + } + + // read the received + receivedNode := chst.ReceivedIndex() + if !receivedNode.IsNull() { + receivedIndex, err := receivedNode.AsInt() + if err != nil { + return err + } + if receivedIndex > c.receivedIndex { + c.receivedIndex = receivedIndex + } + } + + // read the queued + queuedNode := chst.QueuedIndex() + if !queuedNode.IsNull() { + queuedIndex, err := queuedNode.AsInt() + if err != nil { + return err + } + if queuedIndex > c.queuedIndex { + c.queuedIndex = queuedIndex + } + } + + // set progress + var progress uint64 + if chst.Sender() == chst.SelfPeer() { + progress = chst.Queued() + } else { + progress = chst.Received() + } + if progress > c.progress { + c.progress = progress + } + + // set data limit + c.dataLimit = chst.DataLimit() + return nil +} + +// ActionFromChannelState comparse internal graphsync channel state with the data transfer +// state and determines what if any action should be taken on graphsync +func (c *Channel) ActionFromChannelState(chst datatransfer.ChannelState) Action { + c.lk.Lock() + defer c.lk.Unlock() + return c.actionFromChannelState(chst) +} + +func (c *Channel) actionFromChannelState(chst datatransfer.ChannelState) Action { + // if the state is closed, and we haven't closed, we need to close + if !c.requesterCancelled && c.state != channelClosed && chst.Status().TransferComplete() { + return Close + } + + // if the state is running, and we're paused, we need to pause + if c.requestID != nil && c.state == channelPaused && !chst.SelfPaused() { + return Resume + } + + // if the state is paused, and the transfer is running, we need to resume + if c.requestID != nil && c.state == channelOpen && chst.SelfPaused() { + return Pause + } + + return NoAction +} + +func (c *Channel) ReconcileChannelState(chst datatransfer.ChannelState) (Action, error) { + c.lk.Lock() + defer c.lk.Unlock() + err := c.updateFromChannelState(chst) + if err != nil { + return NoAction, err + } + return c.actionFromChannelState(chst), nil +} + +func (c *Channel) MarkTransferComplete() { + c.lk.Lock() + defer c.lk.Unlock() + c.state = channelClosed +} + +// Called when the responder gets a cancel message from the requester +func (c *Channel) OnRequesterCancelled() { + c.lk.Lock() + defer c.lk.Unlock() + + c.requesterCancelled = true +} + +func (c *Channel) hasStore() bool { + c.storeLk.RLock() + defer c.storeLk.RUnlock() + + return c.storeRegistered +} + +// Use the given loader and storer to get / put blocks for the data-transfer. +// Note that each data-transfer channel uses a separate blockstore. +func (c *Channel) UseStore(lsys ipld.LinkSystem) error { + c.storeLk.Lock() + defer c.storeLk.Unlock() + + // Register the channel's store with graphsync + err := c.gs.RegisterPersistenceOption("data-transfer-"+c.channelID.String(), lsys) + if err != nil { + return err + } + + c.storeRegistered = true + + return nil +} + +func (c *Channel) UpdateReceivedIndexIfGreater(nextIdx int64) bool { + c.lk.Lock() + defer c.lk.Unlock() + if c.receivedIndex < nextIdx { + c.receivedIndex = nextIdx + return true + } + return false +} + +func (c *Channel) UpdateQueuedIndexIfGreater(nextIdx int64) bool { + c.lk.Lock() + defer c.lk.Unlock() + if c.queuedIndex < nextIdx { + c.queuedIndex = nextIdx + return true + } + return false +} + +func (c *Channel) UpdateSentIndexIfGreater(nextIdx int64) bool { + c.lk.Lock() + defer c.lk.Unlock() + if c.sentIndex < nextIdx { + c.sentIndex = nextIdx + return true + } + return false +} + +func (c *Channel) UpdateProgress(additionalData uint64) bool { + c.lk.Lock() + defer c.lk.Unlock() + c.progress += additionalData + reachedLimit := c.dataLimit != 0 && c.progress >= c.dataLimit + if reachedLimit { + c.state = channelPaused + } + return reachedLimit +} + +func (c *Channel) Cleanup() { + c.lk.Lock() + defer c.lk.Unlock() + + log.Debugf("%s: cleaning up channel", c.channelID) + + if c.hasStore() { + // Unregister the channel's store from graphsync + opt := "data-transfer-" + c.channelID.String() + err := c.gs.UnregisterPersistenceOption(opt) + if err != nil { + log.Errorf("failed to unregister persistence option %s: %s", opt, err) + } + } + +} + +func (c *Channel) Close(ctx context.Context) error { + // Cancel the graphsync request + c.lk.Lock() + errch := c.cancel(ctx) + c.lk.Unlock() + + // Wait for the cancel message to complete + select { + case err := <-errch: + return err + case <-ctx.Done(): + return ctx.Err() + } +} + +// cancel the graphsync request. +// Note: must be called under the lock. +func (c *Channel) cancel(ctx context.Context) chan error { + errch := make(chan error, 1) + + // Check that the request has not already been cancelled + if c.requesterCancelled || c.state == channelClosed { + errch <- nil + return errch + } + + // Clear the graphsync key to indicate that the request has been cancelled + requestID := c.requestID + c.requestID = nil + c.state = channelClosed + go func() { + log.Debugf("%s: cancelling request", c.channelID) + err := c.gs.Cancel(ctx, *requestID) + + // Ignore "request not found" errors + if err != nil && !errors.Is(graphsync.RequestNotFoundErr{}, err) { + errch <- fmt.Errorf("cancelling graphsync request for channel %s: %w", c.channelID, err) + } else { + errch <- nil + } + }() + + return errch +} + +func (c *Channel) IsCurrentRequest(requestID graphsync.RequestID) bool { + return c.requestID != nil && *c.requestID == requestID +} diff --git a/transport/graphsync/exceptions_test.go b/transport/graphsync/exceptions_test.go new file mode 100644 index 00000000..c4845721 --- /dev/null +++ b/transport/graphsync/exceptions_test.go @@ -0,0 +1,287 @@ +package graphsync_test + +import ( + "context" + "errors" + "fmt" + "testing" + "time" + + "github.com/ipfs/go-graphsync" + "github.com/ipld/go-ipld-prime/datamodel" + cidlink "github.com/ipld/go-ipld-prime/linking/cid" + basicnode "github.com/ipld/go-ipld-prime/node/basic" + "github.com/stretchr/testify/require" + "golang.org/x/exp/rand" + + datatransfer "github.com/filecoin-project/go-data-transfer/v2" + "github.com/filecoin-project/go-data-transfer/v2/message" + "github.com/filecoin-project/go-data-transfer/v2/testutil" + "github.com/filecoin-project/go-data-transfer/v2/transport/graphsync/extension" + "github.com/filecoin-project/go-data-transfer/v2/transport/graphsync/testharness" +) + +func TestTransferExceptions(t *testing.T) { + ctx := context.Background() + testCases := []struct { + name string + parameters []testharness.Option + test func(t *testing.T, th *testharness.GsTestHarness) + }{ + { + name: "error executing pull graphsync request", + parameters: []testharness.Option{testharness.PullRequest()}, + test: func(t *testing.T, th *testharness.GsTestHarness) { + th.Transport.OpenChannel(th.Ctx, th.Channel, th.NewRequest(t)) + receivedRequest := th.Fgs.ReceivedRequests[0] + close(receivedRequest.ResponseChan) + receivedRequest.ResponseErrChan <- errors.New("something went wrong") + close(receivedRequest.ResponseErrChan) + select { + case <-th.CompletedRequests: + case <-ctx.Done(): + t.Fatalf("did not complete request") + } + th.Events.AssertTransportEventEventually(t, th.Channel.ChannelID(), datatransfer.TransportCompletedTransfer{Success: false, ErrorMessage: fmt.Sprintf("channel %s: graphsync request failed to complete: something went wrong", th.Channel.ChannelID())}) + }, + }, + { + name: "unrecognized outgoing pull request", + parameters: []testharness.Option{testharness.PullRequest()}, + test: func(t *testing.T, th *testharness.GsTestHarness) { + // open a channel + th.Transport.OpenChannel(th.Ctx, th.Channel, th.NewRequest(t)) + // configure a store + th.Transport.UseStore(th.Channel.ChannelID(), cidlink.DefaultLinkSystem()) + // setup a seperate request with different request ID but with contained message + otherRequest := testharness.NewFakeRequest(graphsync.NewRequestID(), map[graphsync.ExtensionName]datamodel.Node{ + extension.ExtensionDataTransfer1_1: th.NewRequest(t).ToIPLD(), + }, graphsync.RequestTypeNew) + // run outgoing request hook on this request + th.OutgoingRequestHook(otherRequest) + // no channel opened + th.Events.RefuteTransportEvent(t, th.Channel.ChannelID(), datatransfer.TransportOpenedChannel{}) + // no store configuration + require.Empty(t, th.OutgoingRequestHookActions.PersistenceOption) + // run outgoing request processing listener + th.OutgoingRequestProcessingListener(otherRequest) + // no transfer initiated event + th.Events.RefuteTransportEvent(t, th.Channel.ChannelID(), datatransfer.TransportInitiatedTransfer{}) + dtResponse := th.Response() + // create a response with the wrong request ID + otherResponse := testharness.NewFakeResponse(otherRequest.ID(), map[graphsync.ExtensionName]datamodel.Node{ + extension.ExtensionIncomingRequest1_1: dtResponse.ToIPLD(), + }, graphsync.PartialResponse) + // run incoming response hook + th.IncomingResponseHook(otherResponse) + // no response received + require.Nil(t, th.Events.ReceivedResponse) + // run blook hook + block := testharness.NewFakeBlockData(12345, 1, true) + th.IncomingBlockHook(otherResponse, block) + th.Events.RefuteTransportEvent(t, th.Channel.ChannelID(), datatransfer.TransportReceivedData{Size: block.BlockSize(), Index: basicnode.NewInt(block.Index())}) + }, + }, + { + name: "error cancelling on restart request", + parameters: []testharness.Option{testharness.PullRequest()}, + test: func(t *testing.T, th *testharness.GsTestHarness) { + // open a channel + _ = th.Transport.OpenChannel(th.Ctx, th.Channel, th.NewRequest(t)) + th.Fgs.ReturnedCancelError = errors.New("something went wrong") + err := th.Transport.RestartChannel(th.Ctx, th.Channel, th.RestartRequest(t)) + require.EqualError(t, err, fmt.Sprintf("%s: restarting graphsync request: cancelling graphsync request for channel %s: %s", th.Channel.ChannelID(), th.Channel.ChannelID(), "something went wrong")) + }, + }, + { + name: "error reconnecting during restart", + parameters: []testharness.Option{testharness.PullRequest()}, + test: func(t *testing.T, th *testharness.GsTestHarness) { + // open a channel + _ = th.Transport.OpenChannel(th.Ctx, th.Channel, th.NewRequest(t)) + expectedErr := errors.New("something went wrong") + th.DtNet.ReturnedConnectWithRetryError = expectedErr + err := th.Transport.RestartChannel(th.Ctx, th.Channel, th.RestartRequest(t)) + require.ErrorIs(t, err, expectedErr) + }, + }, + { + name: "unrecognized incoming graphsync request dt response", + test: func(t *testing.T, th *testharness.GsTestHarness) { + dtResponse := th.Response() + requestID := graphsync.NewRequestID() + request := testharness.NewFakeRequest(requestID, map[graphsync.ExtensionName]datamodel.Node{extension.ExtensionDataTransfer1_1: dtResponse.ToIPLD()}, graphsync.RequestTypeNew) + th.IncomingRequestHook(request) + require.False(t, th.IncomingRequestHookActions.Validated) + require.Error(t, th.IncomingRequestHookActions.TerminationError) + require.Equal(t, th.Events.ReceivedResponse, dtResponse) + }, + }, + { + name: "incoming graphsync request w/ dt response gets OnResponseReceived error", + test: func(t *testing.T, th *testharness.GsTestHarness) { + _ = th.Transport.OpenChannel(th.Ctx, th.Channel, th.NewRequest(t)) + dtResponse := th.Response() + requestID := graphsync.NewRequestID() + request := testharness.NewFakeRequest(requestID, map[graphsync.ExtensionName]datamodel.Node{extension.ExtensionDataTransfer1_1: dtResponse.ToIPLD()}, graphsync.RequestTypeNew) + th.Events.ReturnedResponseReceivedError = errors.New("something went wrong") + th.IncomingRequestHook(request) + require.False(t, th.IncomingRequestHookActions.Validated) + require.EqualError(t, th.IncomingRequestHookActions.TerminationError, "something went wrong") + require.Equal(t, th.Events.ReceivedResponse, dtResponse) + }, + }, + { + name: "pull request cancelled", + parameters: []testharness.Option{testharness.PullRequest()}, + test: func(t *testing.T, th *testharness.GsTestHarness) { + _ = th.Transport.OpenChannel(th.Ctx, th.Channel, th.NewRequest(t)) + require.Len(t, th.Fgs.ReceivedRequests, 1) + receivedRequest := th.Fgs.ReceivedRequests[0] + close(receivedRequest.ResponseChan) + receivedRequest.ResponseErrChan <- graphsync.RequestClientCancelledErr{} + close(receivedRequest.ResponseErrChan) + th.Events.AssertTransportEventEventually(t, th.Channel.ChannelID(), datatransfer.TransportTransferCancelled{ + ErrorMessage: "graphsync request cancelled", + }) + }, + }, + { + name: "error opening sending push message", + test: func(t *testing.T, th *testharness.GsTestHarness) { + th.DtNet.ReturnedSendMessageError = errors.New("something went wrong") + err := th.Transport.OpenChannel(th.Ctx, th.Channel, th.NewRequest(t)) + require.EqualError(t, err, "something went wrong") + }, + }, + { + name: "unrecognized incoming graphsync push request", + test: func(t *testing.T, th *testharness.GsTestHarness) { + // open a channel + th.Transport.OpenChannel(th.Ctx, th.Channel, th.NewRequest(t)) + // configure a store + th.Transport.UseStore(th.Channel.ChannelID(), cidlink.DefaultLinkSystem()) + voucherResult := testutil.NewTestTypedVoucher() + otherRequest := testharness.NewFakeRequest(graphsync.NewRequestID(), map[graphsync.ExtensionName]datamodel.Node{ + extension.ExtensionDataTransfer1_1: message.NewResponse(datatransfer.TransferID(rand.Uint64()), true, false, &voucherResult).ToIPLD(), + }, graphsync.RequestTypeNew) + // run incoming request hook on new request + th.IncomingRequestHook(otherRequest) + // should error + require.Error(t, th.IncomingRequestHookActions.TerminationError) + // run incoming request processing listener + th.IncomingRequestProcessingListener(otherRequest) + // no transfer initiated event + th.Events.RefuteTransportEvent(t, th.Channel.ChannelID(), datatransfer.TransportInitiatedTransfer{}) + // run block queued hook + block := testharness.NewFakeBlockData(12345, 1, true) + th.OutgoingBlockHook(otherRequest, block) + // no block queued event + th.Events.RefuteTransportEvent(t, th.Channel.ChannelID(), datatransfer.TransportQueuedData{Size: block.BlockSize(), Index: basicnode.NewInt(block.Index())}) + // run block sent hook + th.BlockSentListener(otherRequest, block) + // no block sent event + th.Events.RefuteTransportEvent(t, th.Channel.ChannelID(), datatransfer.TransportSentData{Size: block.BlockSize(), Index: basicnode.NewInt(block.Index())}) + // run complete listener + th.ResponseCompletedListener(otherRequest, graphsync.RequestCompletedFull) + // no complete event + th.Events.RefuteTransportEvent(t, th.Channel.ChannelID(), datatransfer.TransportCompletedTransfer{Success: true}) + }, + }, + { + name: "channel update on unrecognized channel", + test: func(t *testing.T, th *testharness.GsTestHarness) { + err := th.Transport.ChannelUpdated(th.Ctx, th.Channel.ChannelID(), th.NewRequest(t)) + require.Error(t, err) + }, + }, + { + name: "incoming request errors in OnRequestReceived", + parameters: []testharness.Option{testharness.PullRequest(), testharness.Responder()}, + test: func(t *testing.T, th *testharness.GsTestHarness) { + th.Transport.UseStore(th.Channel.ChannelID(), cidlink.DefaultLinkSystem()) + th.Fgs.AssertHasPersistenceOption(t, fmt.Sprintf("data-transfer-%s", th.Channel.ChannelID().String())) + requestID := graphsync.NewRequestID() + dtRequest := th.NewRequest(t) + request := testharness.NewFakeRequest(requestID, map[graphsync.ExtensionName]datamodel.Node{extension.ExtensionDataTransfer1_1: dtRequest.ToIPLD()}, graphsync.RequestTypeNew) + voucherResult := testutil.NewTestTypedVoucher() + dtResponse := message.NewResponse(th.Channel.TransferID(), false, false, &voucherResult) + th.Events.ReturnedRequestReceivedResponse = dtResponse + th.Events.ReturnedRequestReceivedError = errors.New("something went wrong") + th.Events.ReturnedOnContextAugmentFunc = func(ctx context.Context) context.Context { + return context.WithValue(ctx, ctxKey{}, "applesauce") + } + th.IncomingRequestHook(request) + require.Equal(t, dtRequest, th.Events.ReceivedRequest) + require.Empty(t, th.DtNet.ProtectedPeers) + require.Empty(t, th.IncomingRequestHookActions.PersistenceOption) + require.False(t, th.IncomingRequestHookActions.Validated) + require.False(t, th.IncomingRequestHookActions.Paused) + require.EqualError(t, th.IncomingRequestHookActions.TerminationError, "something went wrong") + sentResponse := th.IncomingRequestHookActions.DTMessage(t) + require.Equal(t, dtResponse, sentResponse) + th.IncomingRequestHookActions.RefuteAugmentedContextKey(t, ctxKey{}) + }, + }, + { + name: "incoming gs request with contained push request errors", + parameters: []testharness.Option{testharness.Responder()}, + test: func(t *testing.T, th *testharness.GsTestHarness) { + requestID := graphsync.NewRequestID() + dtRequest := th.NewRequest(t) + request := testharness.NewFakeRequest(requestID, map[graphsync.ExtensionName]datamodel.Node{extension.ExtensionDataTransfer1_1: dtRequest.ToIPLD()}, graphsync.RequestTypeNew) + dtResponse := th.Response() + th.Events.ReturnedRequestReceivedResponse = dtResponse + th.IncomingRequestHook(request) + require.EqualError(t, th.IncomingRequestHookActions.TerminationError, datatransfer.ErrUnsupported.Error()) + }, + }, + { + name: "incoming requests completes with error code for graphsync", + parameters: []testharness.Option{testharness.PullRequest(), testharness.Responder()}, + test: func(t *testing.T, th *testharness.GsTestHarness) { + requestID := graphsync.NewRequestID() + dtRequest := th.NewRequest(t) + request := testharness.NewFakeRequest(requestID, map[graphsync.ExtensionName]datamodel.Node{extension.ExtensionDataTransfer1_1: dtRequest.ToIPLD()}, graphsync.RequestTypeNew) + dtResponse := th.Response() + th.Events.ReturnedRequestReceivedResponse = dtResponse + th.IncomingRequestHook(request) + + th.ResponseCompletedListener(request, graphsync.RequestFailedUnknown) + select { + case <-th.CompletedResponses: + case <-ctx.Done(): + t.Fatalf("did not complete request") + } + th.Events.AssertTransportEvent(t, th.Channel.ChannelID(), datatransfer.TransportCompletedTransfer{Success: false, ErrorMessage: fmt.Sprintf("graphsync response to peer %s did not complete: response status code %s", th.Channel.Recipient(), graphsync.RequestFailedUnknown.String())}) + + }, + }, + { + name: "incoming push request message errors in OnRequestReceived", + parameters: []testharness.Option{testharness.Responder()}, + test: func(t *testing.T, th *testharness.GsTestHarness) { + th.Transport.UseStore(th.Channel.ChannelID(), cidlink.DefaultLinkSystem()) + th.Fgs.AssertHasPersistenceOption(t, fmt.Sprintf("data-transfer-%s", th.Channel.ChannelID().String())) + voucherResult := testutil.NewTestTypedVoucher() + dtResponse := message.NewResponse(th.Channel.TransferID(), false, false, &voucherResult) + th.Events.ReturnedRequestReceivedResponse = dtResponse + th.Events.ReturnedRequestReceivedError = errors.New("something went wrong") + th.DtNet.Delegates[0].Receiver.ReceiveRequest(ctx, th.Channel.OtherPeer(), th.NewRequest(t)) + require.Equal(t, th.NewRequest(t), th.Events.ReceivedRequest) + require.Empty(t, th.DtNet.ProtectedPeers) + require.Empty(t, th.Fgs.ReceivedRequests) + require.Len(t, th.DtNet.SentMessages, 1) + require.Equal(t, testharness.FakeSentMessage{Message: dtResponse, TransportID: "graphsync", PeerID: th.Channel.OtherPeer()}, th.DtNet.SentMessages[0]) + }, + }, + } + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + ctx, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + th := testharness.SetupHarness(ctx, testCase.parameters...) + testCase.test(t, th) + }) + } +} diff --git a/transport/graphsync/executor/executor.go b/transport/graphsync/executor/executor.go new file mode 100644 index 00000000..b0a8b4c0 --- /dev/null +++ b/transport/graphsync/executor/executor.go @@ -0,0 +1,109 @@ +package executor + +import ( + "fmt" + + "github.com/ipfs/go-graphsync" + logging "github.com/ipfs/go-log/v2" + + datatransfer "github.com/filecoin-project/go-data-transfer/v2" +) + +var log = logging.Logger("dt_graphsync") + +// EventsHandler are the data transfer events that can be dispatched by the execetor +type EventsHandler interface { + OnTransportEvent(datatransfer.ChannelID, datatransfer.TransportEvent) +} + +// Executor handles consuming channels on an outgoing GraphSync request +type Executor struct { + channelID datatransfer.ChannelID + responseChan <-chan graphsync.ResponseProgress + errChan <-chan error + onComplete func() +} + +// NewExecutor sets up a new executor to consume a graphsync request +func NewExecutor( + channelID datatransfer.ChannelID, + responseChan <-chan graphsync.ResponseProgress, + errChan <-chan error, + onComplete func()) *Executor { + return &Executor{channelID, responseChan, errChan, onComplete} +} + +// Start initiates consumption of a graphsync request +func (e *Executor) Start(events EventsHandler, + completedRequestListener func(channelID datatransfer.ChannelID)) { + go e.executeRequest(events, completedRequestListener) +} + +// Read from the graphsync response and error channels until they are closed, +// and return the last error on the error channel +func (e *Executor) consumeResponses() error { + var lastError error + for range e.responseChan { + } + log.Infof("channel %s: finished consuming graphsync response channel", e.channelID) + + for err := range e.errChan { + lastError = err + } + log.Infof("channel %s: finished consuming graphsync error channel", e.channelID) + + return lastError +} + +// Read from the graphsync response and error channels until they are closed +// or there is an error, then call the channel completed callback +func (e *Executor) executeRequest( + events EventsHandler, + completedRequestListener func(channelID datatransfer.ChannelID)) { + // Make sure to call the onComplete callback before returning + defer func() { + log.Infow("gs request complete for channel", "chid", e.channelID) + e.onComplete() + }() + + // Consume the response and error channels for the graphsync request + lastError := e.consumeResponses() + + // Request cancelled by client + if _, ok := lastError.(graphsync.RequestClientCancelledErr); ok { + terr := fmt.Errorf("graphsync request cancelled") + log.Warnf("channel %s: %s", e.channelID, terr) + events.OnTransportEvent(e.channelID, datatransfer.TransportTransferCancelled{ErrorMessage: terr.Error()}) + return + } + + // Request cancelled by responder + if _, ok := lastError.(graphsync.RequestCancelledErr); ok { + log.Infof("channel %s: graphsync request cancelled by responder", e.channelID) + // TODO Should we do anything for RequestCancelledErr ? + return + } + + if lastError != nil { + log.Warnf("channel %s: graphsync error: %s", e.channelID, lastError) + } + + log.Debugf("channel %s: finished executing graphsync request", e.channelID) + + var completeErr error + if lastError != nil { + completeErr = fmt.Errorf("channel %s: graphsync request failed to complete: %w", e.channelID, lastError) + } + + // Used by the tests to listen for when a request completes + if completedRequestListener != nil { + completedRequestListener(e.channelID) + } + + if completeErr == nil { + events.OnTransportEvent(e.channelID, datatransfer.TransportCompletedTransfer{Success: true}) + } else { + events.OnTransportEvent(e.channelID, datatransfer.TransportCompletedTransfer{Success: false, ErrorMessage: completeErr.Error()}) + } + +} diff --git a/transport/graphsync/executor/executor_test.go b/transport/graphsync/executor/executor_test.go new file mode 100644 index 00000000..089a151d --- /dev/null +++ b/transport/graphsync/executor/executor_test.go @@ -0,0 +1,141 @@ +package executor_test + +import ( + "context" + "errors" + "fmt" + "sync" + "testing" + + "github.com/ipfs/go-graphsync" + "github.com/stretchr/testify/require" + + datatransfer "github.com/filecoin-project/go-data-transfer/v2" + "github.com/filecoin-project/go-data-transfer/v2/testutil" + "github.com/filecoin-project/go-data-transfer/v2/transport/graphsync/executor" +) + +func TestExecutor(t *testing.T) { + ctx := context.Background() + chid := testutil.GenerateChannelID() + testCases := map[string]struct { + responseProgresses []graphsync.ResponseProgress + responseErrors []error + hasCompletedRequestHandler bool + expectedEventRecord fakeEvents + }{ + "simple no errors, no listener": { + expectedEventRecord: fakeEvents{ + completedChannel: chid, + completedError: nil, + }, + }, + "simple with error, no listener": { + responseErrors: []error{errors.New("something went wrong")}, + expectedEventRecord: fakeEvents{ + completedChannel: chid, + completedError: fmt.Errorf("channel %s: graphsync request failed to complete: %s", chid, errors.New("something went wrong")), + }, + }, + "client cancelled request error, no listener": { + responseErrors: []error{graphsync.RequestClientCancelledErr{}}, + expectedEventRecord: fakeEvents{ + cancelledChannel: chid, + cancelledErr: errors.New("graphsync request cancelled"), + }, + }, + // no events called here + "cancelled request error, no listener": { + responseErrors: []error{graphsync.RequestCancelledErr{}}, + }, + "has completed request handler": { + expectedEventRecord: fakeEvents{ + completedChannel: chid, + completedError: nil, + }, + hasCompletedRequestHandler: true, + }, + } + for testCase, data := range testCases { + t.Run(testCase, func(t *testing.T) { + responseChan := make(chan graphsync.ResponseProgress) + errChan := make(chan error) + events := &fakeEvents{} + fcrl := &fakeCompletedRequestListener{} + + completed := make(chan struct{}) + var onCompleteOnce sync.Once + + onComplete := func() { + onCompleteOnce.Do(func() { + close(completed) + }) + } + e := executor.NewExecutor(chid, responseChan, errChan, onComplete) + if data.hasCompletedRequestHandler { + e.Start(events, fcrl.complete) + } else { + e.Start(events, nil) + } + + for _, progress := range data.responseProgresses { + select { + case <-ctx.Done(): + t.Fatal("unable to queue all responses") + case responseChan <- progress: + } + } + close(responseChan) + + for _, err := range data.responseErrors { + select { + case <-ctx.Done(): + t.Fatal("unable to queue all errors") + case errChan <- err: + } + } + close(errChan) + + select { + case <-ctx.Done(): + t.Fatal("did not complete request") + case <-completed: + } + + require.Equal(t, data.expectedEventRecord, *events) + if data.hasCompletedRequestHandler { + require.Equal(t, chid, fcrl.calledChannel) + } else { + require.NotEqual(t, chid, fcrl.calledChannel) + } + }) + } +} + +type fakeEvents struct { + completedChannel datatransfer.ChannelID + completedError error + cancelledChannel datatransfer.ChannelID + cancelledErr error +} + +func (fe *fakeEvents) OnTransportEvent(chid datatransfer.ChannelID, transportEvent datatransfer.TransportEvent) { + switch evt := transportEvent.(type) { + case datatransfer.TransportCompletedTransfer: + fe.completedChannel = chid + if !evt.Success { + fe.completedError = errors.New(evt.ErrorMessage) + } + case datatransfer.TransportTransferCancelled: + fe.cancelledChannel = chid + fe.cancelledErr = errors.New(evt.ErrorMessage) + } +} + +type fakeCompletedRequestListener struct { + calledChannel datatransfer.ChannelID +} + +func (fcrl *fakeCompletedRequestListener) complete(channelID datatransfer.ChannelID) { + fcrl.calledChannel = channelID +} diff --git a/transport/graphsync/extension/gsextension.go b/transport/graphsync/extension/gsextension.go index 356ac41c..28aeb667 100644 --- a/transport/graphsync/extension/gsextension.go +++ b/transport/graphsync/extension/gsextension.go @@ -5,10 +5,9 @@ import ( "github.com/ipfs/go-graphsync" "github.com/ipld/go-ipld-prime/datamodel" - "github.com/libp2p/go-libp2p-core/protocol" - datatransfer "github.com/filecoin-project/go-data-transfer" - "github.com/filecoin-project/go-data-transfer/message" + datatransfer "github.com/filecoin-project/go-data-transfer/v2" + "github.com/filecoin-project/go-data-transfer/v2/message" ) const ( @@ -21,10 +20,10 @@ const ( ) // ProtocolMap maps graphsync extensions to their libp2p protocols -var ProtocolMap = map[graphsync.ExtensionName]protocol.ID{ - ExtensionIncomingRequest1_1: datatransfer.ProtocolDataTransfer1_2, - ExtensionOutgoingBlock1_1: datatransfer.ProtocolDataTransfer1_2, - ExtensionDataTransfer1_1: datatransfer.ProtocolDataTransfer1_2, +var ProtocolMap = map[graphsync.ExtensionName]datatransfer.Version{ + ExtensionIncomingRequest1_1: datatransfer.DataTransfer1_2, + ExtensionOutgoingBlock1_1: datatransfer.DataTransfer1_2, + ExtensionDataTransfer1_1: datatransfer.DataTransfer1_2, } // ToExtensionData converts a message to a graphsync extension @@ -35,14 +34,11 @@ func ToExtensionData(msg datatransfer.Message, supportedExtensions []graphsync.E if !ok { return nil, errors.New("unsupported protocol") } - versionedMsg, err := msg.MessageForProtocol(protoID) + versionedMsg, err := msg.MessageForVersion(protoID) if err != nil { continue } - nd, err := versionedMsg.ToIPLD() - if err != nil { - return nil, err - } + nd := versionedMsg.ToIPLD() exts = append(exts, graphsync.ExtensionData{ Name: supportedExtension, Data: nd, diff --git a/transport/graphsync/graphsync.go b/transport/graphsync/graphsync.go index 1eb5c286..41d1e085 100644 --- a/transport/graphsync/graphsync.go +++ b/transport/graphsync/graphsync.go @@ -2,30 +2,33 @@ package graphsync import ( "context" - "errors" - "fmt" "sync" "time" "github.com/ipfs/go-graphsync" - "github.com/ipfs/go-graphsync/donotsendfirstblocks" logging "github.com/ipfs/go-log/v2" ipld "github.com/ipld/go-ipld-prime" + "github.com/ipld/go-ipld-prime/datamodel" + cidlink "github.com/ipld/go-ipld-prime/linking/cid" peer "github.com/libp2p/go-libp2p-core/peer" "golang.org/x/sync/errgroup" "golang.org/x/xerrors" - datatransfer "github.com/filecoin-project/go-data-transfer" - "github.com/filecoin-project/go-data-transfer/transport/graphsync/extension" + datatransfer "github.com/filecoin-project/go-data-transfer/v2" + dtchannel "github.com/filecoin-project/go-data-transfer/v2/transport/graphsync/dtchannel" + "github.com/filecoin-project/go-data-transfer/v2/transport/graphsync/extension" + "github.com/filecoin-project/go-data-transfer/v2/transport/helpers/network" ) var log = logging.Logger("dt_graphsync") +var transportID datatransfer.TransportID = "graphsync" +var supportedVersions = []datatransfer.Version{{Major: 1, Minor: 0, Patch: 0}} + // When restarting a data transfer, we cancel the existing graphsync request // before opening a new one. // This constant defines the maximum time to wait for the request to be // cancelled. -const maxGSCancelWait = time.Second var defaultSupportedExtensions = []graphsync.ExtensionName{ extension.ExtensionDataTransfer1_1, @@ -33,12 +36,10 @@ var defaultSupportedExtensions = []graphsync.ExtensionName{ var incomingReqExtensions = []graphsync.ExtensionName{ extension.ExtensionIncomingRequest1_1, - extension.ExtensionDataTransfer1_1, } var outgoingBlkExtensions = []graphsync.ExtensionName{ extension.ExtensionOutgoingBlock1_1, - extension.ExtensionDataTransfer1_1, } // Option is an option for setting up the graphsync transport @@ -70,6 +71,7 @@ func RegisterCompletedResponseListener(l func(channelID datatransfer.ChannelID)) type Transport struct { events datatransfer.EventsHandler gs graphsync.GraphExchange + dtNet network.DataTransferNetwork peerID peer.ID supportedExtensions []graphsync.ExtensionName @@ -79,7 +81,7 @@ type Transport struct { // Map from data transfer channel ID to information about that channel dtChannelsLk sync.RWMutex - dtChannels map[datatransfer.ChannelID]*dtChannel + dtChannels map[datatransfer.ChannelID]*dtchannel.Channel // Used in graphsync callbacks to map from graphsync request to the // associated data-transfer channel ID. @@ -87,12 +89,13 @@ type Transport struct { } // NewTransport makes a new hooks manager with the given hook events interface -func NewTransport(peerID peer.ID, gs graphsync.GraphExchange, options ...Option) *Transport { +func NewTransport(gs graphsync.GraphExchange, dtNet network.DataTransferNetwork, options ...Option) *Transport { t := &Transport{ gs: gs, - peerID: peerID, + dtNet: dtNet, + peerID: dtNet.ID(), supportedExtensions: defaultSupportedExtensions, - dtChannels: make(map[datatransfer.ChannelID]*dtChannel), + dtChannels: make(map[datatransfer.ChannelID]*dtchannel.Channel), requestIDToChannelID: newRequestIDToChannelIDMap(), } for _, option := range options { @@ -101,18 +104,81 @@ func NewTransport(peerID peer.ID, gs graphsync.GraphExchange, options ...Option) return t } -// OpenChannel initiates an outgoing request for the other peer to send data -// to us on this channel -// Note: from a data transfer symantic standpoint, it doesn't matter if the -// request is push or pull -- OpenChannel is called by the party that is -// intending to receive data +func (t *Transport) ID() datatransfer.TransportID { + return transportID +} + +func (t *Transport) Versions() []datatransfer.Version { + return supportedVersions +} + +func (t *Transport) Capabilities() datatransfer.TransportCapabilities { + return datatransfer.TransportCapabilities{ + Pausable: true, + Restartable: true, + } +} + +// OpenChannel opens a channel on a given transport to move data back and forth. +// OpenChannel MUST ALWAYS called by the initiator. func (t *Transport) OpenChannel( + ctx context.Context, + channel datatransfer.Channel, + req datatransfer.Request) error { + t.dtNet.Protect(channel.OtherPeer(), channel.ChannelID().String()) + t.trackDTChannel(channel.ChannelID()) + if channel.IsPull() { + return t.openRequest(ctx, + channel.Sender(), + channel.ChannelID(), + cidlink.Link{Cid: channel.BaseCID()}, + channel.Selector(), + req) + } + return t.dtNet.SendMessage(ctx, channel.OtherPeer(), transportID, req) +} + +// RestartChannel restarts a channel on the initiator side +// RestartChannel MUST ALWAYS called by the initiator +func (t *Transport) RestartChannel( + ctx context.Context, + channelState datatransfer.ChannelState, + req datatransfer.Request) error { + log.Debugf("%s: re-establishing connection to %s", channelState.ChannelID(), channelState.OtherPeer()) + start := time.Now() + err := t.dtNet.ConnectWithRetry(ctx, channelState.OtherPeer(), transportID) + if err != nil { + return xerrors.Errorf("%s: failed to reconnect to peer %s after %s: %w", + channelState.ChannelID(), channelState.OtherPeer(), time.Since(start), err) + } + log.Debugf("%s: re-established connection to %s in %s", channelState.ChannelID(), channelState.OtherPeer(), time.Since(start)) + + t.dtNet.Protect(channelState.OtherPeer(), channelState.ChannelID().String()) + + ch := t.trackDTChannel(channelState.ChannelID()) + err = ch.UpdateFromChannelState(channelState) + if err != nil { + return err + } + + if channelState.IsPull() { + + return t.openRequest(ctx, + channelState.Sender(), + channelState.ChannelID(), + cidlink.Link{Cid: channelState.BaseCID()}, + channelState.Selector(), + req) + } + return t.dtNet.SendMessage(ctx, channelState.OtherPeer(), transportID, req) +} + +func (t *Transport) openRequest( ctx context.Context, dataSender peer.ID, channelID datatransfer.ChannelID, root ipld.Link, - stor ipld.Node, - channel datatransfer.ChannelState, + stor datamodel.Node, msg datatransfer.Message, ) error { if t.events == nil { @@ -123,151 +189,84 @@ func (t *Transport) OpenChannel( if err != nil { return err } - // If this is a restart request, the client can indicate the blocks that - // it has already received, so that the provider knows not to resend - // those blocks - restartExts, err := t.getRestartExtension(ctx, dataSender, channel) - if err != nil { - return err - } - exts = append(exts, restartExts...) // Start tracking the data-transfer channel ch := t.trackDTChannel(channelID) + requestID := graphsync.NewRequestID() + t.requestIDToChannelID.set(requestID, false, channelID) + // Open a graphsync request to the remote peer - req, err := ch.open(ctx, channelID, dataSender, root, stor, channel, exts) + execetor, err := ch.Open(ctx, requestID, dataSender, root, stor, exts) + if err != nil { return err } - // Process incoming data - go t.executeGsRequest(req) - + execetor.Start(t.events, t.completedRequestListener) return nil } -// Get the extension data for sending a Restart message, depending on the -// protocol version of the peer -func (t *Transport) getRestartExtension(ctx context.Context, p peer.ID, channel datatransfer.ChannelState) ([]graphsync.ExtensionData, error) { - if channel == nil { - return nil, nil - } - return getDoNotSendFirstBlocksExtension(channel) -} - -// Skip the first N blocks because they were already received -func getDoNotSendFirstBlocksExtension(channel datatransfer.ChannelState) ([]graphsync.ExtensionData, error) { - skipBlockCount := channel.ReceivedCidsTotal() - data := donotsendfirstblocks.EncodeDoNotSendFirstBlocks(skipBlockCount) - return []graphsync.ExtensionData{{ - Name: graphsync.ExtensionsDoNotSendFirstBlocks, - Data: data, - }}, nil -} - -// Read from the graphsync response and error channels until they are closed, -// and return the last error on the error channel -func (t *Transport) consumeResponses(req *gsReq) error { - var lastError error - for range req.responseChan { - } - log.Debugf("channel %s: finished consuming graphsync response channel", req.channelID) - - for err := range req.errChan { - lastError = err - } - log.Debugf("channel %s: finished consuming graphsync error channel", req.channelID) - - return lastError -} - -// Read from the graphsync response and error channels until they are closed -// or there is an error, then call the channel completed callback -func (t *Transport) executeGsRequest(req *gsReq) { - // Make sure to call the onComplete callback before returning - defer func() { - log.Infow("gs request complete for channel", "chid", req.channelID) - req.onComplete() - }() - - // Consume the response and error channels for the graphsync request - lastError := t.consumeResponses(req) - - // Request cancelled by client - if _, ok := lastError.(graphsync.RequestClientCancelledErr); ok { - terr := xerrors.Errorf("graphsync request cancelled") - log.Warnf("channel %s: %s", req.channelID, terr) - if err := t.events.OnRequestCancelled(req.channelID, terr); err != nil { - log.Error(err) - } - return - } - - // Request cancelled by responder - if _, ok := lastError.(graphsync.RequestCancelledErr); ok { - log.Infof("channel %s: graphsync request cancelled by responder", req.channelID) - // TODO Should we do anything for RequestCancelledErr ? - return - } - - if lastError != nil { - log.Warnf("channel %s: graphsync error: %s", req.channelID, lastError) - } - - log.Debugf("channel %s: finished executing graphsync request", req.channelID) - - var completeErr error - if lastError != nil { - completeErr = xerrors.Errorf("channel %s: graphsync request failed to complete: %w", req.channelID, lastError) - } - - // Used by the tests to listen for when a request completes - if t.completedRequestListener != nil { - t.completedRequestListener(req.channelID) - } - - err := t.events.OnChannelCompleted(req.channelID, completeErr) +func (t *Transport) reconcileChannelStates(ctx context.Context, chid datatransfer.ChannelID) (*dtchannel.Channel, dtchannel.Action, error) { + chst, err := t.events.ChannelState(ctx, chid) if err != nil { - log.Errorf("channel %s: processing OnChannelCompleted: %s", req.channelID, err) + return nil, dtchannel.NoAction, err } -} - -// PauseChannel pauses the given data-transfer channel -func (t *Transport) PauseChannel(ctx context.Context, chid datatransfer.ChannelID) error { ch, err := t.getDTChannel(chid) if err != nil { - return err + return nil, dtchannel.NoAction, err } - return ch.pause(ctx) + action, err := ch.ReconcileChannelState(chst) + return ch, action, err } -// ResumeChannel resumes the given data-transfer channel and sends the message -// if there is one -func (t *Transport) ResumeChannel( - ctx context.Context, - msg datatransfer.Message, - chid datatransfer.ChannelID, -) error { - ch, err := t.getDTChannel(chid) +// ChannelUpdated notifies the transport that state of the channel has been updated, +// along with an optional message to send over the transport to tell +// the other peer about the update +func (t *Transport) ChannelUpdated(ctx context.Context, chid datatransfer.ChannelID, message datatransfer.Message) error { + ch, action, err := t.reconcileChannelStates(ctx, chid) if err != nil { + if message != nil { + if sendErr := t.dtNet.SendMessage(ctx, t.otherPeer(chid), transportID, message); sendErr != nil { + return sendErr + } + } return err } - return ch.resume(ctx, msg) + return t.processAction(ctx, chid, ch, action, message) } -// CloseChannel closes the given data-transfer channel -func (t *Transport) CloseChannel(ctx context.Context, chid datatransfer.ChannelID) error { - ch, err := t.getDTChannel(chid) - if err != nil { - return err +func (t *Transport) processAction(ctx context.Context, chid datatransfer.ChannelID, ch *dtchannel.Channel, action dtchannel.Action, message datatransfer.Message) error { + if action == dtchannel.Resume { + var extensions []graphsync.ExtensionData + if message != nil { + var err error + extensions, err = extension.ToExtensionData(message, t.supportedExtensions) + if err != nil { + return err + } + } + return ch.Resume(ctx, extensions) } - err = ch.close(ctx) - if err != nil { - return xerrors.Errorf("closing channel: %w", err) + if message != nil { + if err := t.dtNet.SendMessage(ctx, t.otherPeer(chid), transportID, message); err != nil { + return err + } + } + switch action { + case dtchannel.Close: + return ch.Close(ctx) + case dtchannel.Pause: + return ch.Pause(ctx) + default: + return nil } - return nil +} + +// SendMessage sends a data transfer message over the channel to the other peer +func (t *Transport) SendMessage(ctx context.Context, chid datatransfer.ChannelID, msg datatransfer.Message) error { + return t.dtNet.SendMessage(ctx, t.otherPeer(chid), transportID, msg) } // CleanupChannel is called on the otherside of a cancel - removes any associated @@ -283,10 +282,15 @@ func (t *Transport) CleanupChannel(chid datatransfer.ChannelID) { t.dtChannelsLk.Unlock() + // Clean up mapping from gs key to channel ID + t.requestIDToChannelID.deleteRefs(chid) + // Clean up the channel if ok { - ch.cleanup() + ch.Cleanup() } + + t.dtNet.Unprotect(t.otherPeer(chid), chid.String()) } // SetEventHandler sets the handler for events on channels @@ -296,18 +300,21 @@ func (t *Transport) SetEventHandler(events datatransfer.EventsHandler) error { } t.events = events - t.unregisterFuncs = append(t.unregisterFuncs, t.gs.RegisterIncomingRequestQueuedHook(t.gsReqQueuedHook)) + t.unregisterFuncs = append(t.unregisterFuncs, t.gs.RegisterIncomingRequestProcessingListener(t.gsRequestProcessingListener)) + t.unregisterFuncs = append(t.unregisterFuncs, t.gs.RegisterOutgoingRequestProcessingListener(t.gsRequestProcessingListener)) t.unregisterFuncs = append(t.unregisterFuncs, t.gs.RegisterIncomingRequestHook(t.gsReqRecdHook)) t.unregisterFuncs = append(t.unregisterFuncs, t.gs.RegisterCompletedResponseListener(t.gsCompletedResponseListener)) t.unregisterFuncs = append(t.unregisterFuncs, t.gs.RegisterIncomingBlockHook(t.gsIncomingBlockHook)) t.unregisterFuncs = append(t.unregisterFuncs, t.gs.RegisterOutgoingBlockHook(t.gsOutgoingBlockHook)) - t.unregisterFuncs = append(t.unregisterFuncs, t.gs.RegisterBlockSentListener(t.gsBlockSentHook)) + t.unregisterFuncs = append(t.unregisterFuncs, t.gs.RegisterBlockSentListener(t.gsBlockSentListener)) t.unregisterFuncs = append(t.unregisterFuncs, t.gs.RegisterOutgoingRequestHook(t.gsOutgoingRequestHook)) t.unregisterFuncs = append(t.unregisterFuncs, t.gs.RegisterIncomingResponseHook(t.gsIncomingResponseHook)) - t.unregisterFuncs = append(t.unregisterFuncs, t.gs.RegisterRequestUpdatedHook(t.gsRequestUpdatedHook)) + //t.unregisterFuncs = append(t.unregisterFuncs, t.gs.RegisterRequestUpdatedHook(t.gsRequestUpdatedHook)) t.unregisterFuncs = append(t.unregisterFuncs, t.gs.RegisterRequestorCancelledListener(t.gsRequestorCancelledListener)) t.unregisterFuncs = append(t.unregisterFuncs, t.gs.RegisterNetworkErrorListener(t.gsNetworkSendErrorListener)) t.unregisterFuncs = append(t.unregisterFuncs, t.gs.RegisterReceiverNetworkErrorListener(t.gsNetworkReceiveErrorListener)) + + t.dtNet.SetDelegate(transportID, supportedVersions, &receiver{t}) return nil } @@ -324,7 +331,7 @@ func (t *Transport) Shutdown(ctx context.Context) error { for _, ch := range t.dtChannels { ch := ch eg.Go(func() error { - return ch.shutdown(ctx) + return ch.Close(ctx) }) } @@ -338,7 +345,7 @@ func (t *Transport) Shutdown(ctx context.Context) error { // UseStore tells the graphsync transport to use the given loader and storer for this channelID func (t *Transport) UseStore(channelID datatransfer.ChannelID, lsys ipld.LinkSystem) error { ch := t.trackDTChannel(channelID) - return ch.useStore(lsys) + return ch.UseStore(lsys) } // ChannelGraphsyncRequests describes any graphsync request IDs associated with a given channel @@ -386,7 +393,7 @@ func (t *Transport) ChannelsForPeer(p peer.ID) ChannelsForPeer { channelGraphsyncRequests := collection[chid] // finally, determine if the request key matches the current GraphSync key we're tracking for // this channel, indicating it's the current graphsync request - if t.dtChannels[chid] != nil && t.dtChannels[chid].requestID != nil && (*t.dtChannels[chid].requestID) == requestID { + if t.dtChannels[chid] != nil && t.dtChannels[chid].IsCurrentRequest(requestID) { channelGraphsyncRequests.Current = requestID } else { // otherwise this id was a previous graphsync request on a channel that was restarted @@ -402,911 +409,4 @@ func (t *Transport) ChannelsForPeer(p peer.ID) ChannelsForPeer { } } -// gsOutgoingRequestHook is called when a graphsync request is made -func (t *Transport) gsOutgoingRequestHook(p peer.ID, request graphsync.RequestData, hookActions graphsync.OutgoingRequestHookActions) { - message, _ := extension.GetTransferData(request, t.supportedExtensions) - - // extension not found; probably not our request. - if message == nil { - return - } - - // A graphsync request is made when either - // - The local node opens a data-transfer pull channel, so the local node - // sends a graphsync request to ask the remote peer for the data - // - The remote peer opened a data-transfer push channel, and in response - // the local node sends a graphsync request to ask for the data - var initiator peer.ID - var responder peer.ID - if message.IsRequest() { - // This is a pull request so the data-transfer initiator is the local node - initiator = t.peerID - responder = p - } else { - // This is a push response so the data-transfer initiator is the remote - // peer: They opened the push channel, we respond by sending a - // graphsync request for the data - initiator = p - responder = t.peerID - } - chid := datatransfer.ChannelID{Initiator: initiator, Responder: responder, ID: message.TransferID()} - - // A data transfer channel was opened - err := t.events.OnChannelOpened(chid) - if err != nil { - // There was an error opening the channel, bail out - log.Errorf("processing OnChannelOpened for %s: %s", chid, err) - t.CleanupChannel(chid) - return - } - - // Start tracking the channel if we're not already - ch := t.trackDTChannel(chid) - - // Signal that the channel has been opened - ch.gsReqOpened(request.ID(), hookActions) -} - -// gsIncomingBlockHook is called when a block is received -func (t *Transport) gsIncomingBlockHook(p peer.ID, response graphsync.ResponseData, block graphsync.BlockData, hookActions graphsync.IncomingBlockHookActions) { - chid, ok := t.requestIDToChannelID.load(response.RequestID()) - if !ok { - return - } - - err := t.events.OnDataReceived(chid, block.Link(), block.BlockSize(), block.Index(), block.BlockSizeOnWire() != 0) - if err != nil && err != datatransfer.ErrPause { - hookActions.TerminateWithError(err) - return - } - - if err == datatransfer.ErrPause { - hookActions.PauseRequest() - } -} - -func (t *Transport) gsBlockSentHook(p peer.ID, request graphsync.RequestData, block graphsync.BlockData) { - // When a data transfer is restarted, the requester sends a list of CIDs - // that it already has. Graphsync calls the sent hook for all blocks even - // if they are in the list (meaning, they aren't actually sent over the - // wire). So here we check if the block was actually sent - // over the wire before firing the data sent event. - if block.BlockSizeOnWire() == 0 { - return - } - - chid, ok := t.requestIDToChannelID.load(request.ID()) - if !ok { - return - } - - if err := t.events.OnDataSent(chid, block.Link(), block.BlockSize(), block.Index(), block.BlockSizeOnWire() != 0); err != nil { - log.Errorf("failed to process data sent: %+v", err) - } -} - -func (t *Transport) gsOutgoingBlockHook(p peer.ID, request graphsync.RequestData, block graphsync.BlockData, hookActions graphsync.OutgoingBlockHookActions) { - // When a data transfer is restarted, the requester sends a list of CIDs - // that it already has. Graphsync calls the outgoing block hook for all - // blocks even if they are in the list (meaning, they aren't actually going - // to be sent over the wire). So here we check if the block is actually - // going to be sent over the wire before firing the data queued event. - if block.BlockSizeOnWire() == 0 { - return - } - - chid, ok := t.requestIDToChannelID.load(request.ID()) - if !ok { - return - } - - // OnDataQueued is called when a block is queued to be sent to the remote - // peer. It can return ErrPause to pause the response (eg if payment is - // required) and it can return a message that will be sent with the block - // (eg to ask for payment). - msg, err := t.events.OnDataQueued(chid, block.Link(), block.BlockSize(), block.Index(), block.BlockSizeOnWire() != 0) - if err != nil && err != datatransfer.ErrPause { - hookActions.TerminateWithError(err) - return - } - - if err == datatransfer.ErrPause { - hookActions.PauseResponse() - } - - if msg != nil { - // gsOutgoingBlockHook uses a unique extension name so it can be attached with data from a different hook - // outgoingBlkExtensions also includes the default extension name so it remains compatible with all data-transfer protocol versions out there - extensions, err := extension.ToExtensionData(msg, outgoingBlkExtensions) - if err != nil { - hookActions.TerminateWithError(err) - return - } - for _, extension := range extensions { - hookActions.SendExtensionData(extension) - } - } -} - -// gsReqQueuedHook is called when graphsync enqueues an incoming request for data -func (t *Transport) gsReqQueuedHook(p peer.ID, request graphsync.RequestData, hookActions graphsync.RequestQueuedHookActions) { - msg, err := extension.GetTransferData(request, t.supportedExtensions) - if err != nil { - log.Errorf("failed GetTransferData, req=%+v, err=%s", request, err) - } - // extension not found; probably not our request. - if msg == nil { - return - } - - var chid datatransfer.ChannelID - if msg.IsRequest() { - // when a data transfer request comes in on graphsync, the remote peer - // initiated a pull - chid = datatransfer.ChannelID{ID: msg.TransferID(), Initiator: p, Responder: t.peerID} - dtRequest := msg.(datatransfer.Request) - if dtRequest.IsNew() { - log.Infof("%s, pull request queued, req_id=%d", chid, request.ID()) - t.events.OnTransferQueued(chid) - } else { - log.Infof("%s, pull restart request queued, req_id=%d", chid, request.ID()) - } - } else { - // when a data transfer response comes in on graphsync, this node - // initiated a push, and the remote peer responded with a request - // for data - chid = datatransfer.ChannelID{ID: msg.TransferID(), Initiator: t.peerID, Responder: p} - response := msg.(datatransfer.Response) - if response.IsNew() { - log.Infof("%s, GS pull request queued in response to our push, req_id=%d", chid, request.ID()) - t.events.OnTransferQueued(chid) - } else { - log.Infof("%s, GS pull request queued in response to our restart push, req_id=%d", chid, request.ID()) - } - } - augmentContext := t.events.OnContextAugment(chid) - if augmentContext != nil { - hookActions.AugmentContext(augmentContext) - } -} - -// gsReqRecdHook is called when graphsync receives an incoming request for data -func (t *Transport) gsReqRecdHook(p peer.ID, request graphsync.RequestData, hookActions graphsync.IncomingRequestHookActions) { - // if this is a push request the sender is us. - msg, err := extension.GetTransferData(request, t.supportedExtensions) - if err != nil { - hookActions.TerminateWithError(err) - return - } - - // extension not found; probably not our request. - if msg == nil { - return - } - - // An incoming graphsync request for data is received when either - // - The remote peer opened a data-transfer pull channel, so the local node - // receives a graphsync request for the data - // - The local node opened a data-transfer push channel, and in response - // the remote peer sent a graphsync request for the data, and now the - // local node receives that request for data - var chid datatransfer.ChannelID - var responseMessage datatransfer.Message - var ch *dtChannel - if msg.IsRequest() { - // when a data transfer request comes in on graphsync, the remote peer - // initiated a pull - chid = datatransfer.ChannelID{ID: msg.TransferID(), Initiator: p, Responder: t.peerID} - - log.Debugf("%s: received request for data (pull), req_id=%d", chid, request.ID()) - - // Lock the channel for the duration of this method - ch = t.trackDTChannel(chid) - ch.lk.Lock() - defer ch.lk.Unlock() - - request := msg.(datatransfer.Request) - responseMessage, err = t.events.OnRequestReceived(chid, request) - } else { - // when a data transfer response comes in on graphsync, this node - // initiated a push, and the remote peer responded with a request - // for data - chid = datatransfer.ChannelID{ID: msg.TransferID(), Initiator: t.peerID, Responder: p} - - log.Debugf("%s: received request for data (push), req_id=%d", chid, request.ID()) - - // Lock the channel for the duration of this method - ch = t.trackDTChannel(chid) - ch.lk.Lock() - defer ch.lk.Unlock() - - response := msg.(datatransfer.Response) - err = t.events.OnResponseReceived(chid, response) - } - - // If we need to send a response, add the response message as an extension - if responseMessage != nil { - // gsReqRecdHook uses a unique extension name so it can be attached with data from a different hook - // incomingReqExtensions also includes default extension name so it remains compatible with previous data-transfer - // protocol versions out there. - extensions, extensionErr := extension.ToExtensionData(responseMessage, incomingReqExtensions) - if extensionErr != nil { - hookActions.TerminateWithError(err) - return - } - for _, extension := range extensions { - hookActions.SendExtensionData(extension) - } - } - - if err != nil && err != datatransfer.ErrPause { - hookActions.TerminateWithError(err) - return - } - - // Check if the callback indicated that the channel should be paused - // immediately (eg because data is still being unsealed) - paused := false - if err == datatransfer.ErrPause { - log.Debugf("%s: pausing graphsync response", chid) - - paused = true - hookActions.PauseResponse() - } - - // If this is a restart request, and the data transfer still hasn't got - // out of the paused state (eg because we're still unsealing), start this - // graphsync response in the paused state. - if ch.isOpen && !ch.xferStarted && !paused { - log.Debugf("%s: pausing graphsync response after restart", chid) - - paused = true - hookActions.PauseResponse() - } - - // If the transfer is not paused, record that the transfer has started - if !paused { - ch.xferStarted = true - } - - ch.gsDataRequestRcvd(request.ID(), hookActions) - - hookActions.ValidateRequest() -} - -// gsCompletedResponseListener is a graphsync.OnCompletedResponseListener. We use it learn when the data transfer is complete -// for the side that is responding to a graphsync request -func (t *Transport) gsCompletedResponseListener(p peer.ID, request graphsync.RequestData, status graphsync.ResponseStatusCode) { - chid, ok := t.requestIDToChannelID.load(request.ID()) - if !ok { - return - } - - if status == graphsync.RequestCancelled { - return - } - - var completeErr error - if status != graphsync.RequestCompletedFull { - statusStr := gsResponseStatusCodeString(status) - completeErr = xerrors.Errorf("graphsync response to peer %s did not complete: response status code %s", p, statusStr) - } - - // Used by the tests to listen for when a response completes - if t.completedResponseListener != nil { - t.completedResponseListener(chid) - } - - err := t.events.OnChannelCompleted(chid, completeErr) - if err != nil { - log.Error(err) - } -} - -// Remove this map once this PR lands: https://github.com/ipfs/go-graphsync/pull/148 -var gsResponseStatusCodes = map[graphsync.ResponseStatusCode]string{ - graphsync.RequestAcknowledged: "RequestAcknowledged", - graphsync.AdditionalPeers: "AdditionalPeers", - graphsync.NotEnoughGas: "NotEnoughGas", - graphsync.OtherProtocol: "OtherProtocol", - graphsync.PartialResponse: "PartialResponse", - graphsync.RequestPaused: "RequestPaused", - graphsync.RequestCompletedFull: "RequestCompletedFull", - graphsync.RequestCompletedPartial: "RequestCompletedPartial", - graphsync.RequestRejected: "RequestRejected", - graphsync.RequestFailedBusy: "RequestFailedBusy", - graphsync.RequestFailedUnknown: "RequestFailedUnknown", - graphsync.RequestFailedLegal: "RequestFailedLegal", - graphsync.RequestFailedContentNotFound: "RequestFailedContentNotFound", - graphsync.RequestCancelled: "RequestCancelled", -} - -func gsResponseStatusCodeString(code graphsync.ResponseStatusCode) string { - str, ok := gsResponseStatusCodes[code] - if ok { - return str - } - return gsResponseStatusCodes[graphsync.RequestFailedUnknown] -} - -func (t *Transport) gsRequestUpdatedHook(p peer.ID, request graphsync.RequestData, update graphsync.RequestData, hookActions graphsync.RequestUpdatedHookActions) { - chid, ok := t.requestIDToChannelID.load(request.ID()) - if !ok { - return - } - - responseMessage, err := t.processExtension(chid, update, p, t.supportedExtensions) - - if responseMessage != nil { - extensions, extensionErr := extension.ToExtensionData(responseMessage, t.supportedExtensions) - if extensionErr != nil { - hookActions.TerminateWithError(err) - return - } - for _, extension := range extensions { - hookActions.SendExtensionData(extension) - } - } - - if err != nil && err != datatransfer.ErrPause { - hookActions.TerminateWithError(err) - } - -} - -// gsIncomingResponseHook is a graphsync.OnIncomingResponseHook. We use it to pass on responses -func (t *Transport) gsIncomingResponseHook(p peer.ID, response graphsync.ResponseData, hookActions graphsync.IncomingResponseHookActions) { - chid, ok := t.requestIDToChannelID.load(response.RequestID()) - if !ok { - return - } - - responseMessage, err := t.processExtension(chid, response, p, incomingReqExtensions) - - if responseMessage != nil { - extensions, extensionErr := extension.ToExtensionData(responseMessage, t.supportedExtensions) - if extensionErr != nil { - hookActions.TerminateWithError(err) - return - } - for _, extension := range extensions { - hookActions.UpdateRequestWithExtensions(extension) - } - } - - if err != nil { - hookActions.TerminateWithError(err) - } - - // In a case where the transfer sends blocks immediately this extension may contain both a - // response message and a revalidation request so we trigger OnResponseReceived again for this - // specific extension name - _, err = t.processExtension(chid, response, p, []graphsync.ExtensionName{extension.ExtensionOutgoingBlock1_1}) - - if err != nil { - hookActions.TerminateWithError(err) - } -} - -func (t *Transport) processExtension(chid datatransfer.ChannelID, gsMsg extension.GsExtended, p peer.ID, exts []graphsync.ExtensionName) (datatransfer.Message, error) { - - // if this is a push request the sender is us. - msg, err := extension.GetTransferData(gsMsg, exts) - if err != nil { - return nil, err - } - - // extension not found; probably not our request. - if msg == nil { - return nil, nil - } - - if msg.IsRequest() { - - // only accept request message updates when original message was also request - if (chid != datatransfer.ChannelID{ID: msg.TransferID(), Initiator: p, Responder: t.peerID}) { - return nil, errors.New("received request on response channel") - } - dtRequest := msg.(datatransfer.Request) - return t.events.OnRequestReceived(chid, dtRequest) - } - - // only accept response message updates when original message was also response - if (chid != datatransfer.ChannelID{ID: msg.TransferID(), Initiator: t.peerID, Responder: p}) { - return nil, errors.New("received response on request channel") - } - - dtResponse := msg.(datatransfer.Response) - return nil, t.events.OnResponseReceived(chid, dtResponse) -} - -func (t *Transport) gsRequestorCancelledListener(p peer.ID, request graphsync.RequestData) { - chid, ok := t.requestIDToChannelID.load(request.ID()) - if !ok { - return - } - - ch, err := t.getDTChannel(chid) - if err != nil { - if !xerrors.Is(datatransfer.ErrChannelNotFound, err) { - log.Errorf("requestor cancelled: getting channel %s: %s", chid, err) - } - return - } - - log.Debugf("%s: requester cancelled data-transfer", chid) - ch.onRequesterCancelled() -} - -// Called when there is a graphsync error sending data -func (t *Transport) gsNetworkSendErrorListener(p peer.ID, request graphsync.RequestData, gserr error) { - // Fire an error if the graphsync request was made by this node or the remote peer - chid, ok := t.requestIDToChannelID.load(request.ID()) - if !ok { - return - } - - err := t.events.OnSendDataError(chid, gserr) - if err != nil { - log.Errorf("failed to fire transport send error %s: %s", gserr, err) - } -} - -// Called when there is a graphsync error receiving data -func (t *Transport) gsNetworkReceiveErrorListener(p peer.ID, gserr error) { - // Fire a receive data error on all ongoing graphsync transfers with that - // peer - t.requestIDToChannelID.forEach(func(k graphsync.RequestID, sending bool, chid datatransfer.ChannelID) { - if chid.Initiator != p && chid.Responder != p { - return - } - - err := t.events.OnReceiveDataError(chid, gserr) - if err != nil { - log.Errorf("failed to fire transport receive error %s: %s", gserr, err) - } - }) -} - -func (t *Transport) newDTChannel(chid datatransfer.ChannelID) *dtChannel { - return &dtChannel{ - t: t, - channelID: chid, - opened: make(chan graphsync.RequestID, 1), - } -} - -func (t *Transport) trackDTChannel(chid datatransfer.ChannelID) *dtChannel { - t.dtChannelsLk.Lock() - defer t.dtChannelsLk.Unlock() - - ch, ok := t.dtChannels[chid] - if !ok { - ch = t.newDTChannel(chid) - t.dtChannels[chid] = ch - } - - return ch -} - -func (t *Transport) getDTChannel(chid datatransfer.ChannelID) (*dtChannel, error) { - if t.events == nil { - return nil, datatransfer.ErrHandlerNotSet - } - - t.dtChannelsLk.RLock() - defer t.dtChannelsLk.RUnlock() - - ch, ok := t.dtChannels[chid] - if !ok { - return nil, xerrors.Errorf("channel %s: %w", chid, datatransfer.ErrChannelNotFound) - } - return ch, nil -} - -// Info needed to keep track of a data transfer channel -type dtChannel struct { - channelID datatransfer.ChannelID - t *Transport - - lk sync.RWMutex - isOpen bool - requestID *graphsync.RequestID - completed chan struct{} - requesterCancelled bool - xferStarted bool - pendingExtensions []graphsync.ExtensionData - - opened chan graphsync.RequestID - - storeLk sync.RWMutex - storeRegistered bool -} - -// Info needed to monitor an ongoing graphsync request -type gsReq struct { - channelID datatransfer.ChannelID - responseChan <-chan graphsync.ResponseProgress - errChan <-chan error - onComplete func() -} - -// Open a graphsync request for data to the remote peer -func (c *dtChannel) open( - ctx context.Context, - chid datatransfer.ChannelID, - dataSender peer.ID, - root ipld.Link, - stor ipld.Node, - channel datatransfer.ChannelState, - exts []graphsync.ExtensionData, -) (*gsReq, error) { - c.lk.Lock() - defer c.lk.Unlock() - - // If there is an existing graphsync request for this channelID - if c.requestID != nil { - // Cancel the existing graphsync request - completed := c.completed - errch := c.cancel(ctx) - - // Wait for the complete callback to be called - err := waitForCompleteHook(ctx, completed) - if err != nil { - return nil, xerrors.Errorf("%s: waiting for cancelled graphsync request to complete: %w", chid, err) - } - - // Wait for the cancel request method to complete - select { - case err = <-errch: - case <-ctx.Done(): - err = xerrors.Errorf("timed out waiting for graphsync request to be cancelled") - } - if err != nil { - return nil, xerrors.Errorf("%s: restarting graphsync request: %w", chid, err) - } - } - - // Set up a completed channel that will be closed when the request - // completes (or is cancelled) - completed := make(chan struct{}) - var onCompleteOnce sync.Once - onComplete := func() { - // Ensure the channel is only closed once - onCompleteOnce.Do(func() { - log.Debugw("closing the completion ch for data-transfer channel", "chid", chid) - close(completed) - }) - } - c.completed = completed - - // Open a new graphsync request - msg := fmt.Sprintf("Opening graphsync request to %s for root %s", dataSender, root) - if channel != nil { - msg += fmt.Sprintf(" with %d Blocks already received", channel.ReceivedCidsTotal()) - } - log.Info(msg) - responseChan, errChan := c.t.gs.Request(ctx, dataSender, root, stor, exts...) - - // Wait for graphsync "request opened" callback - select { - case <-ctx.Done(): - return nil, ctx.Err() - case requestID := <-c.opened: - // Mark the channel as open and save the Graphsync request key - c.isOpen = true - c.requestID = &requestID - } - - return &gsReq{ - channelID: chid, - responseChan: responseChan, - errChan: errChan, - onComplete: onComplete, - }, nil -} - -func waitForCompleteHook(ctx context.Context, completed chan struct{}) error { - // Wait for the cancel to propagate through to graphsync, and for - // the graphsync request to complete - select { - case <-completed: - return nil - case <-time.After(maxGSCancelWait): - // Fail-safe: give up waiting after a certain amount of time - return nil - case <-ctx.Done(): - return ctx.Err() - } -} - -// gsReqOpened is called when graphsync makes a request to the remote peer to ask for data -func (c *dtChannel) gsReqOpened(requestID graphsync.RequestID, hookActions graphsync.OutgoingRequestHookActions) { - // Tell graphsync to store the received blocks in the registered store - if c.hasStore() { - hookActions.UsePersistenceOption("data-transfer-" + c.channelID.String()) - } - log.Infow("outgoing graphsync request", "peer", c.channelID.OtherParty(c.t.peerID), "graphsync request id", requestID, "data transfer channel id", c.channelID) - // Save a mapping from the graphsync key to the channel ID so that - // subsequent graphsync callbacks are associated with this channel - c.t.requestIDToChannelID.set(requestID, false, c.channelID) - - c.opened <- requestID -} - -// gsDataRequestRcvd is called when the transport receives an incoming request -// for data. -// Note: Must be called under the lock. -func (c *dtChannel) gsDataRequestRcvd(requestID graphsync.RequestID, hookActions graphsync.IncomingRequestHookActions) { - log.Debugf("%s: received request for data, req_id=%d", c.channelID, requestID) - - // If the requester had previously cancelled their request, send any - // message that was queued since the cancel - if c.requesterCancelled { - c.requesterCancelled = false - - extensions := c.pendingExtensions - c.pendingExtensions = nil - for _, ext := range extensions { - hookActions.SendExtensionData(ext) - } - } - - // Tell graphsync to load blocks from the registered store - if c.hasStore() { - hookActions.UsePersistenceOption("data-transfer-" + c.channelID.String()) - } - - // Save a mapping from the graphsync key to the channel ID so that - // subsequent graphsync callbacks are associated with this channel - c.requestID = &requestID - log.Infow("incoming graphsync request", "peer", c.channelID.OtherParty(c.t.peerID), "graphsync request id", requestID, "data transfer channel id", c.channelID) - c.t.requestIDToChannelID.set(requestID, true, c.channelID) - - c.isOpen = true -} - -func (c *dtChannel) pause(ctx context.Context) error { - c.lk.Lock() - defer c.lk.Unlock() - - // Check if the channel was already cancelled - if c.requestID == nil { - log.Debugf("%s: channel was cancelled so not pausing channel", c.channelID) - return nil - } - - // If the requester cancelled, bail out - if c.requesterCancelled { - log.Debugf("%s: requester has cancelled so not pausing response", c.channelID) - return nil - } - - // Pause the response - log.Debugf("%s: pausing response", c.channelID) - return c.t.gs.Pause(ctx, *c.requestID) -} - -func (c *dtChannel) resume(ctx context.Context, msg datatransfer.Message) error { - c.lk.Lock() - defer c.lk.Unlock() - - // Check if the channel was already cancelled - if c.requestID == nil { - log.Debugf("%s: channel was cancelled so not resuming channel", c.channelID) - return nil - } - - var extensions []graphsync.ExtensionData - if msg != nil { - var err error - extensions, err = extension.ToExtensionData(msg, c.t.supportedExtensions) - if err != nil { - return err - } - } - - // If the requester cancelled, bail out - if c.requesterCancelled { - // If there was an associated message, we still want to send it to the - // remote peer. We're not sending any message now, so instead queue up - // the message to be sent next time the peer makes a request to us. - c.pendingExtensions = append(c.pendingExtensions, extensions...) - - log.Debugf("%s: requester has cancelled so not unpausing response", c.channelID) - return nil - } - - // Record that the transfer has started - c.xferStarted = true - - log.Debugf("%s: unpausing response", c.channelID) - return c.t.gs.Unpause(ctx, *c.requestID, extensions...) -} - -func (c *dtChannel) close(ctx context.Context) error { - var errch chan error - c.lk.Lock() - { - // Check if the channel was already cancelled - if c.requestID != nil { - errch = c.cancel(ctx) - } - } - c.lk.Unlock() - - // Wait for the cancel message to complete - select { - case err := <-errch: - return err - case <-ctx.Done(): - return ctx.Err() - } -} - -// Called when the responder gets a cancel message from the requester -func (c *dtChannel) onRequesterCancelled() { - c.lk.Lock() - defer c.lk.Unlock() - - c.requesterCancelled = true -} - -func (c *dtChannel) hasStore() bool { - c.storeLk.RLock() - defer c.storeLk.RUnlock() - - return c.storeRegistered -} - -// Use the given loader and storer to get / put blocks for the data-transfer. -// Note that each data-transfer channel uses a separate blockstore. -func (c *dtChannel) useStore(lsys ipld.LinkSystem) error { - c.storeLk.Lock() - defer c.storeLk.Unlock() - - // Register the channel's store with graphsync - err := c.t.gs.RegisterPersistenceOption("data-transfer-"+c.channelID.String(), lsys) - if err != nil { - return err - } - - c.storeRegistered = true - - return nil -} - -func (c *dtChannel) cleanup() { - c.lk.Lock() - defer c.lk.Unlock() - - log.Debugf("%s: cleaning up channel", c.channelID) - - if c.hasStore() { - // Unregister the channel's store from graphsync - opt := "data-transfer-" + c.channelID.String() - err := c.t.gs.UnregisterPersistenceOption(opt) - if err != nil { - log.Errorf("failed to unregister persistence option %s: %s", opt, err) - } - } - - // Clean up mapping from gs key to channel ID - c.t.requestIDToChannelID.deleteRefs(c.channelID) -} - -func (c *dtChannel) shutdown(ctx context.Context) error { - // Cancel the graphsync request - c.lk.Lock() - errch := c.cancel(ctx) - c.lk.Unlock() - - // Wait for the cancel message to complete - select { - case err := <-errch: - return err - case <-ctx.Done(): - return ctx.Err() - } -} - -// Cancel the graphsync request. -// Note: must be called under the lock. -func (c *dtChannel) cancel(ctx context.Context) chan error { - errch := make(chan error, 1) - - // Check that the request has not already been cancelled - if c.requesterCancelled || c.requestID == nil { - errch <- nil - return errch - } - - // Clear the graphsync key to indicate that the request has been cancelled - requestID := c.requestID - c.requestID = nil - - go func() { - log.Debugf("%s: cancelling request", c.channelID) - err := c.t.gs.Cancel(ctx, *requestID) - - // Ignore "request not found" errors - if err != nil && !xerrors.Is(graphsync.RequestNotFoundErr{}, err) { - errch <- xerrors.Errorf("cancelling graphsync request for channel %s: %w", c.channelID, err) - } else { - errch <- nil - } - }() - - return errch -} - -type channelInfo struct { - sending bool - channelID datatransfer.ChannelID -} - -// Used in graphsync callbacks to map from graphsync request to the -// associated data-transfer channel ID. -type requestIDToChannelIDMap struct { - lk sync.RWMutex - m map[graphsync.RequestID]channelInfo -} - -func newRequestIDToChannelIDMap() *requestIDToChannelIDMap { - return &requestIDToChannelIDMap{ - m: make(map[graphsync.RequestID]channelInfo), - } -} - -// get the value for a key -func (m *requestIDToChannelIDMap) load(key graphsync.RequestID) (datatransfer.ChannelID, bool) { - m.lk.RLock() - defer m.lk.RUnlock() - - val, ok := m.m[key] - return val.channelID, ok -} - -// get the value if any of the keys exists in the map -func (m *requestIDToChannelIDMap) any(ks ...graphsync.RequestID) (datatransfer.ChannelID, bool) { - m.lk.RLock() - defer m.lk.RUnlock() - - for _, k := range ks { - val, ok := m.m[k] - if ok { - return val.channelID, ok - } - } - return datatransfer.ChannelID{}, false -} - -// set the value for a key -func (m *requestIDToChannelIDMap) set(key graphsync.RequestID, sending bool, chid datatransfer.ChannelID) { - m.lk.Lock() - defer m.lk.Unlock() - - m.m[key] = channelInfo{sending, chid} -} - -// call f for each key / value in the map -func (m *requestIDToChannelIDMap) forEach(f func(k graphsync.RequestID, isSending bool, chid datatransfer.ChannelID)) { - m.lk.RLock() - defer m.lk.RUnlock() - - for k, ch := range m.m { - f(k, ch.sending, ch.channelID) - } -} - -// delete any keys that reference this value -func (m *requestIDToChannelIDMap) deleteRefs(id datatransfer.ChannelID) { - m.lk.Lock() - defer m.lk.Unlock() - - for k, ch := range m.m { - if ch.channelID == id { - delete(m.m, k) - } - } -} +var _ datatransfer.Transport = (*Transport)(nil) diff --git a/transport/graphsync/graphsync_test.go b/transport/graphsync/graphsync_test.go deleted file mode 100644 index 82637593..00000000 --- a/transport/graphsync/graphsync_test.go +++ /dev/null @@ -1,1476 +0,0 @@ -package graphsync_test - -import ( - "context" - "errors" - "io" - "math/rand" - "testing" - "time" - - "github.com/ipfs/go-cid" - "github.com/ipfs/go-graphsync" - "github.com/ipfs/go-graphsync/donotsendfirstblocks" - "github.com/ipld/go-ipld-prime" - "github.com/ipld/go-ipld-prime/datamodel" - cidlink "github.com/ipld/go-ipld-prime/linking/cid" - "github.com/ipld/go-ipld-prime/node/basicnode" - peer "github.com/libp2p/go-libp2p-core/peer" - "github.com/libp2p/go-libp2p-core/protocol" - "github.com/stretchr/testify/require" - - datatransfer "github.com/filecoin-project/go-data-transfer" - "github.com/filecoin-project/go-data-transfer/message" - "github.com/filecoin-project/go-data-transfer/testutil" - . "github.com/filecoin-project/go-data-transfer/transport/graphsync" - "github.com/filecoin-project/go-data-transfer/transport/graphsync/extension" -) - -func TestManager(t *testing.T) { - testCases := map[string]struct { - requestConfig gsRequestConfig - responseConfig gsResponseConfig - updatedConfig gsRequestConfig - events fakeEvents - action func(gsData *harness) - check func(t *testing.T, events *fakeEvents, gsData *harness) - protocol protocol.ID - }{ - "gs outgoing request with recognized dt pull channel will record incoming blocks": { - action: func(gsData *harness) { - gsData.outgoingRequestHook() - gsData.incomingBlockHook() - }, - check: func(t *testing.T, events *fakeEvents, gsData *harness) { - require.Equal(t, events.ChannelOpenedChannelID, datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.other, Initiator: gsData.self}) - require.True(t, events.OnDataReceivedCalled) - require.NoError(t, gsData.incomingBlockHookActions.TerminationError) - }, - }, - "gs outgoing request with recognized dt push channel will record incoming blocks": { - requestConfig: gsRequestConfig{ - dtIsResponse: true, - }, - action: func(gsData *harness) { - gsData.outgoingRequestHook() - gsData.incomingBlockHook() - }, - check: func(t *testing.T, events *fakeEvents, gsData *harness) { - require.Equal(t, events.ChannelOpenedChannelID, datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.self, Initiator: gsData.other}) - require.True(t, events.OnDataReceivedCalled) - require.NoError(t, gsData.incomingBlockHookActions.TerminationError) - }, - }, - "non-data-transfer gs request will not record incoming blocks and send updates": { - requestConfig: gsRequestConfig{ - dtExtensionMissing: true, - }, - action: func(gsData *harness) { - gsData.outgoingRequestHook() - gsData.incomingBlockHook() - }, - check: func(t *testing.T, events *fakeEvents, gsData *harness) { - require.Equal(t, events.ChannelOpenedChannelID, datatransfer.ChannelID{}) - require.False(t, events.OnDataReceivedCalled) - require.NoError(t, gsData.incomingBlockHookActions.TerminationError) - }, - }, - "gs request unrecognized opened channel will not record incoming blocks": { - events: fakeEvents{ - OnChannelOpenedError: errors.New("Not recognized"), - }, - action: func(gsData *harness) { - gsData.outgoingRequestHook() - gsData.incomingBlockHook() - }, - check: func(t *testing.T, events *fakeEvents, gsData *harness) { - require.Equal(t, events.ChannelOpenedChannelID, datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.other, Initiator: gsData.self}) - require.False(t, events.OnDataReceivedCalled) - require.NoError(t, gsData.incomingBlockHookActions.TerminationError) - }, - }, - "gs incoming block with data receive error will halt request": { - events: fakeEvents{ - OnDataReceivedError: errors.New("something went wrong"), - }, - action: func(gsData *harness) { - gsData.outgoingRequestHook() - gsData.incomingBlockHook() - }, - check: func(t *testing.T, events *fakeEvents, gsData *harness) { - require.Equal(t, events.ChannelOpenedChannelID, datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.other, Initiator: gsData.self}) - require.True(t, events.OnDataReceivedCalled) - require.Error(t, gsData.incomingBlockHookActions.TerminationError) - }, - }, - "outgoing gs request with recognized dt request can receive gs response": { - responseConfig: gsResponseConfig{ - dtIsResponse: true, - }, - action: func(gsData *harness) { - gsData.outgoingRequestHook() - gsData.incomingResponseHOok() - }, - check: func(t *testing.T, events *fakeEvents, gsData *harness) { - require.Equal(t, events.ChannelOpenedChannelID, datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.other, Initiator: gsData.self}) - require.Equal(t, 0, events.OnRequestReceivedCallCount) - require.Equal(t, 1, events.OnResponseReceivedCallCount) - require.NoError(t, gsData.incomingResponseHookActions.TerminationError) - }, - }, - "outgoing gs request with recognized dt request cannot receive gs response with dt request": { - action: func(gsData *harness) { - gsData.outgoingRequestHook() - gsData.incomingResponseHOok() - }, - check: func(t *testing.T, events *fakeEvents, gsData *harness) { - require.Equal(t, events.ChannelOpenedChannelID, datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.other, Initiator: gsData.self}) - require.Equal(t, 0, events.OnRequestReceivedCallCount) - require.Equal(t, 0, events.OnResponseReceivedCallCount) - require.Error(t, gsData.incomingResponseHookActions.TerminationError) - }, - }, - "outgoing gs request with recognized dt response can receive gs response": { - requestConfig: gsRequestConfig{ - dtIsResponse: true, - }, - action: func(gsData *harness) { - gsData.outgoingRequestHook() - gsData.incomingResponseHOok() - }, - check: func(t *testing.T, events *fakeEvents, gsData *harness) { - require.Equal(t, events.ChannelOpenedChannelID, datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.self, Initiator: gsData.other}) - require.Equal(t, 1, events.OnRequestReceivedCallCount) - require.Equal(t, 0, events.OnResponseReceivedCallCount) - require.NoError(t, gsData.incomingResponseHookActions.TerminationError) - }, - }, - "outgoing gs request with recognized dt response cannot receive gs response with dt response": { - requestConfig: gsRequestConfig{ - dtIsResponse: true, - }, - responseConfig: gsResponseConfig{ - dtIsResponse: true, - }, - action: func(gsData *harness) { - gsData.outgoingRequestHook() - gsData.incomingResponseHOok() - }, - check: func(t *testing.T, events *fakeEvents, gsData *harness) { - require.Equal(t, events.ChannelOpenedChannelID, datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.self, Initiator: gsData.other}) - require.Equal(t, 0, events.OnResponseReceivedCallCount) - require.Equal(t, 0, events.OnRequestReceivedCallCount) - require.Error(t, gsData.incomingResponseHookActions.TerminationError) - }, - }, - "outgoing gs request with recognized dt request will error with malformed update": { - responseConfig: gsResponseConfig{ - dtExtensionMalformed: true, - }, - action: func(gsData *harness) { - gsData.outgoingRequestHook() - gsData.incomingResponseHOok() - }, - check: func(t *testing.T, events *fakeEvents, gsData *harness) { - require.Equal(t, events.ChannelOpenedChannelID, datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.other, Initiator: gsData.self}) - require.Equal(t, 0, events.OnRequestReceivedCallCount) - require.Equal(t, 0, events.OnResponseReceivedCallCount) - require.Error(t, gsData.incomingResponseHookActions.TerminationError) - }, - }, - "outgoing gs request with recognized dt request will ignore non-data-transfer update": { - responseConfig: gsResponseConfig{ - dtExtensionMissing: true, - }, - action: func(gsData *harness) { - gsData.outgoingRequestHook() - gsData.incomingResponseHOok() - }, - check: func(t *testing.T, events *fakeEvents, gsData *harness) { - require.Equal(t, events.ChannelOpenedChannelID, datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.other, Initiator: gsData.self}) - require.Equal(t, 0, events.OnRequestReceivedCallCount) - require.Equal(t, 0, events.OnResponseReceivedCallCount) - require.NoError(t, gsData.incomingResponseHookActions.TerminationError) - }, - }, - "outgoing gs request with recognized dt response can send message on update": { - events: fakeEvents{ - RequestReceivedResponse: testutil.NewDTResponse(t, datatransfer.TransferID(rand.Uint32())), - }, - requestConfig: gsRequestConfig{ - dtIsResponse: true, - }, - action: func(gsData *harness) { - gsData.outgoingRequestHook() - gsData.incomingResponseHOok() - }, - check: func(t *testing.T, events *fakeEvents, gsData *harness) { - require.Equal(t, events.ChannelOpenedChannelID, datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.self, Initiator: gsData.other}) - require.Equal(t, 1, events.OnRequestReceivedCallCount) - require.Equal(t, 0, events.OnResponseReceivedCallCount) - require.NoError(t, gsData.incomingResponseHookActions.TerminationError) - assertHasOutgoingMessage(t, gsData.incomingResponseHookActions.SentExtensions, - events.RequestReceivedResponse) - }, - }, - "outgoing gs request with recognized dt response err will error": { - requestConfig: gsRequestConfig{ - dtIsResponse: true, - }, - events: fakeEvents{ - OnRequestReceivedErrors: []error{errors.New("something went wrong")}, - }, - action: func(gsData *harness) { - gsData.outgoingRequestHook() - gsData.incomingResponseHOok() - }, - check: func(t *testing.T, events *fakeEvents, gsData *harness) { - require.Equal(t, 0, events.OnResponseReceivedCallCount) - require.Equal(t, 1, events.OnRequestReceivedCallCount) - require.Error(t, gsData.incomingResponseHookActions.TerminationError) - }, - }, - "incoming gs request with recognized dt request will validate gs request & send dt response": { - action: func(gsData *harness) { - gsData.incomingRequestHook() - }, - events: fakeEvents{ - RequestReceivedResponse: testutil.NewDTResponse(t, datatransfer.TransferID(rand.Uint32())), - }, - check: func(t *testing.T, events *fakeEvents, gsData *harness) { - require.Equal(t, 1, events.OnRequestReceivedCallCount) - require.Equal(t, 0, events.OnResponseReceivedCallCount) - require.Equal(t, events.RequestReceivedChannelID, datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.self, Initiator: gsData.other}) - dtRequestData, _ := gsData.request.Extension(extension.ExtensionDataTransfer1_1) - assertDecodesToMessage(t, dtRequestData, events.RequestReceivedRequest) - require.True(t, gsData.incomingRequestHookActions.Validated) - assertHasExtensionMessage(t, extension.ExtensionDataTransfer1_1, gsData.incomingRequestHookActions.SentExtensions, events.RequestReceivedResponse) - require.NoError(t, gsData.incomingRequestHookActions.TerminationError) - - channelsForPeer := gsData.transport.ChannelsForPeer(gsData.other) - require.Equal(t, channelsForPeer, ChannelsForPeer{ - SendingChannels: map[datatransfer.ChannelID]ChannelGraphsyncRequests{ - events.RequestReceivedChannelID: { - Current: gsData.request.ID(), - }, - }, - ReceivingChannels: map[datatransfer.ChannelID]ChannelGraphsyncRequests{}, - }) - }, - }, - "incoming gs request with recognized dt response will validate gs request": { - requestConfig: gsRequestConfig{ - dtIsResponse: true, - }, - action: func(gsData *harness) { - gsData.incomingRequestHook() - }, - check: func(t *testing.T, events *fakeEvents, gsData *harness) { - require.Equal(t, 0, events.OnRequestReceivedCallCount) - require.Equal(t, 1, events.OnResponseReceivedCallCount) - require.Equal(t, events.ResponseReceivedChannelID, datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.other, Initiator: gsData.self}) - dtResponseData, _ := gsData.request.Extension(extension.ExtensionDataTransfer1_1) - assertDecodesToMessage(t, dtResponseData, events.ResponseReceivedResponse) - require.True(t, gsData.incomingRequestHookActions.Validated) - require.NoError(t, gsData.incomingRequestHookActions.TerminationError) - }, - }, - "malformed data transfer extension on incoming request will terminate": { - requestConfig: gsRequestConfig{ - dtExtensionMalformed: true, - }, - action: func(gsData *harness) { - gsData.incomingRequestHook() - }, - check: func(t *testing.T, events *fakeEvents, gsData *harness) { - require.Equal(t, 0, events.OnRequestReceivedCallCount) - require.False(t, gsData.incomingRequestHookActions.Validated) - require.Error(t, gsData.incomingRequestHookActions.TerminationError) - }, - }, - "unrecognized incoming dt request will terminate but send response": { - events: fakeEvents{ - RequestReceivedResponse: testutil.NewDTResponse(t, datatransfer.TransferID(rand.Uint32())), - OnRequestReceivedErrors: []error{errors.New("something went wrong")}, - }, - action: func(gsData *harness) { - gsData.incomingRequestHook() - }, - check: func(t *testing.T, events *fakeEvents, gsData *harness) { - require.Equal(t, 1, events.OnRequestReceivedCallCount) - require.Equal(t, 0, events.OnResponseReceivedCallCount) - require.Equal(t, events.RequestReceivedChannelID, datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.self, Initiator: gsData.other}) - dtRequestData, _ := gsData.request.Extension(extension.ExtensionDataTransfer1_1) - assertDecodesToMessage(t, dtRequestData, events.RequestReceivedRequest) - require.False(t, gsData.incomingRequestHookActions.Validated) - assertHasExtensionMessage(t, extension.ExtensionIncomingRequest1_1, gsData.incomingRequestHookActions.SentExtensions, events.RequestReceivedResponse) - require.Error(t, gsData.incomingRequestHookActions.TerminationError) - }, - }, - "incoming gs request with recognized dt request will record outgoing blocks": { - action: func(gsData *harness) { - gsData.incomingRequestHook() - gsData.outgoingBlockHook() - }, - check: func(t *testing.T, events *fakeEvents, gsData *harness) { - require.Equal(t, 1, events.OnRequestReceivedCallCount) - require.True(t, events.OnDataQueuedCalled) - require.NoError(t, gsData.outgoingBlockHookActions.TerminationError) - }, - }, - - "incoming gs request with recognized dt response will record outgoing blocks": { - requestConfig: gsRequestConfig{ - dtIsResponse: true, - }, - action: func(gsData *harness) { - gsData.incomingRequestHook() - gsData.outgoingBlockHook() - }, - check: func(t *testing.T, events *fakeEvents, gsData *harness) { - require.Equal(t, 1, events.OnResponseReceivedCallCount) - require.True(t, events.OnDataQueuedCalled) - require.NoError(t, gsData.outgoingBlockHookActions.TerminationError) - }, - }, - "non-data-transfer request will not record outgoing blocks": { - requestConfig: gsRequestConfig{ - dtExtensionMissing: true, - }, - action: func(gsData *harness) { - gsData.incomingRequestHook() - gsData.outgoingBlockHook() - }, - check: func(t *testing.T, events *fakeEvents, gsData *harness) { - require.Equal(t, 0, events.OnRequestReceivedCallCount) - require.False(t, events.OnDataQueuedCalled) - }, - }, - "outgoing data queued error will terminate request": { - events: fakeEvents{ - OnDataQueuedError: errors.New("something went wrong"), - }, - action: func(gsData *harness) { - gsData.incomingRequestHook() - gsData.outgoingBlockHook() - }, - check: func(t *testing.T, events *fakeEvents, gsData *harness) { - require.Equal(t, 1, events.OnRequestReceivedCallCount) - require.True(t, events.OnDataQueuedCalled) - require.Error(t, gsData.outgoingBlockHookActions.TerminationError) - }, - }, - "outgoing data queued error == pause will pause request": { - events: fakeEvents{ - OnDataQueuedError: datatransfer.ErrPause, - }, - action: func(gsData *harness) { - gsData.incomingRequestHook() - gsData.outgoingBlockHook() - }, - check: func(t *testing.T, events *fakeEvents, gsData *harness) { - require.Equal(t, 1, events.OnRequestReceivedCallCount) - require.True(t, events.OnDataQueuedCalled) - require.True(t, gsData.outgoingBlockHookActions.Paused) - require.NoError(t, gsData.outgoingBlockHookActions.TerminationError) - }, - }, - "incoming gs request with recognized dt request will send updates": { - action: func(gsData *harness) { - gsData.incomingRequestHook() - gsData.outgoingBlockHook() - }, - events: fakeEvents{ - OnDataQueuedMessage: testutil.NewDTResponse(t, datatransfer.TransferID(rand.Uint32())), - }, - check: func(t *testing.T, events *fakeEvents, gsData *harness) { - require.Equal(t, 1, events.OnRequestReceivedCallCount) - require.True(t, events.OnDataQueuedCalled) - require.NoError(t, gsData.outgoingBlockHookActions.TerminationError) - assertHasExtensionMessage(t, extension.ExtensionOutgoingBlock1_1, gsData.outgoingBlockHookActions.SentExtensions, - events.OnDataQueuedMessage) - }, - }, - "incoming gs request with recognized dt request can receive update": { - action: func(gsData *harness) { - gsData.incomingRequestHook() - gsData.requestUpdatedHook() - }, - check: func(t *testing.T, events *fakeEvents, gsData *harness) { - require.Equal(t, 2, events.OnRequestReceivedCallCount) - require.NoError(t, gsData.requestUpdatedHookActions.TerminationError) - }, - }, - "incoming gs request with recognized dt request cannot receive update with dt response": { - updatedConfig: gsRequestConfig{ - dtIsResponse: true, - }, - action: func(gsData *harness) { - gsData.incomingRequestHook() - gsData.requestUpdatedHook() - }, - check: func(t *testing.T, events *fakeEvents, gsData *harness) { - require.Equal(t, 1, events.OnRequestReceivedCallCount) - require.Equal(t, 0, events.OnResponseReceivedCallCount) - require.Error(t, gsData.requestUpdatedHookActions.TerminationError) - }, - }, - "incoming gs request with recognized dt response can receive update": { - requestConfig: gsRequestConfig{ - dtIsResponse: true, - }, - updatedConfig: gsRequestConfig{ - dtIsResponse: true, - }, - action: func(gsData *harness) { - gsData.incomingRequestHook() - gsData.requestUpdatedHook() - }, - check: func(t *testing.T, events *fakeEvents, gsData *harness) { - require.Equal(t, 2, events.OnResponseReceivedCallCount) - require.NoError(t, gsData.requestUpdatedHookActions.TerminationError) - }, - }, - "incoming gs request with recognized dt response cannot receive update with dt request": { - requestConfig: gsRequestConfig{ - dtIsResponse: true, - }, - action: func(gsData *harness) { - gsData.incomingRequestHook() - gsData.requestUpdatedHook() - }, - check: func(t *testing.T, events *fakeEvents, gsData *harness) { - require.Equal(t, 1, events.OnResponseReceivedCallCount) - require.Equal(t, 0, events.OnRequestReceivedCallCount) - require.Error(t, gsData.requestUpdatedHookActions.TerminationError) - }, - }, - "incoming gs request with recognized dt request will error with malformed update": { - updatedConfig: gsRequestConfig{ - dtExtensionMalformed: true, - }, - action: func(gsData *harness) { - gsData.incomingRequestHook() - gsData.requestUpdatedHook() - }, - check: func(t *testing.T, events *fakeEvents, gsData *harness) { - require.Equal(t, 1, events.OnRequestReceivedCallCount) - require.Error(t, gsData.requestUpdatedHookActions.TerminationError) - }, - }, - "incoming gs request with recognized dt request will ignore non-data-transfer update": { - updatedConfig: gsRequestConfig{ - dtExtensionMissing: true, - }, - action: func(gsData *harness) { - gsData.incomingRequestHook() - gsData.requestUpdatedHook() - }, - check: func(t *testing.T, events *fakeEvents, gsData *harness) { - require.Equal(t, 1, events.OnRequestReceivedCallCount) - require.NoError(t, gsData.requestUpdatedHookActions.TerminationError) - }, - }, - "incoming gs request with recognized dt request can send message on update": { - events: fakeEvents{ - RequestReceivedResponse: testutil.NewDTResponse(t, datatransfer.TransferID(rand.Uint32())), - }, - action: func(gsData *harness) { - gsData.incomingRequestHook() - gsData.requestUpdatedHook() - }, - check: func(t *testing.T, events *fakeEvents, gsData *harness) { - require.Equal(t, 2, events.OnRequestReceivedCallCount) - require.NoError(t, gsData.requestUpdatedHookActions.TerminationError) - assertHasOutgoingMessage(t, gsData.requestUpdatedHookActions.SentExtensions, - events.RequestReceivedResponse) - }, - }, - "recognized incoming request will record successful request completion": { - responseConfig: gsResponseConfig{ - status: graphsync.RequestCompletedFull, - }, - action: func(gsData *harness) { - gsData.incomingRequestHook() - gsData.responseCompletedListener() - }, - check: func(t *testing.T, events *fakeEvents, gsData *harness) { - require.Equal(t, 1, events.OnRequestReceivedCallCount) - require.True(t, events.OnChannelCompletedCalled) - require.True(t, events.ChannelCompletedSuccess) - }, - }, - - "recognized incoming request will record unsuccessful request completion": { - responseConfig: gsResponseConfig{ - status: graphsync.RequestCompletedPartial, - }, - action: func(gsData *harness) { - gsData.incomingRequestHook() - gsData.responseCompletedListener() - }, - check: func(t *testing.T, events *fakeEvents, gsData *harness) { - require.Equal(t, 1, events.OnRequestReceivedCallCount) - require.True(t, events.OnChannelCompletedCalled) - require.False(t, events.ChannelCompletedSuccess) - }, - }, - "recognized incoming request will not record request cancellation": { - responseConfig: gsResponseConfig{ - status: graphsync.RequestCancelled, - }, - action: func(gsData *harness) { - gsData.incomingRequestHook() - gsData.responseCompletedListener() - }, - check: func(t *testing.T, events *fakeEvents, gsData *harness) { - require.Equal(t, 1, events.OnRequestReceivedCallCount) - require.False(t, events.OnChannelCompletedCalled) - }, - }, - "non-data-transfer request will not record request completed": { - requestConfig: gsRequestConfig{ - dtExtensionMissing: true, - }, - responseConfig: gsResponseConfig{ - status: graphsync.RequestCompletedPartial, - }, - action: func(gsData *harness) { - gsData.incomingRequestHook() - gsData.responseCompletedListener() - }, - check: func(t *testing.T, events *fakeEvents, gsData *harness) { - require.Equal(t, 0, events.OnRequestReceivedCallCount) - require.False(t, events.OnChannelCompletedCalled) - }, - }, - "recognized incoming request can be closed": { - action: func(gsData *harness) { - gsData.incomingRequestHook() - }, - check: func(t *testing.T, events *fakeEvents, gsData *harness) { - err := gsData.transport.CloseChannel(gsData.ctx, datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.self, Initiator: gsData.other}) - require.NoError(t, err) - require.Equal(t, 1, events.OnRequestReceivedCallCount) - gsData.fgs.AssertCancelReceived(gsData.ctx, t) - }, - }, - "unrecognized request cannot be closed": { - check: func(t *testing.T, events *fakeEvents, gsData *harness) { - err := gsData.transport.CloseChannel(gsData.ctx, datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.self, Initiator: gsData.other}) - require.Error(t, err) - }, - }, - "recognized incoming request that requestor cancelled will not close via graphsync": { - action: func(gsData *harness) { - gsData.incomingRequestHook() - gsData.requestorCancelledListener() - }, - check: func(t *testing.T, events *fakeEvents, gsData *harness) { - err := gsData.transport.CloseChannel(gsData.ctx, datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.self, Initiator: gsData.other}) - require.NoError(t, err) - require.Equal(t, 1, events.OnRequestReceivedCallCount) - gsData.fgs.AssertNoCancelReceived(t) - }, - }, - "recognized incoming request can be paused": { - action: func(gsData *harness) { - gsData.incomingRequestHook() - }, - check: func(t *testing.T, events *fakeEvents, gsData *harness) { - err := gsData.transport.PauseChannel(gsData.ctx, datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.self, Initiator: gsData.other}) - require.NoError(t, err) - require.Equal(t, 1, events.OnRequestReceivedCallCount) - gsData.fgs.AssertPauseReceived(gsData.ctx, t) - }, - }, - "unrecognized request cannot be paused": { - check: func(t *testing.T, events *fakeEvents, gsData *harness) { - err := gsData.transport.PauseChannel(gsData.ctx, datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.self, Initiator: gsData.other}) - require.Error(t, err) - }, - }, - "recognized incoming request that requestor cancelled will not pause via graphsync": { - action: func(gsData *harness) { - gsData.incomingRequestHook() - gsData.requestorCancelledListener() - }, - check: func(t *testing.T, events *fakeEvents, gsData *harness) { - err := gsData.transport.PauseChannel(gsData.ctx, datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.self, Initiator: gsData.other}) - require.NoError(t, err) - require.Equal(t, 1, events.OnRequestReceivedCallCount) - gsData.fgs.AssertNoPauseReceived(t) - }, - }, - - "incoming request can be queued": { - action: func(gsData *harness) { - gsData.incomingRequestQueuedHook() - }, - check: func(t *testing.T, events *fakeEvents, gsData *harness) { - require.True(t, events.TransferQueuedCalled) - require.Equal(t, datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.self, Initiator: gsData.other}, - events.TransferQueuedChannelID) - }, - }, - - "incoming request with dtResponse can be queued": { - requestConfig: gsRequestConfig{ - dtIsResponse: true, - }, - responseConfig: gsResponseConfig{ - dtIsResponse: true, - }, - action: func(gsData *harness) { - gsData.incomingRequestQueuedHook() - }, - check: func(t *testing.T, events *fakeEvents, gsData *harness) { - require.True(t, events.TransferQueuedCalled) - require.Equal(t, datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.other, Initiator: gsData.self}, - events.TransferQueuedChannelID) - }, - }, - - "recognized incoming request can be resumed": { - action: func(gsData *harness) { - gsData.incomingRequestHook() - }, - check: func(t *testing.T, events *fakeEvents, gsData *harness) { - err := gsData.transport.ResumeChannel(gsData.ctx, - gsData.incoming, - datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.self, Initiator: gsData.other}, - ) - require.NoError(t, err) - require.Equal(t, 1, events.OnRequestReceivedCallCount) - gsData.fgs.AssertResumeReceived(gsData.ctx, t) - }, - }, - - "unrecognized request cannot be resumed": { - check: func(t *testing.T, events *fakeEvents, gsData *harness) { - err := gsData.transport.ResumeChannel(gsData.ctx, - gsData.incoming, - datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.self, Initiator: gsData.other}, - ) - require.Error(t, err) - }, - }, - "recognized incoming request that requestor cancelled will not resume via graphsync but will resume otherwise": { - action: func(gsData *harness) { - gsData.incomingRequestHook() - gsData.requestorCancelledListener() - }, - check: func(t *testing.T, events *fakeEvents, gsData *harness) { - err := gsData.transport.ResumeChannel(gsData.ctx, - gsData.incoming, - datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.self, Initiator: gsData.other}, - ) - require.NoError(t, err) - require.Equal(t, 1, events.OnRequestReceivedCallCount) - gsData.fgs.AssertNoResumeReceived(t) - gsData.incomingRequestHook() - assertHasOutgoingMessage(t, gsData.incomingRequestHookActions.SentExtensions, gsData.incoming) - }, - }, - "recognized incoming request will record network send error": { - action: func(gsData *harness) { - gsData.incomingRequestHook() - gsData.networkErrorListener(errors.New("something went wrong")) - }, - check: func(t *testing.T, events *fakeEvents, gsData *harness) { - require.Equal(t, 1, events.OnRequestReceivedCallCount) - require.True(t, events.OnSendDataErrorCalled) - }, - }, - "recognized outgoing request will record network send error": { - action: func(gsData *harness) { - gsData.outgoingRequestHook() - gsData.networkErrorListener(errors.New("something went wrong")) - }, - check: func(t *testing.T, events *fakeEvents, gsData *harness) { - require.True(t, events.OnSendDataErrorCalled) - }, - }, - "recognized incoming request will record network receive error": { - action: func(gsData *harness) { - gsData.incomingRequestHook() - gsData.receiverNetworkErrorListener(errors.New("something went wrong")) - }, - check: func(t *testing.T, events *fakeEvents, gsData *harness) { - require.Equal(t, 1, events.OnRequestReceivedCallCount) - require.True(t, events.OnReceiveDataErrorCalled) - }, - }, - "recognized outgoing request will record network receive error": { - action: func(gsData *harness) { - gsData.outgoingRequestHook() - gsData.receiverNetworkErrorListener(errors.New("something went wrong")) - }, - check: func(t *testing.T, events *fakeEvents, gsData *harness) { - require.True(t, events.OnReceiveDataErrorCalled) - }, - }, - "open channel adds block count to the DoNotSendFirstBlocks extension for v1.2 protocol": { - action: func(gsData *harness) { - cids := testutil.GenerateCids(2) - channel := &mockChannelState{receivedCids: cids} - stor, _ := gsData.outgoing.Selector() - - go gsData.outgoingRequestHook() - _ = gsData.transport.OpenChannel( - gsData.ctx, - gsData.other, - datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.other, Initiator: gsData.self}, - cidlink.Link{Cid: gsData.outgoing.BaseCid()}, - stor, - channel, - gsData.outgoing) - }, - check: func(t *testing.T, events *fakeEvents, gsData *harness) { - requestReceived := gsData.fgs.AssertRequestReceived(gsData.ctx, t) - - ext := requestReceived.Extensions - require.Len(t, ext, 2) - doNotSend := ext[1] - - name := doNotSend.Name - require.Equal(t, graphsync.ExtensionsDoNotSendFirstBlocks, name) - data := doNotSend.Data - blockCount, err := donotsendfirstblocks.DecodeDoNotSendFirstBlocks(data) - require.NoError(t, err) - require.EqualValues(t, blockCount, 2) - }, - }, - "ChannelsForPeer when request is open": { - action: func(gsData *harness) { - cids := testutil.GenerateCids(2) - channel := &mockChannelState{receivedCids: cids} - stor, _ := gsData.outgoing.Selector() - - go gsData.outgoingRequestHook() - _ = gsData.transport.OpenChannel( - gsData.ctx, - gsData.other, - datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.other, Initiator: gsData.self}, - cidlink.Link{Cid: gsData.outgoing.BaseCid()}, - stor, - channel, - gsData.outgoing) - }, - check: func(t *testing.T, events *fakeEvents, gsData *harness) { - gsData.fgs.AssertRequestReceived(gsData.ctx, t) - - channelsForPeer := gsData.transport.ChannelsForPeer(gsData.other) - require.Equal(t, channelsForPeer, ChannelsForPeer{ - ReceivingChannels: map[datatransfer.ChannelID]ChannelGraphsyncRequests{ - events.ChannelOpenedChannelID: { - Current: gsData.request.ID(), - }, - }, - SendingChannels: map[datatransfer.ChannelID]ChannelGraphsyncRequests{}, - }) - }, - }, - "open channel cancels an existing request with the same channel ID": { - action: func(gsData *harness) { - cids := testutil.GenerateCids(2) - channel := &mockChannelState{receivedCids: cids} - stor, _ := gsData.outgoing.Selector() - go gsData.outgoingRequestHook() - _ = gsData.transport.OpenChannel( - gsData.ctx, - gsData.other, - datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.other, Initiator: gsData.self}, - cidlink.Link{Cid: gsData.outgoing.BaseCid()}, - stor, - channel, - gsData.outgoing) - - go gsData.altOutgoingRequestHook() - _ = gsData.transport.OpenChannel( - gsData.ctx, - gsData.other, - datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.other, Initiator: gsData.self}, - cidlink.Link{Cid: gsData.outgoing.BaseCid()}, - stor, - channel, - gsData.outgoing) - }, - check: func(t *testing.T, events *fakeEvents, gsData *harness) { - gsData.fgs.AssertRequestReceived(gsData.ctx, t) - gsData.fgs.AssertRequestReceived(gsData.ctx, t) - - ctxt, cancel := context.WithTimeout(context.Background(), 2*time.Second) - defer cancel() - gsData.fgs.AssertCancelReceived(ctxt, t) - - channelsForPeer := gsData.transport.ChannelsForPeer(gsData.other) - require.Equal(t, channelsForPeer, ChannelsForPeer{ - ReceivingChannels: map[datatransfer.ChannelID]ChannelGraphsyncRequests{ - events.ChannelOpenedChannelID: { - Current: gsData.altRequest.ID(), - Previous: []graphsync.RequestID{gsData.request.ID()}, - }, - }, - SendingChannels: map[datatransfer.ChannelID]ChannelGraphsyncRequests{}, - }) - }, - }, - "OnChannelCompleted called when outgoing request completes successfully": { - action: func(gsData *harness) { - gsData.fgs.LeaveRequestsOpen() - stor, _ := gsData.outgoing.Selector() - - go gsData.outgoingRequestHook() - _ = gsData.transport.OpenChannel( - gsData.ctx, - gsData.other, - datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.other, Initiator: gsData.self}, - cidlink.Link{Cid: gsData.outgoing.BaseCid()}, - stor, - nil, - gsData.outgoing) - }, - check: func(t *testing.T, events *fakeEvents, gsData *harness) { - requestReceived := gsData.fgs.AssertRequestReceived(gsData.ctx, t) - close(requestReceived.ResponseChan) - close(requestReceived.ResponseErrChan) - - require.Eventually(t, func() bool { - return events.OnChannelCompletedCalled == true - }, 2*time.Second, 100*time.Millisecond) - require.True(t, events.ChannelCompletedSuccess) - }, - }, - "OnChannelCompleted called when outgoing request completes with error": { - action: func(gsData *harness) { - gsData.fgs.LeaveRequestsOpen() - stor, _ := gsData.outgoing.Selector() - - go gsData.outgoingRequestHook() - _ = gsData.transport.OpenChannel( - gsData.ctx, - gsData.other, - datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.other, Initiator: gsData.self}, - cidlink.Link{Cid: gsData.outgoing.BaseCid()}, - stor, - nil, - gsData.outgoing) - }, - check: func(t *testing.T, events *fakeEvents, gsData *harness) { - requestReceived := gsData.fgs.AssertRequestReceived(gsData.ctx, t) - close(requestReceived.ResponseChan) - requestReceived.ResponseErrChan <- graphsync.RequestFailedUnknownErr{} - close(requestReceived.ResponseErrChan) - - require.Eventually(t, func() bool { - return events.OnChannelCompletedCalled == true - }, 2*time.Second, 100*time.Millisecond) - require.False(t, events.ChannelCompletedSuccess) - }, - }, - "OnChannelComplete when outgoing request cancelled by caller": { - action: func(gsData *harness) { - gsData.fgs.LeaveRequestsOpen() - stor, _ := gsData.outgoing.Selector() - - go gsData.outgoingRequestHook() - _ = gsData.transport.OpenChannel( - gsData.ctx, - gsData.other, - datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.other, Initiator: gsData.self}, - cidlink.Link{Cid: gsData.outgoing.BaseCid()}, - stor, - nil, - gsData.outgoing) - }, - check: func(t *testing.T, events *fakeEvents, gsData *harness) { - requestReceived := gsData.fgs.AssertRequestReceived(gsData.ctx, t) - extensions := make(map[graphsync.ExtensionName]datamodel.Node) - for _, ext := range requestReceived.Extensions { - extensions[ext.Name] = ext.Data - } - request := testutil.NewFakeRequest(graphsync.NewRequestID(), extensions) - gsData.fgs.OutgoingRequestHook(gsData.other, request, gsData.outgoingRequestHookActions) - _ = gsData.transport.CloseChannel(gsData.ctx, datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.other, Initiator: gsData.self}) - ctxt, cancel := context.WithTimeout(context.Background(), 2*time.Second) - defer cancel() - gsData.fgs.AssertCancelReceived(ctxt, t) - }, - }, - "request times out if we get request context cancelled error": { - action: func(gsData *harness) { - gsData.fgs.LeaveRequestsOpen() - stor, _ := gsData.outgoing.Selector() - - go gsData.outgoingRequestHook() - _ = gsData.transport.OpenChannel( - gsData.ctx, - gsData.other, - datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.other, Initiator: gsData.self}, - cidlink.Link{Cid: gsData.outgoing.BaseCid()}, - stor, - nil, - gsData.outgoing) - }, - check: func(t *testing.T, events *fakeEvents, gsData *harness) { - requestReceived := gsData.fgs.AssertRequestReceived(gsData.ctx, t) - close(requestReceived.ResponseChan) - requestReceived.ResponseErrChan <- graphsync.RequestClientCancelledErr{} - close(requestReceived.ResponseErrChan) - - require.Eventually(t, func() bool { - return events.OnRequestCancelledCalled == true - }, 2*time.Second, 100*time.Millisecond) - require.Equal(t, datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.other, Initiator: gsData.self}, events.OnRequestCancelledChannelId) - }, - }, - "request cancelled out if transport shuts down": { - action: func(gsData *harness) { - gsData.fgs.LeaveRequestsOpen() - stor, _ := gsData.outgoing.Selector() - - go gsData.outgoingRequestHook() - _ = gsData.transport.OpenChannel( - gsData.ctx, - gsData.other, - datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.other, Initiator: gsData.self}, - cidlink.Link{Cid: gsData.outgoing.BaseCid()}, - stor, - nil, - gsData.outgoing) - }, - check: func(t *testing.T, events *fakeEvents, gsData *harness) { - gsData.fgs.AssertRequestReceived(gsData.ctx, t) - - gsData.transport.Shutdown(gsData.ctx) - - ctxt, cancel := context.WithTimeout(context.Background(), 2*time.Second) - defer cancel() - gsData.fgs.AssertCancelReceived(ctxt, t) - - require.Nil(t, gsData.fgs.IncomingRequestHook) - require.Nil(t, gsData.fgs.CompletedResponseListener) - require.Nil(t, gsData.fgs.IncomingBlockHook) - require.Nil(t, gsData.fgs.OutgoingBlockHook) - require.Nil(t, gsData.fgs.BlockSentListener) - require.Nil(t, gsData.fgs.OutgoingRequestHook) - require.Nil(t, gsData.fgs.IncomingResponseHook) - require.Nil(t, gsData.fgs.RequestUpdatedHook) - require.Nil(t, gsData.fgs.RequestorCancelledListener) - require.Nil(t, gsData.fgs.NetworkErrorListener) - }, - }, - "request pause works even if called when request is still pending": { - action: func(gsData *harness) { - gsData.fgs.LeaveRequestsOpen() - stor, _ := gsData.outgoing.Selector() - - go gsData.outgoingRequestHook() - _ = gsData.transport.OpenChannel( - gsData.ctx, - gsData.other, - datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.other, Initiator: gsData.self}, - cidlink.Link{Cid: gsData.outgoing.BaseCid()}, - stor, - nil, - gsData.outgoing) - - }, - check: func(t *testing.T, events *fakeEvents, gsData *harness) { - requestReceived := gsData.fgs.AssertRequestReceived(gsData.ctx, t) - assertHasOutgoingMessage(t, requestReceived.Extensions, gsData.outgoing) - completed := make(chan struct{}) - go func() { - err := gsData.transport.PauseChannel(context.Background(), datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.other, Initiator: gsData.self}) - require.NoError(t, err) - close(completed) - }() - time.Sleep(100 * time.Millisecond) - extensions := make(map[graphsync.ExtensionName]datamodel.Node) - for _, ext := range requestReceived.Extensions { - extensions[ext.Name] = ext.Data - } - request := testutil.NewFakeRequest(graphsync.NewRequestID(), extensions) - gsData.fgs.OutgoingRequestHook(gsData.other, request, gsData.outgoingRequestHookActions) - select { - case <-gsData.ctx.Done(): - t.Fatal("never paused channel") - case <-completed: - } - }, - }, - "UseStore can change store used for outgoing requests": { - action: func(gsData *harness) { - lsys := cidlink.DefaultLinkSystem() - lsys.StorageReadOpener = func(ipld.LinkContext, ipld.Link) (io.Reader, error) { - return nil, nil - } - lsys.StorageWriteOpener = func(ipld.LinkContext) (io.Writer, ipld.BlockWriteCommitter, error) { - return nil, nil, nil - } - _ = gsData.transport.UseStore(datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.other, Initiator: gsData.self}, lsys) - gsData.outgoingRequestHook() - }, - check: func(t *testing.T, events *fakeEvents, gsData *harness) { - expectedChannel := "data-transfer-" + datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.other, Initiator: gsData.self}.String() - gsData.fgs.AssertHasPersistenceOption(t, expectedChannel) - require.Equal(t, expectedChannel, gsData.outgoingRequestHookActions.PersistenceOption) - gsData.transport.CleanupChannel(datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.other, Initiator: gsData.self}) - gsData.fgs.AssertDoesNotHavePersistenceOption(t, expectedChannel) - }, - }, - "UseStore can change store used for incoming requests": { - action: func(gsData *harness) { - lsys := cidlink.DefaultLinkSystem() - lsys.StorageReadOpener = func(ipld.LinkContext, ipld.Link) (io.Reader, error) { - return nil, nil - } - lsys.StorageWriteOpener = func(ipld.LinkContext) (io.Writer, ipld.BlockWriteCommitter, error) { - return nil, nil, nil - } - _ = gsData.transport.UseStore(datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.self, Initiator: gsData.other}, lsys) - gsData.incomingRequestHook() - }, - check: func(t *testing.T, events *fakeEvents, gsData *harness) { - expectedChannel := "data-transfer-" + datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.self, Initiator: gsData.other}.String() - gsData.fgs.AssertHasPersistenceOption(t, expectedChannel) - require.Equal(t, expectedChannel, gsData.incomingRequestHookActions.PersistenceOption) - gsData.transport.CleanupChannel(datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.self, Initiator: gsData.other}) - gsData.fgs.AssertDoesNotHavePersistenceOption(t, expectedChannel) - }, - }, - } - - ctx := context.Background() - for testCase, data := range testCases { - t.Run(testCase, func(t *testing.T) { - ctx, cancel := context.WithTimeout(ctx, 5*time.Second) - defer cancel() - peers := testutil.GeneratePeers(2) - transferID := datatransfer.TransferID(rand.Uint32()) - requestID := graphsync.NewRequestID() - request := data.requestConfig.makeRequest(t, transferID, requestID) - altRequest := data.requestConfig.makeRequest(t, transferID, graphsync.NewRequestID()) - response := data.responseConfig.makeResponse(t, transferID, requestID) - updatedRequest := data.updatedConfig.makeRequest(t, transferID, requestID) - block := testutil.NewFakeBlockData() - fgs := testutil.NewFakeGraphSync() - outgoing := testutil.NewDTRequest(t, transferID) - incoming := testutil.NewDTResponse(t, transferID) - transport := NewTransport(peers[0], fgs) - gsData := &harness{ - ctx: ctx, - outgoing: outgoing, - incoming: incoming, - transport: transport, - fgs: fgs, - self: peers[0], - transferID: transferID, - other: peers[1], - altRequest: altRequest, - request: request, - response: response, - updatedRequest: updatedRequest, - block: block, - outgoingRequestHookActions: &testutil.FakeOutgoingRequestHookActions{}, - outgoingBlockHookActions: &testutil.FakeOutgoingBlockHookActions{}, - incomingBlockHookActions: &testutil.FakeIncomingBlockHookActions{}, - incomingRequestHookActions: &testutil.FakeIncomingRequestHookActions{}, - requestUpdatedHookActions: &testutil.FakeRequestUpdatedActions{}, - incomingResponseHookActions: &testutil.FakeIncomingResponseHookActions{}, - requestQueuedHookActions: &testutil.FakeRequestQueuedHookActions{}, - } - require.NoError(t, transport.SetEventHandler(&data.events)) - if data.action != nil { - data.action(gsData) - } - data.check(t, &data.events, gsData) - }) - } -} - -type fakeEvents struct { - ChannelOpenedChannelID datatransfer.ChannelID - RequestReceivedChannelID datatransfer.ChannelID - ResponseReceivedChannelID datatransfer.ChannelID - OnChannelOpenedError error - OnDataReceivedCalled bool - OnDataReceivedError error - OnDataSentCalled bool - OnRequestReceivedCallCount int - OnRequestReceivedErrors []error - OnResponseReceivedCallCount int - OnResponseReceivedErrors []error - OnChannelCompletedCalled bool - OnChannelCompletedErr error - OnDataQueuedCalled bool - OnDataQueuedMessage datatransfer.Message - OnDataQueuedError error - - OnRequestCancelledCalled bool - OnRequestCancelledChannelId datatransfer.ChannelID - OnSendDataErrorCalled bool - OnSendDataErrorChannelID datatransfer.ChannelID - OnReceiveDataErrorCalled bool - OnReceiveDataErrorChannelID datatransfer.ChannelID - OnContextAugmentFunc func(context.Context) context.Context - TransferQueuedCalled bool - TransferQueuedChannelID datatransfer.ChannelID - - ChannelCompletedSuccess bool - RequestReceivedRequest datatransfer.Request - RequestReceivedResponse datatransfer.Response - ResponseReceivedResponse datatransfer.Response -} - -func (fe *fakeEvents) OnDataQueued(chid datatransfer.ChannelID, link ipld.Link, size uint64, index int64, unique bool) (datatransfer.Message, error) { - fe.OnDataQueuedCalled = true - - return fe.OnDataQueuedMessage, fe.OnDataQueuedError -} - -func (fe *fakeEvents) OnRequestCancelled(chid datatransfer.ChannelID, err error) error { - fe.OnRequestCancelledCalled = true - fe.OnRequestCancelledChannelId = chid - - return nil -} - -func (fe *fakeEvents) OnTransferQueued(chid datatransfer.ChannelID) { - fe.TransferQueuedCalled = true - fe.TransferQueuedChannelID = chid -} - -func (fe *fakeEvents) OnRequestDisconnected(chid datatransfer.ChannelID, err error) error { - return nil -} - -func (fe *fakeEvents) OnSendDataError(chid datatransfer.ChannelID, err error) error { - fe.OnSendDataErrorCalled = true - fe.OnSendDataErrorChannelID = chid - return nil -} - -func (fe *fakeEvents) OnReceiveDataError(chid datatransfer.ChannelID, err error) error { - fe.OnReceiveDataErrorCalled = true - fe.OnReceiveDataErrorChannelID = chid - return nil -} - -func (fe *fakeEvents) OnChannelOpened(chid datatransfer.ChannelID) error { - fe.ChannelOpenedChannelID = chid - return fe.OnChannelOpenedError -} - -func (fe *fakeEvents) OnDataReceived(chid datatransfer.ChannelID, link ipld.Link, size uint64, index int64, unique bool) error { - fe.OnDataReceivedCalled = true - return fe.OnDataReceivedError -} - -func (fe *fakeEvents) OnDataSent(chid datatransfer.ChannelID, link ipld.Link, size uint64, index int64, unique bool) error { - fe.OnDataSentCalled = true - return nil -} - -func (fe *fakeEvents) OnRequestReceived(chid datatransfer.ChannelID, request datatransfer.Request) (datatransfer.Response, error) { - fe.OnRequestReceivedCallCount++ - fe.RequestReceivedChannelID = chid - fe.RequestReceivedRequest = request - var err error - if len(fe.OnRequestReceivedErrors) > 0 { - err, fe.OnRequestReceivedErrors = fe.OnRequestReceivedErrors[0], fe.OnRequestReceivedErrors[1:] - } - return fe.RequestReceivedResponse, err -} - -func (fe *fakeEvents) OnResponseReceived(chid datatransfer.ChannelID, response datatransfer.Response) error { - fe.OnResponseReceivedCallCount++ - fe.ResponseReceivedResponse = response - fe.ResponseReceivedChannelID = chid - var err error - if len(fe.OnResponseReceivedErrors) > 0 { - err, fe.OnResponseReceivedErrors = fe.OnResponseReceivedErrors[0], fe.OnResponseReceivedErrors[1:] - } - return err -} - -func (fe *fakeEvents) OnChannelCompleted(chid datatransfer.ChannelID, completeErr error) error { - fe.OnChannelCompletedCalled = true - fe.ChannelCompletedSuccess = completeErr == nil - return fe.OnChannelCompletedErr -} - -func (fe *fakeEvents) OnContextAugment(chid datatransfer.ChannelID) func(context.Context) context.Context { - return fe.OnContextAugmentFunc -} - -type harness struct { - outgoing datatransfer.Request - incoming datatransfer.Response - ctx context.Context - transport *Transport - fgs *testutil.FakeGraphSync - transferID datatransfer.TransferID - self peer.ID - other peer.ID - block graphsync.BlockData - request graphsync.RequestData - altRequest graphsync.RequestData - response graphsync.ResponseData - updatedRequest graphsync.RequestData - outgoingRequestHookActions *testutil.FakeOutgoingRequestHookActions - incomingBlockHookActions *testutil.FakeIncomingBlockHookActions - outgoingBlockHookActions *testutil.FakeOutgoingBlockHookActions - incomingRequestHookActions *testutil.FakeIncomingRequestHookActions - requestUpdatedHookActions *testutil.FakeRequestUpdatedActions - incomingResponseHookActions *testutil.FakeIncomingResponseHookActions - requestQueuedHookActions *testutil.FakeRequestQueuedHookActions -} - -func (ha *harness) outgoingRequestHook() { - ha.fgs.OutgoingRequestHook(ha.other, ha.request, ha.outgoingRequestHookActions) -} - -func (ha *harness) altOutgoingRequestHook() { - ha.fgs.OutgoingRequestHook(ha.other, ha.altRequest, ha.outgoingRequestHookActions) -} - -func (ha *harness) incomingBlockHook() { - ha.fgs.IncomingBlockHook(ha.other, ha.response, ha.block, ha.incomingBlockHookActions) -} -func (ha *harness) outgoingBlockHook() { - ha.fgs.OutgoingBlockHook(ha.other, ha.request, ha.block, ha.outgoingBlockHookActions) -} - -func (ha *harness) incomingRequestHook() { - ha.fgs.IncomingRequestHook(ha.other, ha.request, ha.incomingRequestHookActions) -} - -func (ha *harness) incomingRequestQueuedHook() { - ha.fgs.IncomingRequestQueuedHook(ha.other, ha.request, ha.requestQueuedHookActions) -} - -func (ha *harness) requestUpdatedHook() { - ha.fgs.RequestUpdatedHook(ha.other, ha.request, ha.updatedRequest, ha.requestUpdatedHookActions) -} -func (ha *harness) incomingResponseHOok() { - ha.fgs.IncomingResponseHook(ha.other, ha.response, ha.incomingResponseHookActions) -} -func (ha *harness) responseCompletedListener() { - ha.fgs.CompletedResponseListener(ha.other, ha.request, ha.response.Status()) -} -func (ha *harness) requestorCancelledListener() { - ha.fgs.RequestorCancelledListener(ha.other, ha.request) -} -func (ha *harness) networkErrorListener(err error) { - ha.fgs.NetworkErrorListener(ha.other, ha.request, err) -} -func (ha *harness) receiverNetworkErrorListener(err error) { - ha.fgs.ReceiverNetworkErrorListener(ha.other, err) -} - -type dtConfig struct { - dtExtensionMissing bool - dtIsResponse bool - dtExtensionMalformed bool -} - -func (dtc *dtConfig) extensions(t *testing.T, transferID datatransfer.TransferID, extName graphsync.ExtensionName) map[graphsync.ExtensionName]datamodel.Node { - extensions := make(map[graphsync.ExtensionName]datamodel.Node) - if !dtc.dtExtensionMissing { - if dtc.dtExtensionMalformed { - extensions[extName] = basicnode.NewInt(10) - } else { - var msg datatransfer.Message - if dtc.dtIsResponse { - msg = testutil.NewDTResponse(t, transferID) - } else { - msg = testutil.NewDTRequest(t, transferID) - } - nd, err := msg.ToIPLD() - require.NoError(t, err) - extensions[extName] = nd - } - } - return extensions -} - -type gsRequestConfig struct { - dtExtensionMissing bool - dtIsResponse bool - dtExtensionMalformed bool -} - -func (grc *gsRequestConfig) makeRequest(t *testing.T, transferID datatransfer.TransferID, requestID graphsync.RequestID) graphsync.RequestData { - dtConfig := dtConfig{ - dtExtensionMissing: grc.dtExtensionMissing, - dtIsResponse: grc.dtIsResponse, - dtExtensionMalformed: grc.dtExtensionMalformed, - } - extensions := dtConfig.extensions(t, transferID, extension.ExtensionDataTransfer1_1) - return testutil.NewFakeRequest(requestID, extensions) -} - -type gsResponseConfig struct { - dtExtensionMissing bool - dtIsResponse bool - dtExtensionMalformed bool - status graphsync.ResponseStatusCode -} - -func (grc *gsResponseConfig) makeResponse(t *testing.T, transferID datatransfer.TransferID, requestID graphsync.RequestID) graphsync.ResponseData { - dtConfig := dtConfig{ - dtExtensionMissing: grc.dtExtensionMissing, - dtIsResponse: grc.dtIsResponse, - dtExtensionMalformed: grc.dtExtensionMalformed, - } - extensions := dtConfig.extensions(t, transferID, extension.ExtensionDataTransfer1_1) - return testutil.NewFakeResponse(requestID, extensions, grc.status) -} - -func assertDecodesToMessage(t *testing.T, data datamodel.Node, expected datatransfer.Message) { - actual, err := message.FromIPLD(data) - require.NoError(t, err) - require.Equal(t, expected, actual) -} - -func assertHasOutgoingMessage(t *testing.T, extensions []graphsync.ExtensionData, expected datatransfer.Message) { - nd, err := expected.ToIPLD() - require.NoError(t, err) - found := false - for _, e := range extensions { - if e.Name == extension.ExtensionDataTransfer1_1 { - require.True(t, ipld.DeepEqual(nd, e.Data), "data matches") - found = true - } - } - if !found { - require.Fail(t, "extension not found") - } -} - -func assertHasExtensionMessage(t *testing.T, name graphsync.ExtensionName, extensions []graphsync.ExtensionData, expected datatransfer.Message) { - nd, err := expected.ToIPLD() - require.NoError(t, err) - found := false - for _, e := range extensions { - if e.Name == name { - require.True(t, ipld.DeepEqual(nd, e.Data), "data matches") - found = true - } - } - if !found { - require.Fail(t, "extension not found") - } -} - -type mockChannelState struct { - receivedCids []cid.Cid -} - -var _ datatransfer.ChannelState = (*mockChannelState)(nil) - -func (m *mockChannelState) ReceivedCids() []cid.Cid { - return m.receivedCids -} - -func (m *mockChannelState) ReceivedCidsLen() int { - return len(m.receivedCids) -} - -func (m *mockChannelState) ReceivedCidsTotal() int64 { - return (int64)(len(m.receivedCids)) -} - -func (m *mockChannelState) QueuedCidsTotal() int64 { - panic("implement me") -} - -func (m *mockChannelState) SentCidsTotal() int64 { - panic("implement me") -} - -func (m *mockChannelState) Queued() uint64 { - panic("implement me") -} - -func (m *mockChannelState) Sent() uint64 { - panic("implement me") -} - -func (m *mockChannelState) Received() uint64 { - panic("implement me") -} - -func (m *mockChannelState) ChannelID() datatransfer.ChannelID { - panic("implement me") -} - -func (m *mockChannelState) Status() datatransfer.Status { - panic("implement me") -} - -func (m *mockChannelState) TransferID() datatransfer.TransferID { - panic("implement me") -} - -func (m *mockChannelState) BaseCID() cid.Cid { - panic("implement me") -} - -func (m *mockChannelState) Selector() ipld.Node { - panic("implement me") -} - -func (m *mockChannelState) Voucher() datatransfer.Voucher { - panic("implement me") -} - -func (m *mockChannelState) Sender() peer.ID { - panic("implement me") -} - -func (m *mockChannelState) Recipient() peer.ID { - panic("implement me") -} - -func (m *mockChannelState) TotalSize() uint64 { - panic("implement me") -} - -func (m *mockChannelState) IsPull() bool { - panic("implement me") -} - -func (m *mockChannelState) OtherPeer() peer.ID { - panic("implement me") -} - -func (m *mockChannelState) SelfPeer() peer.ID { - panic("implement me") -} - -func (m *mockChannelState) Message() string { - panic("implement me") -} - -func (m *mockChannelState) Vouchers() []datatransfer.Voucher { - panic("implement me") -} - -func (m *mockChannelState) VoucherResults() []datatransfer.VoucherResult { - panic("implement me") -} - -func (m *mockChannelState) LastVoucher() datatransfer.Voucher { - panic("implement me") -} - -func (m *mockChannelState) LastVoucherResult() datatransfer.VoucherResult { - panic("implement me") -} - -func (m *mockChannelState) Stages() *datatransfer.ChannelStages { - panic("implement me") -} diff --git a/transport/graphsync/gskeychidmap_test.go b/transport/graphsync/gskeychidmap_test.go index 0962eb7d..61b5fbd7 100644 --- a/transport/graphsync/gskeychidmap_test.go +++ b/transport/graphsync/gskeychidmap_test.go @@ -6,7 +6,7 @@ import ( "github.com/ipfs/go-graphsync" "github.com/stretchr/testify/require" - datatransfer "github.com/filecoin-project/go-data-transfer" + datatransfer "github.com/filecoin-project/go-data-transfer/v2" ) func TestRequestIDToChannelIDMap(t *testing.T) { diff --git a/transport/graphsync/hooks.go b/transport/graphsync/hooks.go new file mode 100644 index 00000000..1e90301c --- /dev/null +++ b/transport/graphsync/hooks.go @@ -0,0 +1,352 @@ +package graphsync + +import ( + "context" + "errors" + "fmt" + + "github.com/ipfs/go-graphsync" + basicnode "github.com/ipld/go-ipld-prime/node/basic" + peer "github.com/libp2p/go-libp2p-core/peer" + "golang.org/x/xerrors" + + datatransfer "github.com/filecoin-project/go-data-transfer/v2" + "github.com/filecoin-project/go-data-transfer/v2/transport/graphsync/dtchannel" + "github.com/filecoin-project/go-data-transfer/v2/transport/graphsync/extension" +) + +// gsOutgoingRequestHook is called when a graphsync request is made +func (t *Transport) gsOutgoingRequestHook(p peer.ID, request graphsync.RequestData, hookActions graphsync.OutgoingRequestHookActions) { + + chid, ok := t.requestIDToChannelID.load(request.ID()) + if !ok { + return + } + + // Start tracking the channel if we're not already + ch, err := t.getDTChannel(chid) + + if err != nil { + // There was an error opening the channel, bail out + log.Errorf("processing OnChannelOpened for %s: %s", chid, err) + t.CleanupChannel(chid) + return + } + + // A data transfer channel was opened + t.events.OnTransportEvent(chid, datatransfer.TransportOpenedChannel{}) + + // Signal that the channel has been opened + ch.GsReqOpened(p, request.ID(), hookActions) +} + +// gsIncomingBlockHook is called when a block is received +func (t *Transport) gsIncomingBlockHook(p peer.ID, response graphsync.ResponseData, block graphsync.BlockData, hookActions graphsync.IncomingBlockHookActions) { + chid, ok := t.requestIDToChannelID.load(response.RequestID()) + if !ok { + return + } + + ch, err := t.getDTChannel(chid) + if err != nil { + hookActions.TerminateWithError(err) + return + } + + if ch.UpdateReceivedIndexIfGreater(block.Index()) && block.BlockSizeOnWire() != 0 { + + t.events.OnTransportEvent(chid, datatransfer.TransportReceivedData{Size: block.BlockSize(), Index: basicnode.NewInt(block.Index())}) + + if ch.UpdateProgress(block.BlockSizeOnWire()) { + t.events.OnTransportEvent(chid, datatransfer.TransportReachedDataLimit{}) + hookActions.PauseRequest() + } + } + +} + +func (t *Transport) gsBlockSentListener(p peer.ID, request graphsync.RequestData, block graphsync.BlockData) { + chid, ok := t.requestIDToChannelID.load(request.ID()) + if !ok { + return + } + + ch, err := t.getDTChannel(chid) + if err != nil { + log.Errorf("sent hook error: %s, for channel %s", err, chid) + return + } + + if ch.UpdateSentIndexIfGreater(block.Index()) && block.BlockSizeOnWire() != 0 { + t.events.OnTransportEvent(chid, datatransfer.TransportSentData{Size: block.BlockSize(), Index: basicnode.NewInt(block.Index())}) + } +} + +func (t *Transport) gsOutgoingBlockHook(p peer.ID, request graphsync.RequestData, block graphsync.BlockData, hookActions graphsync.OutgoingBlockHookActions) { + + chid, ok := t.requestIDToChannelID.load(request.ID()) + if !ok { + return + } + + ch, err := t.getDTChannel(chid) + if err != nil { + hookActions.TerminateWithError(err) + return + } + + if ch.UpdateQueuedIndexIfGreater(block.Index()) && block.BlockSizeOnWire() != 0 { + t.events.OnTransportEvent(chid, datatransfer.TransportQueuedData{Size: block.BlockSize(), Index: basicnode.NewInt(block.Index())}) + + if ch.UpdateProgress(block.BlockSizeOnWire()) { + t.events.OnTransportEvent(chid, datatransfer.TransportReachedDataLimit{}) + hookActions.PauseResponse() + } + } +} + +// gsReqQueuedHook is called when graphsync enqueues an incoming request for data +func (t *Transport) gsRequestProcessingListener(p peer.ID, request graphsync.RequestData, requestCount int) { + + chid, ok := t.requestIDToChannelID.load(request.ID()) + if !ok { + return + } + + t.events.OnTransportEvent(chid, datatransfer.TransportInitiatedTransfer{}) +} + +// gsReqRecdHook is called when graphsync receives an incoming request for data +func (t *Transport) gsReqRecdHook(p peer.ID, request graphsync.RequestData, hookActions graphsync.IncomingRequestHookActions) { + // if this is a push request the sender is us. + msg, err := extension.GetTransferData(request, t.supportedExtensions) + if err != nil { + hookActions.TerminateWithError(err) + return + } + + // extension not found; probably not our request. + if msg == nil { + return + } + + // An incoming graphsync request for data is received when either + // - The remote peer opened a data-transfer pull channel, so the local node + // receives a graphsync request for the data + // - The local node opened a data-transfer push channel, and in response + // the remote peer sent a graphsync request for the data, and now the + // local node receives that request for data + var chid datatransfer.ChannelID + var responseMessage datatransfer.Message + if msg.IsRequest() { + // when a data transfer request comes in on graphsync, the remote peer + // initiated a pull + chid = datatransfer.ChannelID{ID: msg.TransferID(), Initiator: p, Responder: t.peerID} + + log.Debugf("%s: received request for data (pull), req_id=%d", chid, request.ID()) + + request := msg.(datatransfer.Request) + + // graphsync never receives dt push requests as new graphsync requests -- is so, we should error + isNewOrRestart := (request.IsNew() || request.IsRestart()) + if isNewOrRestart && !request.IsPull() { + hookActions.TerminateWithError(datatransfer.ErrUnsupported) + return + } + + responseMessage, err = t.events.OnRequestReceived(chid, request) + + // if we're going to accept this new/restart request, protect connection + if isNewOrRestart && err == nil { + t.dtNet.Protect(p, chid.String()) + } + + } else { + // when a data transfer response comes in on graphsync, this node + // initiated a push, and the remote peer responded with a request + // for data + chid = datatransfer.ChannelID{ID: msg.TransferID(), Initiator: t.peerID, Responder: p} + + log.Debugf("%s: received request for data (push), req_id=%d", chid, request.ID()) + + response := msg.(datatransfer.Response) + err = t.events.OnResponseReceived(chid, response) + } + + // If we need to send a response, add the response message as an extension + if responseMessage != nil { + // gsReqRecdHook uses a unique extension name so it can be attached with data from a different hook + // incomingReqExtensions also includes default extension name so it remains compatible with previous data-transfer + // protocol versions out there. + extensions, extensionErr := extension.ToExtensionData(responseMessage, incomingReqExtensions) + if extensionErr != nil { + hookActions.TerminateWithError(err) + return + } + for _, extension := range extensions { + hookActions.SendExtensionData(extension) + } + } + + if err != nil { + hookActions.TerminateWithError(err) + return + } + + hookActions.AugmentContext(t.events.OnContextAugment(chid)) + + chst, err := t.events.ChannelState(context.TODO(), chid) + if err != nil { + hookActions.TerminateWithError(err) + } + + var ch *dtchannel.Channel + if msg.IsRequest() { + ch = t.trackDTChannel(chid) + } else { + ch, err = t.getDTChannel(chid) + if err != nil { + hookActions.TerminateWithError(err) + return + } + } + t.requestIDToChannelID.set(request.ID(), true, chid) + ch.GsDataRequestRcvd(p, request.ID(), chst, hookActions) +} + +// gsCompletedResponseListener is a graphsync.OnCompletedResponseListener. We use it learn when the data transfer is complete +// for the side that is responding to a graphsync request +func (t *Transport) gsCompletedResponseListener(p peer.ID, request graphsync.RequestData, status graphsync.ResponseStatusCode) { + chid, ok := t.requestIDToChannelID.load(request.ID()) + if !ok { + return + } + + if status == graphsync.RequestCancelled { + return + } + + ch, err := t.getDTChannel(chid) + if err != nil { + return + } + ch.MarkTransferComplete() + + var completeEvent datatransfer.TransportCompletedTransfer + if status == graphsync.RequestCompletedFull { + completeEvent.Success = true + } else { + completeEvent.ErrorMessage = fmt.Sprintf("graphsync response to peer %s did not complete: response status code %s", p, status.String()) + } + + // Used by the tests to listen for when a response completes + if t.completedResponseListener != nil { + t.completedResponseListener(chid) + } + + t.events.OnTransportEvent(chid, completeEvent) + +} + +// gsIncomingResponseHook is a graphsync.OnIncomingResponseHook. We use it to pass on responses +func (t *Transport) gsIncomingResponseHook(p peer.ID, response graphsync.ResponseData, hookActions graphsync.IncomingResponseHookActions) { + chid, ok := t.requestIDToChannelID.load(response.RequestID()) + if !ok { + return + } + responseMessage, err := t.processExtension(chid, response, p, incomingReqExtensions) + + if responseMessage != nil { + t.dtNet.SendMessage(context.TODO(), p, transportID, responseMessage) + } + + if err != nil { + hookActions.TerminateWithError(err) + } + + // In a case where the transfer sends blocks immediately this extension may contain both a + // response message and a revalidation request so we trigger OnResponseReceived again for this + // specific extension name + responseMessage, err = t.processExtension(chid, response, p, outgoingBlkExtensions) + + if responseMessage != nil { + t.dtNet.SendMessage(context.TODO(), p, transportID, responseMessage) + } + + if err != nil { + hookActions.TerminateWithError(err) + } +} + +func (t *Transport) processExtension(chid datatransfer.ChannelID, gsMsg extension.GsExtended, p peer.ID, exts []graphsync.ExtensionName) (datatransfer.Message, error) { + + // if this is a push request the sender is us. + msg, err := extension.GetTransferData(gsMsg, exts) + if err != nil { + return nil, err + } + // extension not found; probably not our request. + if msg == nil { + return nil, nil + } + + if msg.IsRequest() { + // only accept request message updates when original message was also request + if (chid != datatransfer.ChannelID{ID: msg.TransferID(), Initiator: p, Responder: t.peerID}) { + return nil, errors.New("received request on response channel") + } + dtRequest := msg.(datatransfer.Request) + return t.events.OnRequestReceived(chid, dtRequest) + } + + // only accept response message updates when original message was also response + if (chid != datatransfer.ChannelID{ID: msg.TransferID(), Initiator: t.peerID, Responder: p}) { + return nil, errors.New("received response on request channel") + } + + dtResponse := msg.(datatransfer.Response) + + return nil, t.events.OnResponseReceived(chid, dtResponse) +} + +func (t *Transport) gsRequestorCancelledListener(p peer.ID, request graphsync.RequestData) { + chid, ok := t.requestIDToChannelID.load(request.ID()) + if !ok { + return + } + + ch, err := t.getDTChannel(chid) + if err != nil { + if !xerrors.Is(datatransfer.ErrChannelNotFound, err) { + log.Errorf("requestor cancelled: getting channel %s: %s", chid, err) + } + return + } + + log.Debugf("%s: requester cancelled data-transfer", chid) + ch.OnRequesterCancelled() +} + +// Called when there is a graphsync error sending data +func (t *Transport) gsNetworkSendErrorListener(p peer.ID, request graphsync.RequestData, gserr error) { + // Fire an error if the graphsync request was made by this node or the remote peer + chid, ok := t.requestIDToChannelID.load(request.ID()) + if !ok { + return + } + + t.events.OnTransportEvent(chid, datatransfer.TransportErrorSendingData{ErrorMessage: gserr.Error()}) +} + +// Called when there is a graphsync error receiving data +func (t *Transport) gsNetworkReceiveErrorListener(p peer.ID, gserr error) { + // Fire a receive data error on all ongoing graphsync transfers with that + // peer + t.requestIDToChannelID.forEach(func(k graphsync.RequestID, sending bool, chid datatransfer.ChannelID) { + if chid.Initiator != p && chid.Responder != p { + return + } + + t.events.OnTransportEvent(chid, datatransfer.TransportErrorReceivingData{ErrorMessage: gserr.Error()}) + }) +} diff --git a/transport/graphsync/initiating_test.go b/transport/graphsync/initiating_test.go new file mode 100644 index 00000000..8b0d1e53 --- /dev/null +++ b/transport/graphsync/initiating_test.go @@ -0,0 +1,1325 @@ +package graphsync_test + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/ipfs/go-graphsync" + "github.com/ipld/go-ipld-prime/datamodel" + cidlink "github.com/ipld/go-ipld-prime/linking/cid" + basicnode "github.com/ipld/go-ipld-prime/node/basic" + "github.com/stretchr/testify/require" + + datatransfer "github.com/filecoin-project/go-data-transfer/v2" + "github.com/filecoin-project/go-data-transfer/v2/transport/graphsync/extension" + "github.com/filecoin-project/go-data-transfer/v2/transport/graphsync/testharness" +) + +func TestInitiatingPullRequestSuccessFlow(t *testing.T) { + ctx := context.Background() + ctx, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + th := testharness.SetupHarness(ctx, testharness.PullRequest()) + var receivedRequest testharness.ReceivedGraphSyncRequest + var request graphsync.RequestData + t.Run("opens successfully", func(t *testing.T) { + err := th.Transport.OpenChannel(th.Ctx, th.Channel, th.NewRequest(t)) + require.NoError(t, err) + require.Len(t, th.DtNet.ProtectedPeers, 1) + require.Equal(t, th.DtNet.ProtectedPeers[0], testharness.TaggedPeer{th.Channel.OtherPeer(), th.Channel.ChannelID().String()}) + require.Len(t, th.Fgs.ReceivedRequests, 1) + receivedRequest = th.Fgs.ReceivedRequests[0] + request = receivedRequest.ToRequestData(t) + msg, err := extension.GetTransferData(request, []graphsync.ExtensionName{ + extension.ExtensionDataTransfer1_1, + }) + require.NoError(t, err) + require.Equal(t, th.NewRequest(t), msg) + }) + t.Run("configures persistence", func(t *testing.T) { + th.Transport.UseStore(th.Channel.ChannelID(), cidlink.DefaultLinkSystem()) + th.Fgs.AssertHasPersistenceOption(t, fmt.Sprintf("data-transfer-%s", th.Channel.ChannelID().String())) + }) + t.Run("receives outgoing request hook", func(t *testing.T) { + th.OutgoingRequestHook(request) + require.Equal(t, fmt.Sprintf("data-transfer-%s", th.Channel.ChannelID().String()), th.OutgoingRequestHookActions.PersistenceOption) + th.Events.AssertTransportEvent(t, th.Channel.ChannelID(), datatransfer.TransportOpenedChannel{}) + }) + t.Run("receives outgoing processing listener", func(t *testing.T) { + th.OutgoingRequestProcessingListener(request) + th.Events.AssertTransportEvent(t, th.Channel.ChannelID(), datatransfer.TransportInitiatedTransfer{}) + }) + dtResponse := th.Response() + response := receivedRequest.Response(t, dtResponse, nil, graphsync.PartialResponse) + t.Run("receives response", func(t *testing.T) { + th.IncomingResponseHook(response) + require.Equal(t, th.Events.ReceivedResponse, dtResponse) + }) + + t.Run("received block", func(t *testing.T) { + block := testharness.NewFakeBlockData(12345, 1, true) + th.IncomingBlockHook(response, block) + th.Events.AssertTransportEvent(t, th.Channel.ChannelID(), datatransfer.TransportReceivedData{Size: block.BlockSize(), Index: basicnode.NewInt(block.Index())}) + block = testharness.NewFakeBlockData(12345, 2, true) + th.IncomingBlockHook(response, block) + th.Events.AssertTransportEvent(t, th.Channel.ChannelID(), datatransfer.TransportReceivedData{Size: block.BlockSize(), Index: basicnode.NewInt(block.Index())}) + // block not on wire has no effect + block = testharness.NewFakeBlockData(12345, 3, false) + th.IncomingBlockHook(response, block) + th.Events.RefuteTransportEvent(t, th.Channel.ChannelID(), datatransfer.TransportReceivedData{Size: block.BlockSize(), Index: basicnode.NewInt(block.Index())}) + // block with lower index has no effect + block = testharness.NewFakeBlockData(67890, 1, true) + th.IncomingBlockHook(response, block) + th.Events.RefuteTransportEvent(t, th.Channel.ChannelID(), datatransfer.TransportReceivedData{Size: block.BlockSize(), Index: basicnode.NewInt(block.Index())}) + }) + + t.Run("receive pause", func(t *testing.T) { + dtPauseResponse := th.UpdateResponse(true) + pauseResponse := receivedRequest.Response(t, nil, dtPauseResponse, graphsync.RequestPaused) + th.IncomingResponseHook(pauseResponse) + require.Equal(t, th.Events.ReceivedResponse, dtPauseResponse) + }) + + t.Run("send update", func(t *testing.T) { + vRequest := th.VoucherRequest() + th.Transport.SendMessage(ctx, th.Channel.ChannelID(), vRequest) + th.DtNet.AssertSentMessage(t, testharness.FakeSentMessage{PeerID: th.Channel.OtherPeer(), TransportID: "graphsync", Message: vRequest}) + }) + + t.Run("receive resume", func(t *testing.T) { + dtResumeResponse := th.UpdateResponse(false) + pauseResponse := receivedRequest.Response(t, nil, dtResumeResponse, graphsync.PartialResponse) + th.IncomingResponseHook(pauseResponse) + require.Equal(t, th.Events.ReceivedResponse, dtResumeResponse) + }) + + t.Run("pause", func(t *testing.T) { + th.Channel.SetInitiatorPaused(true) + th.Transport.ChannelUpdated(ctx, th.Channel.ChannelID(), th.UpdateRequest(true)) + th.DtNet.AssertSentMessage(t, testharness.FakeSentMessage{PeerID: th.Channel.OtherPeer(), TransportID: "graphsync", Message: th.UpdateRequest(true)}) + require.Len(t, th.Fgs.Pauses, 1) + require.Equal(t, th.Fgs.Pauses[0], request.ID()) + }) + t.Run("pause again", func(t *testing.T) { + th.Transport.ChannelUpdated(ctx, th.Channel.ChannelID(), th.UpdateRequest(true)) + // should send message again + th.DtNet.AssertSentMessage(t, testharness.FakeSentMessage{PeerID: th.Channel.OtherPeer(), TransportID: "graphsync", Message: th.UpdateRequest(true)}) + // should not pause again + require.Len(t, th.Fgs.Pauses, 1) + }) + t.Run("resume", func(t *testing.T) { + th.Channel.SetInitiatorPaused(false) + th.Transport.ChannelUpdated(ctx, th.Channel.ChannelID(), th.UpdateRequest(false)) + require.Len(t, th.Fgs.Resumes, 1) + resume := th.Fgs.Resumes[0] + require.Equal(t, request.ID(), resume.RequestID) + msg := resume.DTMessage(t) + require.Equal(t, msg, th.UpdateRequest(false)) + }) + t.Run("resume again", func(t *testing.T) { + th.Transport.ChannelUpdated(ctx, th.Channel.ChannelID(), th.UpdateRequest(false)) + // should send message again + th.DtNet.AssertSentMessage(t, testharness.FakeSentMessage{PeerID: th.Channel.OtherPeer(), TransportID: "graphsync", Message: th.UpdateRequest(true)}) + // should not resume again + require.Len(t, th.Fgs.Resumes, 1) + }) + + t.Run("restart request", func(t *testing.T) { + restartIndex := int64(5) + th.Channel.SetReceivedIndex(basicnode.NewInt(restartIndex)) + err := th.Transport.RestartChannel(ctx, th.Channel, th.RestartRequest(t)) + require.NoError(t, err) + require.Len(t, th.DtNet.ProtectedPeers, 2) + require.Equal(t, th.DtNet.ProtectedPeers[1], testharness.TaggedPeer{th.Channel.OtherPeer(), th.Channel.ChannelID().String()}) + require.Len(t, th.DtNet.ConnectWithRetryAttempts, 1) + require.Equal(t, th.DtNet.ConnectWithRetryAttempts[0], testharness.ConnectWithRetryAttempt{th.Channel.OtherPeer(), "graphsync"}) + require.Len(t, th.Fgs.Cancels, 1) + th.Events.AssertTransportEvent(t, th.Channel.ChannelID(), datatransfer.TransportTransferCancelled{ErrorMessage: "graphsync request cancelled"}) + require.Equal(t, request.ID(), th.Fgs.Cancels[0]) + require.Len(t, th.Fgs.ReceivedRequests, 2) + receivedRequest = th.Fgs.ReceivedRequests[1] + request = receivedRequest.ToRequestData(t) + msg, err := extension.GetTransferData(request, []graphsync.ExtensionName{ + extension.ExtensionDataTransfer1_1, + }) + require.NoError(t, err) + require.Equal(t, th.RestartRequest(t), msg) + nd, has := request.Extension(graphsync.ExtensionsDoNotSendFirstBlocks) + require.True(t, has) + val, err := nd.AsInt() + require.NoError(t, err) + require.Equal(t, restartIndex, val) + }) + + t.Run("complete request", func(t *testing.T) { + close(receivedRequest.ResponseChan) + close(receivedRequest.ResponseErrChan) + select { + case <-th.CompletedRequests: + case <-ctx.Done(): + t.Fatalf("did not complete request") + } + th.Events.AssertTransportEventEventually(t, th.Channel.ChannelID(), datatransfer.TransportCompletedTransfer{Success: true}) + }) + + t.Run("cleanup request", func(t *testing.T) { + th.Transport.CleanupChannel(th.Channel.ChannelID()) + require.Len(t, th.DtNet.UnprotectedPeers, 1) + require.Equal(t, th.DtNet.UnprotectedPeers[0], testharness.TaggedPeer{th.Channel.OtherPeer(), th.Channel.ChannelID().String()}) + }) +} + +type ctxKey struct{} + +func TestInitiatingPushRequestSuccessFlow(t *testing.T) { + ctx := context.Background() + ctx, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + th := testharness.SetupHarness(ctx) + t.Run("opens successfully", func(t *testing.T) { + err := th.Transport.OpenChannel(th.Ctx, th.Channel, th.NewRequest(t)) + require.NoError(t, err) + require.Len(t, th.DtNet.ProtectedPeers, 1) + require.Equal(t, th.DtNet.ProtectedPeers[0], testharness.TaggedPeer{th.Channel.OtherPeer(), th.Channel.ChannelID().String()}) + th.DtNet.AssertSentMessage(t, testharness.FakeSentMessage{PeerID: th.Channel.OtherPeer(), TransportID: "graphsync", Message: th.NewRequest(t)}) + }) + t.Run("configures persistence", func(t *testing.T) { + th.Transport.UseStore(th.Channel.ChannelID(), cidlink.DefaultLinkSystem()) + th.Fgs.AssertHasPersistenceOption(t, fmt.Sprintf("data-transfer-%s", th.Channel.ChannelID().String())) + }) + dtResponse := th.Response() + requestID := graphsync.NewRequestID() + request := testharness.NewFakeRequest(requestID, map[graphsync.ExtensionName]datamodel.Node{extension.ExtensionDataTransfer1_1: dtResponse.ToIPLD()}, graphsync.RequestTypeNew) + //response := receivedRequest.Response(t, dtResponse, nil, graphsync.PartialResponse) + t.Run("receives incoming request hook", func(t *testing.T) { + th.Events.ReturnedOnContextAugmentFunc = func(ctx context.Context) context.Context { + return context.WithValue(ctx, ctxKey{}, "applesauce") + } + th.IncomingRequestHook(request) + require.Equal(t, fmt.Sprintf("data-transfer-%s", th.Channel.ChannelID().String()), th.IncomingRequestHookActions.PersistenceOption) + require.True(t, th.IncomingRequestHookActions.Validated) + require.False(t, th.IncomingBlockHookActions.Paused) + require.NoError(t, th.IncomingRequestHookActions.TerminationError) + th.IncomingRequestHookActions.AssertAugmentedContextKey(t, ctxKey{}, "applesauce") + require.Equal(t, th.Events.ReceivedResponse, dtResponse) + }) + + t.Run("receives incoming processing listener", func(t *testing.T) { + th.IncomingRequestProcessingListener(request) + th.Events.AssertTransportEvent(t, th.Channel.ChannelID(), datatransfer.TransportInitiatedTransfer{}) + }) + + t.Run("queued block", func(t *testing.T) { + block := testharness.NewFakeBlockData(12345, 1, true) + th.OutgoingBlockHook(request, block) + th.Events.AssertTransportEvent(t, th.Channel.ChannelID(), datatransfer.TransportQueuedData{Size: block.BlockSize(), Index: basicnode.NewInt(block.Index())}) + block = testharness.NewFakeBlockData(12345, 2, true) + th.OutgoingBlockHook(request, block) + th.Events.AssertTransportEvent(t, th.Channel.ChannelID(), datatransfer.TransportQueuedData{Size: block.BlockSize(), Index: basicnode.NewInt(block.Index())}) + // block not on wire has no effect + block = testharness.NewFakeBlockData(12345, 3, false) + th.OutgoingBlockHook(request, block) + th.Events.RefuteTransportEvent(t, th.Channel.ChannelID(), datatransfer.TransportQueuedData{Size: block.BlockSize(), Index: basicnode.NewInt(block.Index())}) + // block with lower index has no effect + block = testharness.NewFakeBlockData(67890, 1, true) + th.OutgoingBlockHook(request, block) + th.Events.RefuteTransportEvent(t, th.Channel.ChannelID(), datatransfer.TransportQueuedData{Size: block.BlockSize(), Index: basicnode.NewInt(block.Index())}) + }) + + t.Run("sent block", func(t *testing.T) { + block := testharness.NewFakeBlockData(12345, 1, true) + th.BlockSentListener(request, block) + th.Events.AssertTransportEvent(t, th.Channel.ChannelID(), datatransfer.TransportSentData{Size: block.BlockSize(), Index: basicnode.NewInt(block.Index())}) + block = testharness.NewFakeBlockData(12345, 2, true) + th.BlockSentListener(request, block) + th.Events.AssertTransportEvent(t, th.Channel.ChannelID(), datatransfer.TransportSentData{Size: block.BlockSize(), Index: basicnode.NewInt(block.Index())}) + // block not on wire has no effect + block = testharness.NewFakeBlockData(12345, 3, false) + th.BlockSentListener(request, block) + th.Events.RefuteTransportEvent(t, th.Channel.ChannelID(), datatransfer.TransportSentData{Size: block.BlockSize(), Index: basicnode.NewInt(block.Index())}) + // block with lower index has no effect + block = testharness.NewFakeBlockData(67890, 1, true) + th.BlockSentListener(request, block) + th.Events.RefuteTransportEvent(t, th.Channel.ChannelID(), datatransfer.TransportSentData{Size: block.BlockSize(), Index: basicnode.NewInt(block.Index())}) + }) + + t.Run("receive pause", func(t *testing.T) { + th.RequestorCancelledListener(request) + dtPauseResponse := th.UpdateResponse(true) + th.DtNet.Delegates[0].Receiver.ReceiveResponse(ctx, th.Channel.OtherPeer(), dtPauseResponse) + require.Equal(t, th.Events.ReceivedResponse, dtPauseResponse) + }) + + t.Run("send update", func(t *testing.T) { + vRequest := th.VoucherRequest() + th.Transport.SendMessage(ctx, th.Channel.ChannelID(), vRequest) + th.DtNet.AssertSentMessage(t, testharness.FakeSentMessage{PeerID: th.Channel.OtherPeer(), TransportID: "graphsync", Message: vRequest}) + }) + + t.Run("receive resume", func(t *testing.T) { + dtResumeResponse := th.UpdateResponse(false) + request = testharness.NewFakeRequest(requestID, map[graphsync.ExtensionName]datamodel.Node{extension.ExtensionDataTransfer1_1: dtResumeResponse.ToIPLD()}, graphsync.RequestTypeNew) + // reset hook behavior + th.IncomingRequestHookActions = &testharness.FakeIncomingRequestHookActions{} + th.IncomingRequestHook(request) + require.Equal(t, fmt.Sprintf("data-transfer-%s", th.Channel.ChannelID().String()), th.IncomingRequestHookActions.PersistenceOption) + require.True(t, th.IncomingRequestHookActions.Validated) + require.False(t, th.IncomingBlockHookActions.Paused) + require.NoError(t, th.IncomingRequestHookActions.TerminationError) + th.IncomingRequestHookActions.AssertAugmentedContextKey(t, ctxKey{}, "applesauce") + require.Equal(t, th.Events.ReceivedResponse, dtResumeResponse) + }) + + t.Run("pause", func(t *testing.T) { + th.Channel.SetInitiatorPaused(true) + th.Transport.ChannelUpdated(ctx, th.Channel.ChannelID(), th.UpdateRequest(true)) + th.DtNet.AssertSentMessage(t, testharness.FakeSentMessage{PeerID: th.Channel.OtherPeer(), TransportID: "graphsync", Message: th.UpdateRequest(true)}) + require.Len(t, th.Fgs.Pauses, 1) + require.Equal(t, th.Fgs.Pauses[0], request.ID()) + }) + t.Run("pause again", func(t *testing.T) { + th.Transport.ChannelUpdated(ctx, th.Channel.ChannelID(), th.UpdateRequest(true)) + // should send message again + th.DtNet.AssertSentMessage(t, testharness.FakeSentMessage{PeerID: th.Channel.OtherPeer(), TransportID: "graphsync", Message: th.UpdateRequest(true)}) + // should not pause again + require.Len(t, th.Fgs.Pauses, 1) + }) + t.Run("resume", func(t *testing.T) { + th.Channel.SetInitiatorPaused(false) + th.Transport.ChannelUpdated(ctx, th.Channel.ChannelID(), th.UpdateRequest(false)) + require.Len(t, th.Fgs.Resumes, 1) + resume := th.Fgs.Resumes[0] + require.Equal(t, request.ID(), resume.RequestID) + msg := resume.DTMessage(t) + require.Equal(t, msg, th.UpdateRequest(false)) + }) + t.Run("resume again", func(t *testing.T) { + th.Transport.ChannelUpdated(ctx, th.Channel.ChannelID(), th.UpdateRequest(false)) + // should send message again + th.DtNet.AssertSentMessage(t, testharness.FakeSentMessage{PeerID: th.Channel.OtherPeer(), TransportID: "graphsync", Message: th.UpdateRequest(true)}) + // should not resume again + require.Len(t, th.Fgs.Resumes, 1) + }) + + t.Run("restart request", func(t *testing.T) { + err := th.Transport.RestartChannel(ctx, th.Channel, th.RestartRequest(t)) + require.NoError(t, err) + require.Len(t, th.DtNet.ProtectedPeers, 2) + require.Equal(t, th.DtNet.ProtectedPeers[1], testharness.TaggedPeer{th.Channel.OtherPeer(), th.Channel.ChannelID().String()}) + require.Len(t, th.DtNet.ConnectWithRetryAttempts, 1) + require.Equal(t, th.DtNet.ConnectWithRetryAttempts[0], testharness.ConnectWithRetryAttempt{th.Channel.OtherPeer(), "graphsync"}) + th.DtNet.AssertSentMessage(t, testharness.FakeSentMessage{PeerID: th.Channel.OtherPeer(), TransportID: "graphsync", Message: th.NewRequest(t)}) + }) + + t.Run("complete request", func(t *testing.T) { + th.ResponseCompletedListener(request, graphsync.RequestCompletedFull) + select { + case <-th.CompletedResponses: + case <-ctx.Done(): + t.Fatalf("did not complete request") + } + th.Events.AssertTransportEvent(t, th.Channel.ChannelID(), datatransfer.TransportCompletedTransfer{Success: true}) + }) + + t.Run("cleanup request", func(t *testing.T) { + th.Transport.CleanupChannel(th.Channel.ChannelID()) + require.Len(t, th.DtNet.UnprotectedPeers, 1) + require.Equal(t, th.DtNet.UnprotectedPeers[0], testharness.TaggedPeer{th.Channel.OtherPeer(), th.Channel.ChannelID().String()}) + }) +} + +/* "gs outgoing request with recognized dt push channel will record incoming blocks": { + requestConfig: gsRequestConfig{ + dtIsResponse: true, + }, + action: func(gsData *harness) { + gsData.outgoingRequestHook() + gsData.incomingBlockHook() + }, + check: func(t *testing.T, events *fakeEvents, gsData *harness) { + require.Equal(t, events.ChannelOpenedChannelID, datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.self, Initiator: gsData.other}) + require.True(t, events.OnDataReceivedCalled) + require.NoError(t, gsData.incomingBlockHookActions.TerminationError) + }, + }, + "non-data-transfer gs request will not record incoming blocks and send updates": { + requestConfig: gsRequestConfig{ + dtExtensionMissing: true, + }, + action: func(gsData *harness) { + gsData.outgoingRequestHook() + gsData.incomingBlockHook() + }, + check: func(t *testing.T, events *fakeEvents, gsData *harness) { + require.Equal(t, events.ChannelOpenedChannelID, datatransfer.ChannelID{}) + require.False(t, events.OnDataReceivedCalled) + require.NoError(t, gsData.incomingBlockHookActions.TerminationError) + }, + }, + "gs request unrecognized opened channel will not record incoming blocks": { + events: fakeEvents{ + OnChannelOpenedError: errors.New("Not recognized"), + }, + action: func(gsData *harness) { + gsData.outgoingRequestHook() + gsData.incomingBlockHook() + }, + check: func(t *testing.T, events *fakeEvents, gsData *harness) { + require.Equal(t, events.ChannelOpenedChannelID, datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.other, Initiator: gsData.self}) + require.False(t, events.OnDataReceivedCalled) + require.NoError(t, gsData.incomingBlockHookActions.TerminationError) + }, + }, + "gs incoming block with data receive error will halt request": { + events: fakeEvents{ + OnDataReceivedError: errors.New("something went wrong"), + }, + action: func(gsData *harness) { + gsData.outgoingRequestHook() + gsData.incomingBlockHook() + }, + check: func(t *testing.T, events *fakeEvents, gsData *harness) { + require.Equal(t, events.ChannelOpenedChannelID, datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.other, Initiator: gsData.self}) + require.True(t, events.OnDataReceivedCalled) + require.Error(t, gsData.incomingBlockHookActions.TerminationError) + }, + }, + "outgoing gs request with recognized dt request can receive gs response": { + responseConfig: gsResponseConfig{ + dtIsResponse: true, + }, + action: func(gsData *harness) { + gsData.outgoingRequestHook() + gsData.incomingResponseHOok() + }, + check: func(t *testing.T, events *fakeEvents, gsData *harness) { + require.Equal(t, events.ChannelOpenedChannelID, datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.other, Initiator: gsData.self}) + require.Equal(t, 0, events.OnRequestReceivedCallCount) + require.Equal(t, 1, events.OnResponseReceivedCallCount) + require.NoError(t, gsData.incomingResponseHookActions.TerminationError) + }, + }, + "outgoing gs request with recognized dt request cannot receive gs response with dt request": { + action: func(gsData *harness) { + gsData.outgoingRequestHook() + gsData.incomingResponseHOok() + }, + check: func(t *testing.T, events *fakeEvents, gsData *harness) { + require.Equal(t, events.ChannelOpenedChannelID, datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.other, Initiator: gsData.self}) + require.Equal(t, 0, events.OnRequestReceivedCallCount) + require.Equal(t, 0, events.OnResponseReceivedCallCount) + require.Error(t, gsData.incomingResponseHookActions.TerminationError) + }, + }, + "outgoing gs request with recognized dt response can receive gs response": { + requestConfig: gsRequestConfig{ + dtIsResponse: true, + }, + action: func(gsData *harness) { + gsData.outgoingRequestHook() + gsData.incomingResponseHOok() + }, + check: func(t *testing.T, events *fakeEvents, gsData *harness) { + require.Equal(t, events.ChannelOpenedChannelID, datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.self, Initiator: gsData.other}) + require.Equal(t, 1, events.OnRequestReceivedCallCount) + require.Equal(t, 0, events.OnResponseReceivedCallCount) + require.NoError(t, gsData.incomingResponseHookActions.TerminationError) + }, + }, + "outgoing gs request with recognized dt response cannot receive gs response with dt response": { + requestConfig: gsRequestConfig{ + dtIsResponse: true, + }, + responseConfig: gsResponseConfig{ + dtIsResponse: true, + }, + action: func(gsData *harness) { + gsData.outgoingRequestHook() + gsData.incomingResponseHOok() + }, + check: func(t *testing.T, events *fakeEvents, gsData *harness) { + require.Equal(t, events.ChannelOpenedChannelID, datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.self, Initiator: gsData.other}) + require.Equal(t, 0, events.OnResponseReceivedCallCount) + require.Equal(t, 0, events.OnRequestReceivedCallCount) + require.Error(t, gsData.incomingResponseHookActions.TerminationError) + }, + }, + "outgoing gs request with recognized dt request will error with malformed update": { + responseConfig: gsResponseConfig{ + dtExtensionMalformed: true, + }, + action: func(gsData *harness) { + gsData.outgoingRequestHook() + gsData.incomingResponseHOok() + }, + check: func(t *testing.T, events *fakeEvents, gsData *harness) { + require.Equal(t, events.ChannelOpenedChannelID, datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.other, Initiator: gsData.self}) + require.Equal(t, 0, events.OnRequestReceivedCallCount) + require.Equal(t, 0, events.OnResponseReceivedCallCount) + require.Error(t, gsData.incomingResponseHookActions.TerminationError) + }, + }, + "outgoing gs request with recognized dt request will ignore non-data-transfer update": { + responseConfig: gsResponseConfig{ + dtExtensionMissing: true, + }, + action: func(gsData *harness) { + gsData.outgoingRequestHook() + gsData.incomingResponseHOok() + }, + check: func(t *testing.T, events *fakeEvents, gsData *harness) { + require.Equal(t, events.ChannelOpenedChannelID, datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.other, Initiator: gsData.self}) + require.Equal(t, 0, events.OnRequestReceivedCallCount) + require.Equal(t, 0, events.OnResponseReceivedCallCount) + require.NoError(t, gsData.incomingResponseHookActions.TerminationError) + }, + }, + "outgoing gs request with recognized dt response can send message on update": { + events: fakeEvents{ + RequestReceivedResponse: testutil.NewDTResponse(t, datatransfer.TransferID(rand.Uint32())), + }, + requestConfig: gsRequestConfig{ + dtIsResponse: true, + }, + action: func(gsData *harness) { + gsData.outgoingRequestHook() + gsData.incomingResponseHOok() + }, + check: func(t *testing.T, events *fakeEvents, gsData *harness) { + require.Equal(t, events.ChannelOpenedChannelID, datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.self, Initiator: gsData.other}) + require.Equal(t, 1, events.OnRequestReceivedCallCount) + require.Equal(t, 0, events.OnResponseReceivedCallCount) + require.NoError(t, gsData.incomingResponseHookActions.TerminationError) + assertHasOutgoingMessage(t, gsData.incomingResponseHookActions.SentExtensions, + events.RequestReceivedResponse) + }, + }, + "outgoing gs request with recognized dt response err will error": { + requestConfig: gsRequestConfig{ + dtIsResponse: true, + }, + events: fakeEvents{ + OnRequestReceivedErrors: []error{errors.New("something went wrong")}, + }, + action: func(gsData *harness) { + gsData.outgoingRequestHook() + gsData.incomingResponseHOok() + }, + check: func(t *testing.T, events *fakeEvents, gsData *harness) { + require.Equal(t, 0, events.OnResponseReceivedCallCount) + require.Equal(t, 1, events.OnRequestReceivedCallCount) + require.Error(t, gsData.incomingResponseHookActions.TerminationError) + }, + }, + "incoming gs request with recognized dt request will validate gs request & send dt response": { + action: func(gsData *harness) { + gsData.incomingRequestHook() + }, + events: fakeEvents{ + RequestReceivedResponse: testutil.NewDTResponse(t, datatransfer.TransferID(rand.Uint32())), + }, + check: func(t *testing.T, events *fakeEvents, gsData *harness) { + require.Equal(t, 1, events.OnRequestReceivedCallCount) + require.Equal(t, 0, events.OnResponseReceivedCallCount) + require.Equal(t, events.RequestReceivedChannelID, datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.self, Initiator: gsData.other}) + dtRequestData, _ := gsData.request.Extension(extension.ExtensionDataTransfer1_1) + assertDecodesToMessage(t, dtRequestData, events.RequestReceivedRequest) + require.True(t, gsData.incomingRequestHookActions.Validated) + assertHasExtensionMessage(t, extension.ExtensionDataTransfer1_1, gsData.incomingRequestHookActions.SentExtensions, events.RequestReceivedResponse) + require.NoError(t, gsData.incomingRequestHookActions.TerminationError) + + channelsForPeer := gsData.transport.ChannelsForPeer(gsData.other) + require.Equal(t, channelsForPeer, ChannelsForPeer{ + SendingChannels: map[datatransfer.ChannelID]ChannelGraphsyncRequests{ + events.RequestReceivedChannelID: { + Current: gsData.request.ID(), + }, + }, + ReceivingChannels: map[datatransfer.ChannelID]ChannelGraphsyncRequests{}, + }) + }, + }, + "incoming gs request with recognized dt response will validate gs request": { + requestConfig: gsRequestConfig{ + dtIsResponse: true, + }, + action: func(gsData *harness) { + gsData.incomingRequestHook() + }, + check: func(t *testing.T, events *fakeEvents, gsData *harness) { + require.Equal(t, 0, events.OnRequestReceivedCallCount) + require.Equal(t, 1, events.OnResponseReceivedCallCount) + require.Equal(t, events.ResponseReceivedChannelID, datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.other, Initiator: gsData.self}) + dtResponseData, _ := gsData.request.Extension(extension.ExtensionDataTransfer1_1) + assertDecodesToMessage(t, dtResponseData, events.ResponseReceivedResponse) + require.True(t, gsData.incomingRequestHookActions.Validated) + require.NoError(t, gsData.incomingRequestHookActions.TerminationError) + }, + }, + "malformed data transfer extension on incoming request will terminate": { + requestConfig: gsRequestConfig{ + dtExtensionMalformed: true, + }, + action: func(gsData *harness) { + gsData.incomingRequestHook() + }, + check: func(t *testing.T, events *fakeEvents, gsData *harness) { + require.Equal(t, 0, events.OnRequestReceivedCallCount) + require.False(t, gsData.incomingRequestHookActions.Validated) + require.Error(t, gsData.incomingRequestHookActions.TerminationError) + }, + }, + "unrecognized incoming dt request will terminate but send response": { + events: fakeEvents{ + RequestReceivedResponse: testutil.NewDTResponse(t, datatransfer.TransferID(rand.Uint32())), + OnRequestReceivedErrors: []error{errors.New("something went wrong")}, + }, + action: func(gsData *harness) { + gsData.incomingRequestHook() + }, + check: func(t *testing.T, events *fakeEvents, gsData *harness) { + require.Equal(t, 1, events.OnRequestReceivedCallCount) + require.Equal(t, 0, events.OnResponseReceivedCallCount) + require.Equal(t, events.RequestReceivedChannelID, datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.self, Initiator: gsData.other}) + dtRequestData, _ := gsData.request.Extension(extension.ExtensionDataTransfer1_1) + assertDecodesToMessage(t, dtRequestData, events.RequestReceivedRequest) + require.False(t, gsData.incomingRequestHookActions.Validated) + assertHasExtensionMessage(t, extension.ExtensionIncomingRequest1_1, gsData.incomingRequestHookActions.SentExtensions, events.RequestReceivedResponse) + require.Error(t, gsData.incomingRequestHookActions.TerminationError) + }, + }, + "incoming gs request with recognized dt request will record outgoing blocks": { + action: func(gsData *harness) { + gsData.incomingRequestHook() + gsData.outgoingBlockHook() + }, + check: func(t *testing.T, events *fakeEvents, gsData *harness) { + require.Equal(t, 1, events.OnRequestReceivedCallCount) + require.True(t, events.OnDataQueuedCalled) + require.NoError(t, gsData.outgoingBlockHookActions.TerminationError) + }, + }, + + "incoming gs request with recognized dt response will record outgoing blocks": { + requestConfig: gsRequestConfig{ + dtIsResponse: true, + }, + action: func(gsData *harness) { + gsData.incomingRequestHook() + gsData.outgoingBlockHook() + }, + check: func(t *testing.T, events *fakeEvents, gsData *harness) { + require.Equal(t, 1, events.OnResponseReceivedCallCount) + require.True(t, events.OnDataQueuedCalled) + require.NoError(t, gsData.outgoingBlockHookActions.TerminationError) + }, + }, + "non-data-transfer request will not record outgoing blocks": { + requestConfig: gsRequestConfig{ + dtExtensionMissing: true, + }, + action: func(gsData *harness) { + gsData.incomingRequestHook() + gsData.outgoingBlockHook() + }, + check: func(t *testing.T, events *fakeEvents, gsData *harness) { + require.Equal(t, 0, events.OnRequestReceivedCallCount) + require.False(t, events.OnDataQueuedCalled) + }, + }, + "outgoing data queued error will terminate request": { + events: fakeEvents{ + OnDataQueuedError: errors.New("something went wrong"), + }, + action: func(gsData *harness) { + gsData.incomingRequestHook() + gsData.outgoingBlockHook() + }, + check: func(t *testing.T, events *fakeEvents, gsData *harness) { + require.Equal(t, 1, events.OnRequestReceivedCallCount) + require.True(t, events.OnDataQueuedCalled) + require.Error(t, gsData.outgoingBlockHookActions.TerminationError) + }, + }, + "outgoing data queued error == pause will pause request": { + events: fakeEvents{ + OnDataQueuedError: datatransfer.ErrPause, + }, + action: func(gsData *harness) { + gsData.incomingRequestHook() + gsData.outgoingBlockHook() + }, + check: func(t *testing.T, events *fakeEvents, gsData *harness) { + require.Equal(t, 1, events.OnRequestReceivedCallCount) + require.True(t, events.OnDataQueuedCalled) + require.True(t, gsData.outgoingBlockHookActions.Paused) + require.NoError(t, gsData.outgoingBlockHookActions.TerminationError) + }, + }, + "incoming gs request with recognized dt request will send updates": { + action: func(gsData *harness) { + gsData.incomingRequestHook() + gsData.outgoingBlockHook() + }, + events: fakeEvents{ + OnDataQueuedMessage: testutil.NewDTResponse(t, datatransfer.TransferID(rand.Uint32())), + }, + check: func(t *testing.T, events *fakeEvents, gsData *harness) { + require.Equal(t, 1, events.OnRequestReceivedCallCount) + require.True(t, events.OnDataQueuedCalled) + require.NoError(t, gsData.outgoingBlockHookActions.TerminationError) + assertHasExtensionMessage(t, extension.ExtensionOutgoingBlock1_1, gsData.outgoingBlockHookActions.SentExtensions, + events.OnDataQueuedMessage) + }, + }, + "incoming gs request with recognized dt request can receive update": { + action: func(gsData *harness) { + gsData.incomingRequestHook() + gsData.requestUpdatedHook() + }, + check: func(t *testing.T, events *fakeEvents, gsData *harness) { + require.Equal(t, 2, events.OnRequestReceivedCallCount) + require.NoError(t, gsData.requestUpdatedHookActions.TerminationError) + }, + }, + "incoming gs request with recognized dt request cannot receive update with dt response": { + updatedConfig: gsRequestConfig{ + dtIsResponse: true, + }, + action: func(gsData *harness) { + gsData.incomingRequestHook() + gsData.requestUpdatedHook() + }, + check: func(t *testing.T, events *fakeEvents, gsData *harness) { + require.Equal(t, 1, events.OnRequestReceivedCallCount) + require.Equal(t, 0, events.OnResponseReceivedCallCount) + require.Error(t, gsData.requestUpdatedHookActions.TerminationError) + }, + }, + "incoming gs request with recognized dt response can receive update": { + requestConfig: gsRequestConfig{ + dtIsResponse: true, + }, + updatedConfig: gsRequestConfig{ + dtIsResponse: true, + }, + action: func(gsData *harness) { + gsData.incomingRequestHook() + gsData.requestUpdatedHook() + }, + check: func(t *testing.T, events *fakeEvents, gsData *harness) { + require.Equal(t, 2, events.OnResponseReceivedCallCount) + require.NoError(t, gsData.requestUpdatedHookActions.TerminationError) + }, + }, + "incoming gs request with recognized dt response cannot receive update with dt request": { + requestConfig: gsRequestConfig{ + dtIsResponse: true, + }, + action: func(gsData *harness) { + gsData.incomingRequestHook() + gsData.requestUpdatedHook() + }, + check: func(t *testing.T, events *fakeEvents, gsData *harness) { + require.Equal(t, 1, events.OnResponseReceivedCallCount) + require.Equal(t, 0, events.OnRequestReceivedCallCount) + require.Error(t, gsData.requestUpdatedHookActions.TerminationError) + }, + }, + "incoming gs request with recognized dt request will error with malformed update": { + updatedConfig: gsRequestConfig{ + dtExtensionMalformed: true, + }, + action: func(gsData *harness) { + gsData.incomingRequestHook() + gsData.requestUpdatedHook() + }, + check: func(t *testing.T, events *fakeEvents, gsData *harness) { + require.Equal(t, 1, events.OnRequestReceivedCallCount) + require.Error(t, gsData.requestUpdatedHookActions.TerminationError) + }, + }, + "incoming gs request with recognized dt request will ignore non-data-transfer update": { + updatedConfig: gsRequestConfig{ + dtExtensionMissing: true, + }, + action: func(gsData *harness) { + gsData.incomingRequestHook() + gsData.requestUpdatedHook() + }, + check: func(t *testing.T, events *fakeEvents, gsData *harness) { + require.Equal(t, 1, events.OnRequestReceivedCallCount) + require.NoError(t, gsData.requestUpdatedHookActions.TerminationError) + }, + }, + "incoming gs request with recognized dt request can send message on update": { + events: fakeEvents{ + RequestReceivedResponse: testutil.NewDTResponse(t, datatransfer.TransferID(rand.Uint32())), + }, + action: func(gsData *harness) { + gsData.incomingRequestHook() + gsData.requestUpdatedHook() + }, + check: func(t *testing.T, events *fakeEvents, gsData *harness) { + require.Equal(t, 2, events.OnRequestReceivedCallCount) + require.NoError(t, gsData.requestUpdatedHookActions.TerminationError) + assertHasOutgoingMessage(t, gsData.requestUpdatedHookActions.SentExtensions, + events.RequestReceivedResponse) + }, + }, + "recognized incoming request will record successful request completion": { + responseConfig: gsResponseConfig{ + status: graphsync.RequestCompletedFull, + }, + action: func(gsData *harness) { + gsData.incomingRequestHook() + gsData.responseCompletedListener() + }, + check: func(t *testing.T, events *fakeEvents, gsData *harness) { + require.Equal(t, 1, events.OnRequestReceivedCallCount) + require.True(t, events.OnChannelCompletedCalled) + require.True(t, events.ChannelCompletedSuccess) + }, + }, + + "recognized incoming request will record unsuccessful request completion": { + responseConfig: gsResponseConfig{ + status: graphsync.RequestCompletedPartial, + }, + action: func(gsData *harness) { + gsData.incomingRequestHook() + gsData.responseCompletedListener() + }, + check: func(t *testing.T, events *fakeEvents, gsData *harness) { + require.Equal(t, 1, events.OnRequestReceivedCallCount) + require.True(t, events.OnChannelCompletedCalled) + require.False(t, events.ChannelCompletedSuccess) + }, + }, + "recognized incoming request will not record request cancellation": { + responseConfig: gsResponseConfig{ + status: graphsync.RequestCancelled, + }, + action: func(gsData *harness) { + gsData.incomingRequestHook() + gsData.responseCompletedListener() + }, + check: func(t *testing.T, events *fakeEvents, gsData *harness) { + require.Equal(t, 1, events.OnRequestReceivedCallCount) + require.False(t, events.OnChannelCompletedCalled) + }, + }, + "non-data-transfer request will not record request completed": { + requestConfig: gsRequestConfig{ + dtExtensionMissing: true, + }, + responseConfig: gsResponseConfig{ + status: graphsync.RequestCompletedPartial, + }, + action: func(gsData *harness) { + gsData.incomingRequestHook() + gsData.responseCompletedListener() + }, + check: func(t *testing.T, events *fakeEvents, gsData *harness) { + require.Equal(t, 0, events.OnRequestReceivedCallCount) + require.False(t, events.OnChannelCompletedCalled) + }, + }, + "recognized incoming request can be closed": { + action: func(gsData *harness) { + gsData.incomingRequestHook() + }, + check: func(t *testing.T, events *fakeEvents, gsData *harness) { + err := gsData.transport.CloseChannel(gsData.ctx, datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.self, Initiator: gsData.other}) + require.NoError(t, err) + require.Equal(t, 1, events.OnRequestReceivedCallCount) + gsData.fgs.AssertCancelReceived(gsData.ctx, t) + }, + }, + "unrecognized request cannot be closed": { + check: func(t *testing.T, events *fakeEvents, gsData *harness) { + err := gsData.transport.CloseChannel(gsData.ctx, datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.self, Initiator: gsData.other}) + require.Error(t, err) + }, + }, + "recognized incoming request that requestor cancelled will not close via graphsync": { + action: func(gsData *harness) { + gsData.incomingRequestHook() + gsData.requestorCancelledListener() + }, + check: func(t *testing.T, events *fakeEvents, gsData *harness) { + err := gsData.transport.CloseChannel(gsData.ctx, datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.self, Initiator: gsData.other}) + require.NoError(t, err) + require.Equal(t, 1, events.OnRequestReceivedCallCount) + gsData.fgs.AssertNoCancelReceived(t) + }, + }, + "recognized incoming request can be paused": { + action: func(gsData *harness) { + gsData.incomingRequestHook() + }, + check: func(t *testing.T, events *fakeEvents, gsData *harness) { + err := gsData.transport.PauseChannel(gsData.ctx, datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.self, Initiator: gsData.other}) + require.NoError(t, err) + require.Equal(t, 1, events.OnRequestReceivedCallCount) + gsData.fgs.AssertPauseReceived(gsData.ctx, t) + }, + }, + "unrecognized request cannot be paused": { + check: func(t *testing.T, events *fakeEvents, gsData *harness) { + err := gsData.transport.PauseChannel(gsData.ctx, datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.self, Initiator: gsData.other}) + require.Error(t, err) + }, + }, + "recognized incoming request that requestor cancelled will not pause via graphsync": { + action: func(gsData *harness) { + gsData.incomingRequestHook() + gsData.requestorCancelledListener() + }, + check: func(t *testing.T, events *fakeEvents, gsData *harness) { + err := gsData.transport.PauseChannel(gsData.ctx, datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.self, Initiator: gsData.other}) + require.NoError(t, err) + require.Equal(t, 1, events.OnRequestReceivedCallCount) + gsData.fgs.AssertNoPauseReceived(t) + }, + }, + + "incoming request can be queued": { + action: func(gsData *harness) { + gsData.incomingRequestQueuedHook() + }, + check: func(t *testing.T, events *fakeEvents, gsData *harness) { + require.True(t, events.TransferQueuedCalled) + require.Equal(t, datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.self, Initiator: gsData.other}, + events.TransferQueuedChannelID) + }, + }, + + "incoming request with dtResponse can be queued": { + requestConfig: gsRequestConfig{ + dtIsResponse: true, + }, + responseConfig: gsResponseConfig{ + dtIsResponse: true, + }, + action: func(gsData *harness) { + gsData.incomingRequestQueuedHook() + }, + check: func(t *testing.T, events *fakeEvents, gsData *harness) { + require.True(t, events.TransferQueuedCalled) + require.Equal(t, datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.other, Initiator: gsData.self}, + events.TransferQueuedChannelID) + }, + }, + + "recognized incoming request can be resumed": { + action: func(gsData *harness) { + gsData.incomingRequestHook() + }, + check: func(t *testing.T, events *fakeEvents, gsData *harness) { + err := gsData.transport.ResumeChannel(gsData.ctx, + gsData.incoming, + datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.self, Initiator: gsData.other}, + ) + require.NoError(t, err) + require.Equal(t, 1, events.OnRequestReceivedCallCount) + gsData.fgs.AssertResumeReceived(gsData.ctx, t) + }, + }, + + "unrecognized request cannot be resumed": { + check: func(t *testing.T, events *fakeEvents, gsData *harness) { + err := gsData.transport.ResumeChannel(gsData.ctx, + gsData.incoming, + datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.self, Initiator: gsData.other}, + ) + require.Error(t, err) + }, + }, + "recognized incoming request that requestor cancelled will not resume via graphsync but will resume otherwise": { + action: func(gsData *harness) { + gsData.incomingRequestHook() + gsData.requestorCancelledListener() + }, + check: func(t *testing.T, events *fakeEvents, gsData *harness) { + err := gsData.transport.ResumeChannel(gsData.ctx, + gsData.incoming, + datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.self, Initiator: gsData.other}, + ) + require.NoError(t, err) + require.Equal(t, 1, events.OnRequestReceivedCallCount) + gsData.fgs.AssertNoResumeReceived(t) + gsData.incomingRequestHook() + assertHasOutgoingMessage(t, gsData.incomingRequestHookActions.SentExtensions, gsData.incoming) + }, + }, + "recognized incoming request will record network send error": { + action: func(gsData *harness) { + gsData.incomingRequestHook() + gsData.networkErrorListener(errors.New("something went wrong")) + }, + check: func(t *testing.T, events *fakeEvents, gsData *harness) { + require.Equal(t, 1, events.OnRequestReceivedCallCount) + require.True(t, events.OnSendDataErrorCalled) + }, + }, + "recognized outgoing request will record network send error": { + action: func(gsData *harness) { + gsData.outgoingRequestHook() + gsData.networkErrorListener(errors.New("something went wrong")) + }, + check: func(t *testing.T, events *fakeEvents, gsData *harness) { + require.True(t, events.OnSendDataErrorCalled) + }, + }, + "recognized incoming request will record network receive error": { + action: func(gsData *harness) { + gsData.incomingRequestHook() + gsData.receiverNetworkErrorListener(errors.New("something went wrong")) + }, + check: func(t *testing.T, events *fakeEvents, gsData *harness) { + require.Equal(t, 1, events.OnRequestReceivedCallCount) + require.True(t, events.OnReceiveDataErrorCalled) + }, + }, + "recognized outgoing request will record network receive error": { + action: func(gsData *harness) { + gsData.outgoingRequestHook() + gsData.receiverNetworkErrorListener(errors.New("something went wrong")) + }, + check: func(t *testing.T, events *fakeEvents, gsData *harness) { + require.True(t, events.OnReceiveDataErrorCalled) + }, + }, + "open channel adds block count to the DoNotSendFirstBlocks extension for v1.2 protocol": { + action: func(gsData *harness) { + cids := testutil.GenerateCids(2) + channel := testutil.NewMockChannelState(testutil.MockChannelStateParams{ReceivedCids: cids}) + stor, _ := gsData.outgoing.Selector() + + go gsData.outgoingRequestHook() + _ = gsData.transport.OpenChannel( + gsData.ctx, + gsData.other, + datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.other, Initiator: gsData.self}, + cidlink.Link{Cid: gsData.outgoing.BaseCid()}, + stor, + channel, + gsData.outgoing) + }, + check: func(t *testing.T, events *fakeEvents, gsData *harness) { + requestReceived := gsData.fgs.AssertRequestReceived(gsData.ctx, t) + + ext := requestReceived.Extensions + require.Len(t, ext, 2) + doNotSend := ext[1] + + name := doNotSend.Name + require.Equal(t, graphsync.ExtensionsDoNotSendFirstBlocks, name) + data := doNotSend.Data + blockCount, err := donotsendfirstblocks.DecodeDoNotSendFirstBlocks(data) + require.NoError(t, err) + require.EqualValues(t, blockCount, 2) + }, + }, + "ChannelsForPeer when request is open": { + action: func(gsData *harness) { + cids := testutil.GenerateCids(2) + channel := testutil.NewMockChannelState(testutil.MockChannelStateParams{ReceivedCids: cids}) + stor, _ := gsData.outgoing.Selector() + + go gsData.outgoingRequestHook() + _ = gsData.transport.OpenChannel( + gsData.ctx, + gsData.other, + datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.other, Initiator: gsData.self}, + cidlink.Link{Cid: gsData.outgoing.BaseCid()}, + stor, + channel, + gsData.outgoing) + }, + check: func(t *testing.T, events *fakeEvents, gsData *harness) { + gsData.fgs.AssertRequestReceived(gsData.ctx, t) + + channelsForPeer := gsData.transport.ChannelsForPeer(gsData.other) + require.Equal(t, channelsForPeer, ChannelsForPeer{ + ReceivingChannels: map[datatransfer.ChannelID]ChannelGraphsyncRequests{ + events.ChannelOpenedChannelID: { + Current: gsData.request.ID(), + }, + }, + SendingChannels: map[datatransfer.ChannelID]ChannelGraphsyncRequests{}, + }) + }, + }, + "open channel cancels an existing request with the same channel ID": { + action: func(gsData *harness) { + cids := testutil.GenerateCids(2) + channel := testutil.NewMockChannelState(testutil.MockChannelStateParams{ReceivedCids: cids}) + stor, _ := gsData.outgoing.Selector() + go gsData.outgoingRequestHook() + _ = gsData.transport.OpenChannel( + gsData.ctx, + gsData.other, + datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.other, Initiator: gsData.self}, + cidlink.Link{Cid: gsData.outgoing.BaseCid()}, + stor, + channel, + gsData.outgoing) + + go gsData.altOutgoingRequestHook() + _ = gsData.transport.OpenChannel( + gsData.ctx, + gsData.other, + datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.other, Initiator: gsData.self}, + cidlink.Link{Cid: gsData.outgoing.BaseCid()}, + stor, + channel, + gsData.outgoing) + }, + check: func(t *testing.T, events *fakeEvents, gsData *harness) { + gsData.fgs.AssertRequestReceived(gsData.ctx, t) + gsData.fgs.AssertRequestReceived(gsData.ctx, t) + + ctxt, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + gsData.fgs.AssertCancelReceived(ctxt, t) + + channelsForPeer := gsData.transport.ChannelsForPeer(gsData.other) + require.Equal(t, channelsForPeer, ChannelsForPeer{ + ReceivingChannels: map[datatransfer.ChannelID]ChannelGraphsyncRequests{ + events.ChannelOpenedChannelID: { + Current: gsData.altRequest.ID(), + Previous: []graphsync.RequestID{gsData.request.ID()}, + }, + }, + SendingChannels: map[datatransfer.ChannelID]ChannelGraphsyncRequests{}, + }) + }, + }, + "OnChannelCompleted called when outgoing request completes successfully": { + action: func(gsData *harness) { + gsData.fgs.LeaveRequestsOpen() + stor, _ := gsData.outgoing.Selector() + + go gsData.outgoingRequestHook() + _ = gsData.transport.OpenChannel( + gsData.ctx, + gsData.other, + datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.other, Initiator: gsData.self}, + cidlink.Link{Cid: gsData.outgoing.BaseCid()}, + stor, + nil, + gsData.outgoing) + }, + check: func(t *testing.T, events *fakeEvents, gsData *harness) { + requestReceived := gsData.fgs.AssertRequestReceived(gsData.ctx, t) + close(requestReceived.ResponseChan) + close(requestReceived.ResponseErrChan) + + require.Eventually(t, func() bool { + return events.OnChannelCompletedCalled == true + }, 2*time.Second, 100*time.Millisecond) + require.True(t, events.ChannelCompletedSuccess) + }, + }, + "OnChannelCompleted called when outgoing request completes with error": { + action: func(gsData *harness) { + gsData.fgs.LeaveRequestsOpen() + stor, _ := gsData.outgoing.Selector() + + go gsData.outgoingRequestHook() + _ = gsData.transport.OpenChannel( + gsData.ctx, + gsData.other, + datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.other, Initiator: gsData.self}, + cidlink.Link{Cid: gsData.outgoing.BaseCid()}, + stor, + nil, + gsData.outgoing) + }, + check: func(t *testing.T, events *fakeEvents, gsData *harness) { + requestReceived := gsData.fgs.AssertRequestReceived(gsData.ctx, t) + close(requestReceived.ResponseChan) + requestReceived.ResponseErrChan <- graphsync.RequestFailedUnknownErr{} + close(requestReceived.ResponseErrChan) + + require.Eventually(t, func() bool { + return events.OnChannelCompletedCalled == true + }, 2*time.Second, 100*time.Millisecond) + require.False(t, events.ChannelCompletedSuccess) + }, + }, + "OnChannelComplete when outgoing request cancelled by caller": { + action: func(gsData *harness) { + gsData.fgs.LeaveRequestsOpen() + stor, _ := gsData.outgoing.Selector() + + go gsData.outgoingRequestHook() + _ = gsData.transport.OpenChannel( + gsData.ctx, + gsData.other, + datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.other, Initiator: gsData.self}, + cidlink.Link{Cid: gsData.outgoing.BaseCid()}, + stor, + nil, + gsData.outgoing) + }, + check: func(t *testing.T, events *fakeEvents, gsData *harness) { + requestReceived := gsData.fgs.AssertRequestReceived(gsData.ctx, t) + extensions := make(map[graphsync.ExtensionName]datamodel.Node) + for _, ext := range requestReceived.Extensions { + extensions[ext.Name] = ext.Data + } + request := testutil.NewFakeRequest(graphsync.NewRequestID(), extensions) + gsData.fgs.OutgoingRequestHook(gsData.other, request, gsData.outgoingRequestHookActions) + _ = gsData.transport.CloseChannel(gsData.ctx, datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.other, Initiator: gsData.self}) + ctxt, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + gsData.fgs.AssertCancelReceived(ctxt, t) + }, + }, + "request times out if we get request context cancelled error": { + action: func(gsData *harness) { + gsData.fgs.LeaveRequestsOpen() + stor, _ := gsData.outgoing.Selector() + + go gsData.outgoingRequestHook() + _ = gsData.transport.OpenChannel( + gsData.ctx, + gsData.other, + datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.other, Initiator: gsData.self}, + cidlink.Link{Cid: gsData.outgoing.BaseCid()}, + stor, + nil, + gsData.outgoing) + }, + check: func(t *testing.T, events *fakeEvents, gsData *harness) { + requestReceived := gsData.fgs.AssertRequestReceived(gsData.ctx, t) + close(requestReceived.ResponseChan) + requestReceived.ResponseErrChan <- graphsync.RequestClientCancelledErr{} + close(requestReceived.ResponseErrChan) + + require.Eventually(t, func() bool { + return events.OnRequestCancelledCalled == true + }, 2*time.Second, 100*time.Millisecond) + require.Equal(t, datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.other, Initiator: gsData.self}, events.OnRequestCancelledChannelId) + }, + }, + "request cancelled out if transport shuts down": { + action: func(gsData *harness) { + gsData.fgs.LeaveRequestsOpen() + stor, _ := gsData.outgoing.Selector() + + go gsData.outgoingRequestHook() + _ = gsData.transport.OpenChannel( + gsData.ctx, + gsData.other, + datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.other, Initiator: gsData.self}, + cidlink.Link{Cid: gsData.outgoing.BaseCid()}, + stor, + nil, + gsData.outgoing) + }, + check: func(t *testing.T, events *fakeEvents, gsData *harness) { + gsData.fgs.AssertRequestReceived(gsData.ctx, t) + + gsData.transport.Shutdown(gsData.ctx) + + ctxt, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + gsData.fgs.AssertCancelReceived(ctxt, t) + + require.Nil(t, gsData.fgs.IncomingRequestHook) + require.Nil(t, gsData.fgs.CompletedResponseListener) + require.Nil(t, gsData.fgs.IncomingBlockHook) + require.Nil(t, gsData.fgs.OutgoingBlockHook) + require.Nil(t, gsData.fgs.BlockSentListener) + require.Nil(t, gsData.fgs.OutgoingRequestHook) + require.Nil(t, gsData.fgs.IncomingResponseHook) + require.Nil(t, gsData.fgs.RequestUpdatedHook) + require.Nil(t, gsData.fgs.RequestorCancelledListener) + require.Nil(t, gsData.fgs.NetworkErrorListener) + }, + }, + "request pause works even if called when request is still pending": { + action: func(gsData *harness) { + gsData.fgs.LeaveRequestsOpen() + stor, _ := gsData.outgoing.Selector() + + go gsData.outgoingRequestHook() + _ = gsData.transport.OpenChannel( + gsData.ctx, + gsData.other, + datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.other, Initiator: gsData.self}, + cidlink.Link{Cid: gsData.outgoing.BaseCid()}, + stor, + nil, + gsData.outgoing) + + }, + check: func(t *testing.T, events *fakeEvents, gsData *harness) { + requestReceived := gsData.fgs.AssertRequestReceived(gsData.ctx, t) + assertHasOutgoingMessage(t, requestReceived.Extensions, gsData.outgoing) + completed := make(chan struct{}) + go func() { + err := gsData.transport.PauseChannel(context.Background(), datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.other, Initiator: gsData.self}) + require.NoError(t, err) + close(completed) + }() + time.Sleep(100 * time.Millisecond) + extensions := make(map[graphsync.ExtensionName]datamodel.Node) + for _, ext := range requestReceived.Extensions { + extensions[ext.Name] = ext.Data + } + request := testutil.NewFakeRequest(graphsync.NewRequestID(), extensions) + gsData.fgs.OutgoingRequestHook(gsData.other, request, gsData.outgoingRequestHookActions) + select { + case <-gsData.ctx.Done(): + t.Fatal("never paused channel") + case <-completed: + } + }, + }, + "UseStore can change store used for outgoing requests": { + action: func(gsData *harness) { + lsys := cidlink.DefaultLinkSystem() + lsys.StorageReadOpener = func(ipld.LinkContext, ipld.Link) (io.Reader, error) { + return nil, nil + } + lsys.StorageWriteOpener = func(ipld.LinkContext) (io.Writer, ipld.BlockWriteCommitter, error) { + return nil, nil, nil + } + _ = gsData.transport.UseStore(datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.other, Initiator: gsData.self}, lsys) + gsData.outgoingRequestHook() + }, + check: func(t *testing.T, events *fakeEvents, gsData *harness) { + expectedChannel := "data-transfer-" + datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.other, Initiator: gsData.self}.String() + gsData.fgs.AssertHasPersistenceOption(t, expectedChannel) + require.Equal(t, expectedChannel, gsData.outgoingRequestHookActions.PersistenceOption) + gsData.transport.CleanupChannel(datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.other, Initiator: gsData.self}) + gsData.fgs.AssertDoesNotHavePersistenceOption(t, expectedChannel) + }, + }, + "UseStore can change store used for incoming requests": { + action: func(gsData *harness) { + lsys := cidlink.DefaultLinkSystem() + lsys.StorageReadOpener = func(ipld.LinkContext, ipld.Link) (io.Reader, error) { + return nil, nil + } + lsys.StorageWriteOpener = func(ipld.LinkContext) (io.Writer, ipld.BlockWriteCommitter, error) { + return nil, nil, nil + } + _ = gsData.transport.UseStore(datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.self, Initiator: gsData.other}, lsys) + gsData.incomingRequestHook() + }, + check: func(t *testing.T, events *fakeEvents, gsData *harness) { + expectedChannel := "data-transfer-" + datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.self, Initiator: gsData.other}.String() + gsData.fgs.AssertHasPersistenceOption(t, expectedChannel) + require.Equal(t, expectedChannel, gsData.incomingRequestHookActions.PersistenceOption) + gsData.transport.CleanupChannel(datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.self, Initiator: gsData.other}) + gsData.fgs.AssertDoesNotHavePersistenceOption(t, expectedChannel) + }, + },*/ diff --git a/transport/graphsync/receiver.go b/transport/graphsync/receiver.go new file mode 100644 index 00000000..2391fe1f --- /dev/null +++ b/transport/graphsync/receiver.go @@ -0,0 +1,180 @@ +package graphsync + +import ( + "context" + + cidlink "github.com/ipld/go-ipld-prime/linking/cid" + "github.com/libp2p/go-libp2p-core/peer" + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/trace" + + datatransfer "github.com/filecoin-project/go-data-transfer/v2" +) + +type receiver struct { + transport *Transport +} + +// ReceiveRequest takes an incoming data transfer request, validates the voucher and +// processes the message. +func (r *receiver) ReceiveRequest( + ctx context.Context, + initiator peer.ID, + incoming datatransfer.Request) { + err := r.receiveRequest(ctx, initiator, incoming) + if err != nil { + log.Warnf("error processing request from %s: %s", initiator, err) + } +} + +func (r *receiver) receiveRequest(ctx context.Context, initiator peer.ID, incoming datatransfer.Request) error { + chid := datatransfer.ChannelID{Initiator: initiator, Responder: r.transport.peerID, ID: incoming.TransferID()} + ctxAugment := r.transport.events.OnContextAugment(chid) + if ctxAugment != nil { + ctx = ctxAugment(ctx) + } + ctx, span := otel.Tracer("gs-data-transfer").Start(ctx, "receiveRequest", trace.WithAttributes( + attribute.String("channelID", chid.String()), + attribute.String("baseCid", incoming.BaseCid().String()), + attribute.Bool("isNew", incoming.IsNew()), + attribute.Bool("isRestart", incoming.IsRestart()), + attribute.Bool("isUpdate", incoming.IsUpdate()), + attribute.Bool("isCancel", incoming.IsCancel()), + attribute.Bool("isPaused", incoming.IsPaused()), + )) + defer span.End() + isNewOrRestart := incoming.IsNew() || incoming.IsRestart() + // a graphsync pull request MUST come in via graphsync + if isNewOrRestart && incoming.IsPull() { + return datatransfer.ErrUnsupported + } + response, receiveErr := r.transport.events.OnRequestReceived(chid, incoming) + initiateGraphsyncRequest := isNewOrRestart && response != nil && receiveErr == nil + ch, err := r.transport.getDTChannel(chid) + if err != nil { + if !initiateGraphsyncRequest || receiveErr != nil { + if response != nil { + if sendErr := r.transport.dtNet.SendMessage(ctx, initiator, transportID, response); sendErr != nil { + return sendErr + } + return receiveErr + } + return receiveErr + } + ch = r.transport.trackDTChannel(chid) + } + + if receiveErr != nil { + if response != nil { + if err := r.transport.dtNet.SendMessage(ctx, initiator, transportID, response); err != nil { + return err + } + _ = ch.Close(ctx) + return receiveErr + } + } + + if isNewOrRestart { + r.transport.dtNet.Protect(initiator, chid.String()) + } + chst, err := r.transport.events.ChannelState(ctx, chid) + if err != nil { + return err + } + + err = ch.UpdateFromChannelState(chst) + if err != nil { + return err + } + + if initiateGraphsyncRequest { + stor, _ := incoming.Selector() + if err := r.transport.openRequest(ctx, initiator, chid, cidlink.Link{Cid: incoming.BaseCid()}, stor, response); err != nil { + return err + } + response = nil + } + + action := ch.ActionFromChannelState(chst) + return r.transport.processAction(ctx, chid, ch, action, response) +} + +// ReceiveResponse handles responses to our Push or Pull data transfer request. +// It schedules a transfer only if our Pull Request is accepted. +func (r *receiver) ReceiveResponse( + ctx context.Context, + sender peer.ID, + incoming datatransfer.Response) { + err := r.receiveResponse(ctx, sender, incoming) + if err != nil { + log.Error(err) + } +} +func (r *receiver) receiveResponse( + ctx context.Context, + sender peer.ID, + incoming datatransfer.Response) error { + chid := datatransfer.ChannelID{Initiator: r.transport.peerID, Responder: sender, ID: incoming.TransferID()} + ctx = r.transport.events.OnContextAugment(chid)(ctx) + ctx, span := otel.Tracer("gs-data-transfer").Start(ctx, "receiveResponse", trace.WithAttributes( + attribute.String("channelID", chid.String()), + attribute.Bool("accepted", incoming.Accepted()), + attribute.Bool("isComplete", incoming.IsComplete()), + attribute.Bool("isNew", incoming.IsNew()), + attribute.Bool("isRestart", incoming.IsRestart()), + attribute.Bool("isUpdate", incoming.IsUpdate()), + attribute.Bool("isCancel", incoming.IsCancel()), + attribute.Bool("isPaused", incoming.IsPaused()), + )) + defer span.End() + receiveErr := r.transport.events.OnResponseReceived(chid, incoming) + ch, err := r.transport.getDTChannel(chid) + if err != nil { + return err + } + if receiveErr != nil { + log.Warnf("closing channel %s after getting error processing response from %s: %s", + chid, sender, err) + + _ = ch.Close(ctx) + return receiveErr + } + return nil +} + +func (r *receiver) ReceiveError(err error) { + log.Errorf("received error message on data transfer: %s", err.Error()) +} + +func (r *receiver) ReceiveRestartExistingChannelRequest(ctx context.Context, + sender peer.ID, + incoming datatransfer.Request) { + + ch, err := incoming.RestartChannelId() + if err != nil { + log.Errorf("cannot restart channel: failed to fetch channel Id: %w", err) + return + } + + ctx = r.transport.events.OnContextAugment(ch)(ctx) + ctx, span := otel.Tracer("gs-data-transfer").Start(ctx, "receiveRequest", trace.WithAttributes( + attribute.String("channelID", ch.String()), + )) + defer span.End() + log.Infof("channel %s: received restart existing channel request from %s", ch, sender) + + // initiator should be me + if ch.Initiator != r.transport.peerID { + log.Errorf("cannot restart channel %s: channel initiator is not the manager peer", ch) + return + } + + if ch.Responder != sender { + log.Errorf("cannot restart channel %s: channel counterparty is not the sender peer", ch) + return + } + + r.transport.events.OnTransportEvent(ch, datatransfer.TransportReceivedRestartExistingChannelRequest{}) + return +} diff --git a/transport/graphsync/responding_test.go b/transport/graphsync/responding_test.go new file mode 100644 index 00000000..93343574 --- /dev/null +++ b/transport/graphsync/responding_test.go @@ -0,0 +1,413 @@ +package graphsync_test + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/ipfs/go-graphsync" + "github.com/ipld/go-ipld-prime/datamodel" + cidlink "github.com/ipld/go-ipld-prime/linking/cid" + basicnode "github.com/ipld/go-ipld-prime/node/basic" + "github.com/stretchr/testify/require" + + datatransfer "github.com/filecoin-project/go-data-transfer/v2" + "github.com/filecoin-project/go-data-transfer/v2/transport/graphsync/extension" + "github.com/filecoin-project/go-data-transfer/v2/transport/graphsync/testharness" +) + +func TestRespondingPullSuccessFlow(t *testing.T) { + ctx := context.Background() + ctx, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + th := testharness.SetupHarness(ctx, testharness.PullRequest(), testharness.Responder()) + + // this actually happens in the request received event handler itself in a real life case, but here we just run it before + t.Run("configures persistence", func(t *testing.T) { + th.Transport.UseStore(th.Channel.ChannelID(), cidlink.DefaultLinkSystem()) + th.Fgs.AssertHasPersistenceOption(t, fmt.Sprintf("data-transfer-%s", th.Channel.ChannelID().String())) + }) + requestID := graphsync.NewRequestID() + dtRequest := th.NewRequest(t) + request := testharness.NewFakeRequest(requestID, map[graphsync.ExtensionName]datamodel.Node{extension.ExtensionDataTransfer1_1: dtRequest.ToIPLD()}, graphsync.RequestTypeNew) + + // this the actual start of request processing + t.Run("received and responds successfully", func(t *testing.T) { + dtResponse := th.Response() + th.Events.ReturnedRequestReceivedResponse = dtResponse + th.Channel.SetResponderPaused(true) + th.Channel.SetDataLimit(10000) + th.Events.ReturnedOnContextAugmentFunc = func(ctx context.Context) context.Context { + return context.WithValue(ctx, ctxKey{}, "applesauce") + } + th.IncomingRequestHook(request) + require.Equal(t, dtRequest, th.Events.ReceivedRequest) + require.Len(t, th.DtNet.ProtectedPeers, 1) + require.Equal(t, th.DtNet.ProtectedPeers[0], testharness.TaggedPeer{th.Channel.OtherPeer(), th.Channel.ChannelID().String()}) + require.Equal(t, fmt.Sprintf("data-transfer-%s", th.Channel.ChannelID().String()), th.IncomingRequestHookActions.PersistenceOption) + require.True(t, th.IncomingRequestHookActions.Validated) + require.True(t, th.IncomingRequestHookActions.Paused) + require.NoError(t, th.IncomingRequestHookActions.TerminationError) + sentResponse := th.IncomingRequestHookActions.DTMessage(t) + require.Equal(t, dtResponse, sentResponse) + th.IncomingRequestHookActions.AssertAugmentedContextKey(t, ctxKey{}, "applesauce") + }) + + t.Run("receives incoming processing listener", func(t *testing.T) { + th.IncomingRequestProcessingListener(request) + th.Events.AssertTransportEvent(t, th.Channel.ChannelID(), datatransfer.TransportInitiatedTransfer{}) + }) + + t.Run("unpause request", func(t *testing.T) { + th.Channel.SetResponderPaused(false) + dtValidationResponse := th.ValidationResultResponse(false) + th.Transport.ChannelUpdated(ctx, th.Channel.ChannelID(), dtValidationResponse) + require.Len(t, th.Fgs.Resumes, 1) + require.Equal(t, dtValidationResponse, th.Fgs.Resumes[0].DTMessage(t)) + }) + + t.Run("queued block / data limits", func(t *testing.T) { + // consume first block + block := testharness.NewFakeBlockData(8000, 1, true) + th.OutgoingBlockHook(request, block) + th.Events.AssertTransportEvent(t, th.Channel.ChannelID(), datatransfer.TransportQueuedData{Size: block.BlockSize(), Index: basicnode.NewInt(block.Index())}) + + // consume second block -- should hit data limit + block = testharness.NewFakeBlockData(3000, 2, true) + th.OutgoingBlockHook(request, block) + require.True(t, th.OutgoingBlockHookActions.Paused) + th.Events.AssertTransportEvent(t, th.Channel.ChannelID(), datatransfer.TransportQueuedData{Size: block.BlockSize(), Index: basicnode.NewInt(block.Index())}) + th.Events.AssertTransportEvent(t, th.Channel.ChannelID(), datatransfer.TransportReachedDataLimit{}) + + // reset data limit + th.Channel.SetResponderPaused(false) + th.Channel.SetDataLimit(20000) + dtValidationResponse := th.ValidationResultResponse(false) + th.Transport.ChannelUpdated(ctx, th.Channel.ChannelID(), dtValidationResponse) + require.Len(t, th.Fgs.Resumes, 2) + require.Equal(t, dtValidationResponse, th.Fgs.Resumes[1].DTMessage(t)) + + // block not on wire has no effect + block = testharness.NewFakeBlockData(12345, 3, false) + th.OutgoingBlockHook(request, block) + th.Events.RefuteTransportEvent(t, th.Channel.ChannelID(), datatransfer.TransportQueuedData{Size: block.BlockSize(), Index: basicnode.NewInt(block.Index())}) + // block with lower index has no effect + block = testharness.NewFakeBlockData(67890, 1, true) + th.OutgoingBlockHook(request, block) + th.Events.RefuteTransportEvent(t, th.Channel.ChannelID(), datatransfer.TransportQueuedData{Size: block.BlockSize(), Index: basicnode.NewInt(block.Index())}) + + // consume third block + block = testharness.NewFakeBlockData(5000, 4, true) + th.OutgoingBlockHook(request, block) + th.Events.AssertTransportEvent(t, th.Channel.ChannelID(), datatransfer.TransportQueuedData{Size: block.BlockSize(), Index: basicnode.NewInt(block.Index())}) + + // consume fourth block should hit data limit again + block = testharness.NewFakeBlockData(5000, 5, true) + th.OutgoingBlockHook(request, block) + require.True(t, th.OutgoingBlockHookActions.Paused) + th.Events.AssertTransportEvent(t, th.Channel.ChannelID(), datatransfer.TransportQueuedData{Size: block.BlockSize(), Index: basicnode.NewInt(block.Index())}) + th.Events.AssertTransportEvent(t, th.Channel.ChannelID(), datatransfer.TransportReachedDataLimit{}) + + }) + + t.Run("sent block", func(t *testing.T) { + block := testharness.NewFakeBlockData(12345, 1, true) + th.BlockSentListener(request, block) + th.Events.AssertTransportEvent(t, th.Channel.ChannelID(), datatransfer.TransportSentData{Size: block.BlockSize(), Index: basicnode.NewInt(block.Index())}) + block = testharness.NewFakeBlockData(12345, 2, true) + th.BlockSentListener(request, block) + th.Events.AssertTransportEvent(t, th.Channel.ChannelID(), datatransfer.TransportSentData{Size: block.BlockSize(), Index: basicnode.NewInt(block.Index())}) + // block not on wire has no effect + block = testharness.NewFakeBlockData(12345, 3, false) + th.BlockSentListener(request, block) + th.Events.RefuteTransportEvent(t, th.Channel.ChannelID(), datatransfer.TransportSentData{Size: block.BlockSize(), Index: basicnode.NewInt(block.Index())}) + // block with lower index has no effect + block = testharness.NewFakeBlockData(67890, 1, true) + th.BlockSentListener(request, block) + th.Events.RefuteTransportEvent(t, th.Channel.ChannelID(), datatransfer.TransportSentData{Size: block.BlockSize(), Index: basicnode.NewInt(block.Index())}) + }) + + t.Run("receive pause", func(t *testing.T) { + th.RequestorCancelledListener(request) + dtPauseRequest := th.UpdateRequest(true) + th.Events.ReturnedRequestReceivedResponse = nil + th.DtNet.Delegates[0].Receiver.ReceiveRequest(ctx, th.Channel.OtherPeer(), dtPauseRequest) + require.Equal(t, th.Events.ReceivedRequest, dtPauseRequest) + }) + + t.Run("receive resume", func(t *testing.T) { + dtResumeRequest := th.UpdateRequest(false) + request = testharness.NewFakeRequest(requestID, map[graphsync.ExtensionName]datamodel.Node{extension.ExtensionDataTransfer1_1: dtResumeRequest.ToIPLD()}, graphsync.RequestTypeNew) + // reset hook behavior + th.IncomingRequestHookActions = &testharness.FakeIncomingRequestHookActions{} + th.IncomingRequestHook(request) + // only protect on new and restart requests + require.Len(t, th.DtNet.ProtectedPeers, 1) + require.Equal(t, th.DtNet.ProtectedPeers[0], testharness.TaggedPeer{th.Channel.OtherPeer(), th.Channel.ChannelID().String()}) + require.Equal(t, fmt.Sprintf("data-transfer-%s", th.Channel.ChannelID().String()), th.IncomingRequestHookActions.PersistenceOption) + require.True(t, th.IncomingRequestHookActions.Validated) + require.False(t, th.IncomingBlockHookActions.Paused) + require.NoError(t, th.IncomingRequestHookActions.TerminationError) + th.IncomingRequestHookActions.AssertAugmentedContextKey(t, ctxKey{}, "applesauce") + require.Equal(t, th.Events.ReceivedRequest, dtResumeRequest) + }) + + t.Run("pause", func(t *testing.T) { + th.Channel.SetResponderPaused(true) + th.Transport.ChannelUpdated(ctx, th.Channel.ChannelID(), th.UpdateResponse(true)) + th.DtNet.AssertSentMessage(t, testharness.FakeSentMessage{PeerID: th.Channel.OtherPeer(), TransportID: "graphsync", Message: th.UpdateResponse(true)}) + require.Len(t, th.Fgs.Pauses, 1) + require.Equal(t, th.Fgs.Pauses[0], request.ID()) + }) + + t.Run("pause again", func(t *testing.T) { + th.Transport.ChannelUpdated(ctx, th.Channel.ChannelID(), th.UpdateResponse(true)) + // should send message again + th.DtNet.AssertSentMessage(t, testharness.FakeSentMessage{PeerID: th.Channel.OtherPeer(), TransportID: "graphsync", Message: th.UpdateResponse(true)}) + // should not pause again + require.Len(t, th.Fgs.Pauses, 1) + }) + + t.Run("resume", func(t *testing.T) { + th.Channel.SetResponderPaused(false) + th.Transport.ChannelUpdated(ctx, th.Channel.ChannelID(), th.UpdateResponse(true)) + require.Len(t, th.Fgs.Resumes, 3) + resume := th.Fgs.Resumes[2] + require.Equal(t, request.ID(), resume.RequestID) + msg := resume.DTMessage(t) + require.Equal(t, msg, th.UpdateResponse(false)) + }) + t.Run("resume again", func(t *testing.T) { + th.Transport.ChannelUpdated(ctx, th.Channel.ChannelID(), th.UpdateResponse(false)) + // should send message again + th.DtNet.AssertSentMessage(t, testharness.FakeSentMessage{PeerID: th.Channel.OtherPeer(), TransportID: "graphsync", Message: th.UpdateResponse(true)}) + // should not resume again + require.Len(t, th.Fgs.Resumes, 3) + }) + + t.Run("restart request", func(t *testing.T) { + dtRestartRequest := th.RestartRequest(t) + request := testharness.NewFakeRequest(requestID, map[graphsync.ExtensionName]datamodel.Node{extension.ExtensionDataTransfer1_1: dtRestartRequest.ToIPLD()}, graphsync.RequestTypeNew) + th.IncomingRequestHook(request) + // protect again for a restart + require.Len(t, th.DtNet.ProtectedPeers, 2) + require.Equal(t, th.DtNet.ProtectedPeers[1], testharness.TaggedPeer{th.Channel.OtherPeer(), th.Channel.ChannelID().String()}) + require.Equal(t, fmt.Sprintf("data-transfer-%s", th.Channel.ChannelID().String()), th.IncomingRequestHookActions.PersistenceOption) + require.True(t, th.IncomingRequestHookActions.Validated) + require.False(t, th.IncomingRequestHookActions.Paused) + require.NoError(t, th.IncomingRequestHookActions.TerminationError) + th.IncomingRequestHookActions.AssertAugmentedContextKey(t, ctxKey{}, "applesauce") + }) + + t.Run("complete request", func(t *testing.T) { + th.ResponseCompletedListener(request, graphsync.RequestCompletedFull) + select { + case <-th.CompletedResponses: + case <-ctx.Done(): + t.Fatalf("did not complete request") + } + th.Events.AssertTransportEvent(t, th.Channel.ChannelID(), datatransfer.TransportCompletedTransfer{Success: true}) + }) + + t.Run("cleanup request", func(t *testing.T) { + th.Transport.CleanupChannel(th.Channel.ChannelID()) + require.Len(t, th.DtNet.UnprotectedPeers, 1) + require.Equal(t, th.DtNet.UnprotectedPeers[0], testharness.TaggedPeer{th.Channel.OtherPeer(), th.Channel.ChannelID().String()}) + }) +} + +func TestRespondingPushSuccessFlow(t *testing.T) { + ctx := context.Background() + ctx, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + th := testharness.SetupHarness(ctx, testharness.Responder()) + var receivedRequest testharness.ReceivedGraphSyncRequest + var request graphsync.RequestData + + contextAugmentedCalls := []struct{}{} + th.Events.ReturnedOnContextAugmentFunc = func(ctx context.Context) context.Context { + contextAugmentedCalls = append(contextAugmentedCalls, struct{}{}) + return ctx + } + t.Run("configures persistence", func(t *testing.T) { + th.Transport.UseStore(th.Channel.ChannelID(), cidlink.DefaultLinkSystem()) + th.Fgs.AssertHasPersistenceOption(t, fmt.Sprintf("data-transfer-%s", th.Channel.ChannelID().String())) + }) + t.Run("receive new request", func(t *testing.T) { + dtResponse := th.Response() + th.Events.ReturnedRequestReceivedResponse = dtResponse + th.Channel.SetResponderPaused(true) + th.Channel.SetDataLimit(10000) + + th.DtNet.Delegates[0].Receiver.ReceiveRequest(ctx, th.Channel.OtherPeer(), th.NewRequest(t)) + require.Equal(t, th.NewRequest(t), th.Events.ReceivedRequest) + require.Len(t, th.DtNet.ProtectedPeers, 1) + require.Equal(t, th.DtNet.ProtectedPeers[0], testharness.TaggedPeer{th.Channel.OtherPeer(), th.Channel.ChannelID().String()}) + require.Len(t, th.Fgs.ReceivedRequests, 1) + receivedRequest = th.Fgs.ReceivedRequests[0] + request = receivedRequest.ToRequestData(t) + msg, err := extension.GetTransferData(request, []graphsync.ExtensionName{ + extension.ExtensionDataTransfer1_1, + }) + require.NoError(t, err) + require.Equal(t, dtResponse, msg) + require.Len(t, th.Fgs.Pauses, 1) + require.Equal(t, request.ID(), th.Fgs.Pauses[0]) + require.Len(t, contextAugmentedCalls, 1) + }) + + t.Run("unpause request", func(t *testing.T) { + th.Channel.SetResponderPaused(false) + dtValidationResponse := th.ValidationResultResponse(false) + th.Transport.ChannelUpdated(ctx, th.Channel.ChannelID(), dtValidationResponse) + require.Len(t, th.Fgs.Resumes, 1) + require.Equal(t, dtValidationResponse, th.Fgs.Resumes[0].DTMessage(t)) + }) + + t.Run("receives outgoing request hook", func(t *testing.T) { + th.OutgoingRequestHook(request) + require.Equal(t, fmt.Sprintf("data-transfer-%s", th.Channel.ChannelID().String()), th.OutgoingRequestHookActions.PersistenceOption) + th.Events.AssertTransportEvent(t, th.Channel.ChannelID(), datatransfer.TransportOpenedChannel{}) + }) + + t.Run("receives outgoing processing listener", func(t *testing.T) { + th.OutgoingRequestProcessingListener(request) + th.Events.AssertTransportEvent(t, th.Channel.ChannelID(), datatransfer.TransportInitiatedTransfer{}) + }) + response := receivedRequest.Response(t, nil, nil, graphsync.PartialResponse) + t.Run("received block / data limits", func(t *testing.T) { + th.IncomingResponseHook(response) + // consume first block + block := testharness.NewFakeBlockData(8000, 1, true) + th.IncomingBlockHook(response, block) + th.Events.AssertTransportEvent(t, th.Channel.ChannelID(), datatransfer.TransportReceivedData{Size: block.BlockSize(), Index: basicnode.NewInt(block.Index())}) + + // consume second block -- should hit data limit + block = testharness.NewFakeBlockData(3000, 2, true) + th.IncomingBlockHook(response, block) + require.True(t, th.IncomingBlockHookActions.Paused) + th.Events.AssertTransportEvent(t, th.Channel.ChannelID(), datatransfer.TransportReceivedData{Size: block.BlockSize(), Index: basicnode.NewInt(block.Index())}) + th.Events.AssertTransportEvent(t, th.Channel.ChannelID(), datatransfer.TransportReachedDataLimit{}) + + // reset data limit + th.Channel.SetResponderPaused(false) + th.Channel.SetDataLimit(20000) + dtValidationResponse := th.ValidationResultResponse(false) + th.Transport.ChannelUpdated(ctx, th.Channel.ChannelID(), dtValidationResponse) + require.Len(t, th.Fgs.Resumes, 2) + require.Equal(t, dtValidationResponse, th.Fgs.Resumes[1].DTMessage(t)) + + // block not on wire has no effect + block = testharness.NewFakeBlockData(12345, 3, false) + th.IncomingBlockHook(response, block) + th.Events.RefuteTransportEvent(t, th.Channel.ChannelID(), datatransfer.TransportReceivedData{Size: block.BlockSize(), Index: basicnode.NewInt(block.Index())}) + // block with lower index has no effect + block = testharness.NewFakeBlockData(67890, 1, true) + th.OutgoingBlockHook(request, block) + th.Events.RefuteTransportEvent(t, th.Channel.ChannelID(), datatransfer.TransportReceivedData{Size: block.BlockSize(), Index: basicnode.NewInt(block.Index())}) + + // consume third block + block = testharness.NewFakeBlockData(5000, 4, true) + th.IncomingBlockHook(response, block) + th.Events.AssertTransportEvent(t, th.Channel.ChannelID(), datatransfer.TransportReceivedData{Size: block.BlockSize(), Index: basicnode.NewInt(block.Index())}) + + // consume fourth block should hit data limit again + block = testharness.NewFakeBlockData(5000, 5, true) + th.IncomingBlockHook(response, block) + require.True(t, th.IncomingBlockHookActions.Paused) + th.Events.AssertTransportEvent(t, th.Channel.ChannelID(), datatransfer.TransportReceivedData{Size: block.BlockSize(), Index: basicnode.NewInt(block.Index())}) + th.Events.AssertTransportEvent(t, th.Channel.ChannelID(), datatransfer.TransportReachedDataLimit{}) + + }) + + t.Run("receive pause", func(t *testing.T) { + dtPauseRequest := th.UpdateRequest(true) + pauseResponse := receivedRequest.Response(t, nil, dtPauseRequest, graphsync.RequestPaused) + th.IncomingResponseHook(pauseResponse) + th.Events.ReturnedRequestReceivedResponse = nil + require.Equal(t, th.Events.ReceivedRequest, dtPauseRequest) + }) + + t.Run("receive resume", func(t *testing.T) { + dtResumeRequest := th.UpdateRequest(false) + pauseResponse := receivedRequest.Response(t, nil, dtResumeRequest, graphsync.PartialResponse) + th.IncomingResponseHook(pauseResponse) + require.Equal(t, th.Events.ReceivedRequest, dtResumeRequest) + }) + + t.Run("pause", func(t *testing.T) { + th.Channel.SetResponderPaused(true) + th.Transport.ChannelUpdated(ctx, th.Channel.ChannelID(), th.UpdateResponse(true)) + th.DtNet.AssertSentMessage(t, testharness.FakeSentMessage{PeerID: th.Channel.OtherPeer(), TransportID: "graphsync", Message: th.UpdateResponse(true)}) + require.Len(t, th.Fgs.Pauses, 1) + require.Equal(t, th.Fgs.Pauses[0], request.ID()) + }) + t.Run("pause again", func(t *testing.T) { + th.Transport.ChannelUpdated(ctx, th.Channel.ChannelID(), th.UpdateResponse(true)) + // should send message again + th.DtNet.AssertSentMessage(t, testharness.FakeSentMessage{PeerID: th.Channel.OtherPeer(), TransportID: "graphsync", Message: th.UpdateResponse(true)}) + // should not pause again + require.Len(t, th.Fgs.Pauses, 1) + }) + t.Run("resume", func(t *testing.T) { + th.Channel.SetResponderPaused(false) + th.Transport.ChannelUpdated(ctx, th.Channel.ChannelID(), th.UpdateResponse(false)) + require.Len(t, th.Fgs.Resumes, 3) + resume := th.Fgs.Resumes[2] + require.Equal(t, request.ID(), resume.RequestID) + msg := resume.DTMessage(t) + require.Equal(t, msg, th.UpdateResponse(false)) + }) + t.Run("resume again", func(t *testing.T) { + th.Transport.ChannelUpdated(ctx, th.Channel.ChannelID(), th.UpdateResponse(false)) + // should send message again + th.DtNet.AssertSentMessage(t, testharness.FakeSentMessage{PeerID: th.Channel.OtherPeer(), TransportID: "graphsync", Message: th.UpdateResponse(false)}) + // should not resume again + require.Len(t, th.Fgs.Resumes, 3) + }) + + t.Run("restart request", func(t *testing.T) { + restartIndex := int64(5) + th.Channel.SetReceivedIndex(basicnode.NewInt(restartIndex)) + dtResponse := th.RestartResponse(false) + th.Events.ReturnedRequestReceivedResponse = dtResponse + th.DtNet.Delegates[0].Receiver.ReceiveRequest(ctx, th.Channel.OtherPeer(), th.NewRequest(t)) + require.Len(t, th.DtNet.ProtectedPeers, 2) + require.Equal(t, th.DtNet.ProtectedPeers[1], testharness.TaggedPeer{th.Channel.OtherPeer(), th.Channel.ChannelID().String()}) + require.Len(t, th.Fgs.Cancels, 1) + th.Events.AssertTransportEvent(t, th.Channel.ChannelID(), datatransfer.TransportTransferCancelled{ErrorMessage: "graphsync request cancelled"}) + require.Equal(t, request.ID(), th.Fgs.Cancels[0]) + require.Len(t, th.Fgs.ReceivedRequests, 2) + receivedRequest = th.Fgs.ReceivedRequests[1] + request = receivedRequest.ToRequestData(t) + msg, err := extension.GetTransferData(request, []graphsync.ExtensionName{ + extension.ExtensionDataTransfer1_1, + }) + require.NoError(t, err) + require.Equal(t, dtResponse, msg) + nd, has := request.Extension(graphsync.ExtensionsDoNotSendFirstBlocks) + require.True(t, has) + val, err := nd.AsInt() + require.NoError(t, err) + require.Equal(t, restartIndex, val) + require.Len(t, contextAugmentedCalls, 2) + }) + + t.Run("complete request", func(t *testing.T) { + close(receivedRequest.ResponseChan) + close(receivedRequest.ResponseErrChan) + select { + case <-th.CompletedRequests: + case <-ctx.Done(): + t.Fatalf("did not complete request") + } + th.Events.AssertTransportEvent(t, th.Channel.ChannelID(), datatransfer.TransportCompletedTransfer{Success: true}) + }) + + t.Run("cleanup request", func(t *testing.T) { + th.Transport.CleanupChannel(th.Channel.ChannelID()) + require.Len(t, th.DtNet.UnprotectedPeers, 1) + require.Equal(t, th.DtNet.UnprotectedPeers[0], testharness.TaggedPeer{th.Channel.OtherPeer(), th.Channel.ChannelID().String()}) + }) +} diff --git a/transport/graphsync/testharness/events.go b/transport/graphsync/testharness/events.go new file mode 100644 index 00000000..398a3ba8 --- /dev/null +++ b/transport/graphsync/testharness/events.go @@ -0,0 +1,74 @@ +package testharness + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/require" + + datatransfer "github.com/filecoin-project/go-data-transfer/v2" +) + +type ReceivedTransportEvent struct { + ChannelID datatransfer.ChannelID + TransportEvent datatransfer.TransportEvent +} + +type FakeEvents struct { + // function return value parameters + ReturnedRequestReceivedResponse datatransfer.Response + ReturnedRequestReceivedError error + ReturnedResponseReceivedError error + ReturnedChannelState datatransfer.ChannelState + ReturnedOnContextAugmentFunc func(context.Context) context.Context + + // recording of actions + OnRequestReceivedCalled bool + ReceivedRequest datatransfer.Request + OnResponseReceivedCalled bool + ReceivedResponse datatransfer.Response + ReceivedTransportEvents []ReceivedTransportEvent +} + +func (fe *FakeEvents) OnTransportEvent(chid datatransfer.ChannelID, evt datatransfer.TransportEvent) { + fe.ReceivedTransportEvents = append(fe.ReceivedTransportEvents, ReceivedTransportEvent{chid, evt}) +} + +func (fe *FakeEvents) AssertTransportEvent(t *testing.T, chid datatransfer.ChannelID, evt datatransfer.TransportEvent) { + require.Contains(t, fe.ReceivedTransportEvents, ReceivedTransportEvent{chid, evt}) +} + +func (fe *FakeEvents) AssertTransportEventEventually(t *testing.T, chid datatransfer.ChannelID, evt datatransfer.TransportEvent) { + require.Eventually(t, func() bool { + for _, receivedEvent := range fe.ReceivedTransportEvents { + if (receivedEvent == ReceivedTransportEvent{chid, evt}) { + return true + } + } + return false + }, time.Second, time.Millisecond) +} + +func (fe *FakeEvents) RefuteTransportEvent(t *testing.T, chid datatransfer.ChannelID, evt datatransfer.TransportEvent) { + require.NotContains(t, fe.ReceivedTransportEvents, ReceivedTransportEvent{chid, evt}) +} +func (fe *FakeEvents) OnRequestReceived(chid datatransfer.ChannelID, request datatransfer.Request) (datatransfer.Response, error) { + fe.OnRequestReceivedCalled = true + fe.ReceivedRequest = request + return fe.ReturnedRequestReceivedResponse, fe.ReturnedRequestReceivedError +} + +func (fe *FakeEvents) OnResponseReceived(chid datatransfer.ChannelID, response datatransfer.Response) error { + fe.OnResponseReceivedCalled = true + fe.ReceivedResponse = response + return fe.ReturnedResponseReceivedError +} + +func (fe *FakeEvents) OnContextAugment(chid datatransfer.ChannelID) func(context.Context) context.Context { + return fe.ReturnedOnContextAugmentFunc +} + +func (fe *FakeEvents) ChannelState(ctx context.Context, chid datatransfer.ChannelID) (datatransfer.ChannelState, error) { + return fe.ReturnedChannelState, nil +} diff --git a/testutil/fakegraphsync.go b/transport/graphsync/testharness/fakegraphsync.go similarity index 68% rename from testutil/fakegraphsync.go rename to transport/graphsync/testharness/fakegraphsync.go index c6477522..5f31b454 100644 --- a/testutil/fakegraphsync.go +++ b/transport/graphsync/testharness/fakegraphsync.go @@ -1,4 +1,4 @@ -package testutil +package testharness import ( "context" @@ -13,18 +13,20 @@ import ( "github.com/ipld/go-ipld-prime/datamodel" cidlink "github.com/ipld/go-ipld-prime/linking/cid" "github.com/ipld/go-ipld-prime/traversal" + selectorparse "github.com/ipld/go-ipld-prime/traversal/selector/parse" "github.com/libp2p/go-libp2p-core/peer" "github.com/stretchr/testify/require" - datatransfer "github.com/filecoin-project/go-data-transfer" - "github.com/filecoin-project/go-data-transfer/message" - "github.com/filecoin-project/go-data-transfer/transport/graphsync/extension" + datatransfer "github.com/filecoin-project/go-data-transfer/v2" + "github.com/filecoin-project/go-data-transfer/v2/message" + "github.com/filecoin-project/go-data-transfer/v2/testutil" + "github.com/filecoin-project/go-data-transfer/v2/transport/graphsync/extension" ) -func matchDtMessage(t *testing.T, extensions []graphsync.ExtensionData) datatransfer.Message { +func matchDtMessage(t *testing.T, extensions []graphsync.ExtensionData, extName graphsync.ExtensionName) datatransfer.Message { var matchedExtension *graphsync.ExtensionData for _, ext := range extensions { - if ext.Name == extension.ExtensionDataTransfer1_1 { + if ext.Name == extName { matchedExtension = &ext break } @@ -40,15 +42,43 @@ type ReceivedGraphSyncRequest struct { Ctx context.Context P peer.ID Root ipld.Link - Selector ipld.Node + Selector datamodel.Node Extensions []graphsync.ExtensionData ResponseChan chan graphsync.ResponseProgress ResponseErrChan chan error } +func (gsRequest ReceivedGraphSyncRequest) ToRequestData(t *testing.T) graphsync.RequestData { + extensions := make(map[graphsync.ExtensionName]datamodel.Node) + for _, extension := range gsRequest.Extensions { + extensions[extension.Name] = extension.Data + } + requestID, ok := gsRequest.requestID() + require.True(t, ok) + return NewFakeRequest(requestID, extensions, graphsync.RequestTypeNew) +} + +func (gsRequest ReceivedGraphSyncRequest) Response(t *testing.T, incomingRequestMsg datatransfer.Message, blockMessage datatransfer.Message, code graphsync.ResponseStatusCode) graphsync.ResponseData { + extensions := make(map[graphsync.ExtensionName]datamodel.Node) + if incomingRequestMsg != nil { + extensions[extension.ExtensionIncomingRequest1_1] = incomingRequestMsg.ToIPLD() + } + if blockMessage != nil { + extensions[extension.ExtensionOutgoingBlock1_1] = blockMessage.ToIPLD() + } + requestID, ok := gsRequest.requestID() + require.True(t, ok) + return NewFakeResponse(requestID, extensions, code) +} + +func (gsRequest ReceivedGraphSyncRequest) requestID() (graphsync.RequestID, bool) { + request, ok := gsRequest.Ctx.Value(graphsync.RequestIDContextKey{}).(graphsync.RequestID) + return request, ok +} + // DTMessage returns the data transfer message among the graphsync extensions sent with this request func (gsRequest ReceivedGraphSyncRequest) DTMessage(t *testing.T) datatransfer.Message { - return matchDtMessage(t, gsRequest.Extensions) + return matchDtMessage(t, gsRequest.Extensions, extension.ExtensionDataTransfer1_1) } type Resume struct { @@ -58,7 +88,7 @@ type Resume struct { // DTMessage returns the data transfer message among the graphsync extensions sent with this request func (resume Resume) DTMessage(t *testing.T) datatransfer.Message { - return matchDtMessage(t, resume.Extensions) + return matchDtMessage(t, resume.Extensions, extension.ExtensionDataTransfer1_1) } type Update struct { @@ -68,41 +98,41 @@ type Update struct { // DTMessage returns the data transfer message among the graphsync extensions sent with this request func (update Update) DTMessage(t *testing.T) datatransfer.Message { - return matchDtMessage(t, update.Extensions) + return matchDtMessage(t, update.Extensions, extension.ExtensionDataTransfer1_1) } // FakeGraphSync implements a GraphExchange but does nothing type FakeGraphSync struct { - requests chan ReceivedGraphSyncRequest // records calls to fakeGraphSync.Request - pauses chan graphsync.RequestID - resumes chan Resume - cancels chan graphsync.RequestID - updates chan Update - persistenceOptionsLk sync.RWMutex - persistenceOptions map[string]ipld.LinkSystem - leaveRequestsOpen bool - OutgoingRequestHook graphsync.OnOutgoingRequestHook - IncomingBlockHook graphsync.OnIncomingBlockHook - OutgoingBlockHook graphsync.OnOutgoingBlockHook - IncomingRequestQueuedHook graphsync.OnIncomingRequestQueuedHook - IncomingRequestHook graphsync.OnIncomingRequestHook - CompletedResponseListener graphsync.OnResponseCompletedListener - RequestUpdatedHook graphsync.OnRequestUpdatedHook - IncomingResponseHook graphsync.OnIncomingResponseHook - RequestorCancelledListener graphsync.OnRequestorCancelledListener - BlockSentListener graphsync.OnBlockSentListener - NetworkErrorListener graphsync.OnNetworkErrorListener - ReceiverNetworkErrorListener graphsync.OnReceiverNetworkErrorListener + ReceivedRequests []ReceivedGraphSyncRequest // records calls to fakeGraphSync.Request + Pauses []graphsync.RequestID + Resumes []Resume + Cancels []graphsync.RequestID + Updates []Update + persistenceOptionsLk sync.RWMutex + persistenceOptions map[string]ipld.LinkSystem + leaveRequestsOpen bool + OutgoingRequestHook graphsync.OnOutgoingRequestHook + IncomingBlockHook graphsync.OnIncomingBlockHook + OutgoingBlockHook graphsync.OnOutgoingBlockHook + IncomingRequestProcessingListener graphsync.OnRequestProcessingListener + OutgoingRequestProcessingListener graphsync.OnRequestProcessingListener + IncomingRequestHook graphsync.OnIncomingRequestHook + CompletedResponseListener graphsync.OnResponseCompletedListener + RequestUpdatedHook graphsync.OnRequestUpdatedHook + IncomingResponseHook graphsync.OnIncomingResponseHook + RequestorCancelledListener graphsync.OnRequestorCancelledListener + BlockSentListener graphsync.OnBlockSentListener + NetworkErrorListener graphsync.OnNetworkErrorListener + ReceiverNetworkErrorListener graphsync.OnReceiverNetworkErrorListener + ReturnedCancelError error + ReturnedPauseError error + ReturnedResumeError error + ReturnedSendUpdateError error } // NewFakeGraphSync returns a new fake graphsync implementation func NewFakeGraphSync() *FakeGraphSync { return &FakeGraphSync{ - requests: make(chan ReceivedGraphSyncRequest, 2), - pauses: make(chan graphsync.RequestID, 1), - resumes: make(chan Resume, 1), - cancels: make(chan graphsync.RequestID, 1), - updates: make(chan Update, 1), persistenceOptions: make(map[string]ipld.LinkSystem), } } @@ -111,70 +141,6 @@ func (fgs *FakeGraphSync) LeaveRequestsOpen() { fgs.leaveRequestsOpen = true } -// AssertNoRequestReceived asserts that no requests should ahve been received by this graphsync implementation -func (fgs *FakeGraphSync) AssertNoRequestReceived(t *testing.T) { - require.Empty(t, fgs.requests, "should not receive request") -} - -// AssertRequestReceived asserts a request should be received before the context closes (and returns said request) -func (fgs *FakeGraphSync) AssertRequestReceived(ctx context.Context, t *testing.T) ReceivedGraphSyncRequest { - var requestReceived ReceivedGraphSyncRequest - select { - case <-ctx.Done(): - t.Fatal("did not receive message sent") - case requestReceived = <-fgs.requests: - } - return requestReceived -} - -// AssertNoPauseReceived asserts that no pause requests should ahve been received by this graphsync implementation -func (fgs *FakeGraphSync) AssertNoPauseReceived(t *testing.T) { - require.Empty(t, fgs.pauses, "should not receive pause request") -} - -// AssertPauseReceived asserts a pause request should be received before the context closes (and returns said request) -func (fgs *FakeGraphSync) AssertPauseReceived(ctx context.Context, t *testing.T) graphsync.RequestID { - var pauseReceived graphsync.RequestID - select { - case <-ctx.Done(): - t.Fatal("did not receive message sent") - case pauseReceived = <-fgs.pauses: - } - return pauseReceived -} - -// AssertNoResumeReceived asserts that no resume requests should ahve been received by this graphsync implementation -func (fgs *FakeGraphSync) AssertNoResumeReceived(t *testing.T) { - require.Empty(t, fgs.resumes, "should not receive resume request") -} - -// AssertResumeReceived asserts a resume request should be received before the context closes (and returns said request) -func (fgs *FakeGraphSync) AssertResumeReceived(ctx context.Context, t *testing.T) Resume { - var resumeReceived Resume - select { - case <-ctx.Done(): - t.Fatal("did not receive message sent") - case resumeReceived = <-fgs.resumes: - } - return resumeReceived -} - -// AssertNoCancelReceived asserts that no requests were cancelled by thiss graphsync implementation -func (fgs *FakeGraphSync) AssertNoCancelReceived(t *testing.T) { - require.Empty(t, fgs.cancels, "should not cancel request") -} - -// AssertCancelReceived asserts a requests was cancelled before the context closes (and returns said request id) -func (fgs *FakeGraphSync) AssertCancelReceived(ctx context.Context, t *testing.T) graphsync.RequestID { - var cancelReceived graphsync.RequestID - select { - case <-ctx.Done(): - t.Fatal("did not receive message sent") - case cancelReceived = <-fgs.cancels: - } - return cancelReceived -} - // AssertHasPersistenceOption verifies that a persistence option was registered func (fgs *FakeGraphSync) AssertHasPersistenceOption(t *testing.T, name string) ipld.LinkSystem { fgs.persistenceOptionsLk.RLock() @@ -193,10 +159,10 @@ func (fgs *FakeGraphSync) AssertDoesNotHavePersistenceOption(t *testing.T, name } // Request initiates a new GraphSync request to the given peer using the given selector spec. -func (fgs *FakeGraphSync) Request(ctx context.Context, p peer.ID, root ipld.Link, selector ipld.Node, extensions ...graphsync.ExtensionData) (<-chan graphsync.ResponseProgress, <-chan error) { - errors := make(chan error) - responses := make(chan graphsync.ResponseProgress) - fgs.requests <- ReceivedGraphSyncRequest{ctx, p, root, selector, extensions, responses, errors} +func (fgs *FakeGraphSync) Request(ctx context.Context, p peer.ID, root ipld.Link, selector datamodel.Node, extensions ...graphsync.ExtensionData) (<-chan graphsync.ResponseProgress, <-chan error) { + errors := make(chan error, 1) + responses := make(chan graphsync.ResponseProgress, 1) + fgs.ReceivedRequests = append(fgs.ReceivedRequests, ReceivedGraphSyncRequest{ctx, p, root, selector, extensions, responses, errors}) if !fgs.leaveRequestsOpen { close(responses) close(errors) @@ -232,11 +198,11 @@ func (fgs *FakeGraphSync) RegisterIncomingRequestHook(hook graphsync.OnIncomingR } } -// RegisterIncomingRequestQueuedHook adds a hook that runs when an incoming GS request is queued. -func (fgs *FakeGraphSync) RegisterIncomingRequestQueuedHook(hook graphsync.OnIncomingRequestQueuedHook) graphsync.UnregisterHookFunc { - fgs.IncomingRequestQueuedHook = hook +// RegisterIncomingRequestProcessingListener adds a hook that runs when an incoming GS request begins processing +func (fgs *FakeGraphSync) RegisterIncomingRequestProcessingListener(hook graphsync.OnRequestProcessingListener) graphsync.UnregisterHookFunc { + fgs.IncomingRequestProcessingListener = hook return func() { - fgs.IncomingRequestQueuedHook = nil + fgs.IncomingRequestProcessingListener = nil } } @@ -290,19 +256,29 @@ func (fgs *FakeGraphSync) RegisterCompletedResponseListener(listener graphsync.O // Unpause unpauses a request that was paused in a block hook based on request ID func (fgs *FakeGraphSync) Unpause(ctx context.Context, requestID graphsync.RequestID, extensions ...graphsync.ExtensionData) error { - fgs.resumes <- Resume{requestID, extensions} - return nil + fgs.Resumes = append(fgs.Resumes, Resume{requestID, extensions}) + return fgs.ReturnedResumeError } // Pause pauses a request based on request ID func (fgs *FakeGraphSync) Pause(ctx context.Context, requestID graphsync.RequestID) error { - fgs.pauses <- requestID - return nil + fgs.Pauses = append(fgs.Pauses, requestID) + return fgs.ReturnedPauseError } func (fgs *FakeGraphSync) Cancel(ctx context.Context, requestID graphsync.RequestID) error { - fgs.cancels <- requestID - return nil + if fgs.leaveRequestsOpen { + for _, rr := range fgs.ReceivedRequests { + existingRequestID, has := rr.requestID() + if has && requestID.String() == existingRequestID.String() { + close(rr.ResponseChan) + rr.ResponseErrChan <- graphsync.RequestClientCancelledErr{} + close(rr.ResponseErrChan) + } + } + } + fgs.Cancels = append(fgs.Cancels, requestID) + return fgs.ReturnedCancelError } // RegisterRequestorCancelledListener adds a listener on the responder for requests cancelled by the requestor @@ -341,22 +317,25 @@ func (fgs *FakeGraphSync) Stats() graphsync.Stats { return graphsync.Stats{} } -func (fgs *FakeGraphSync) RegisterOutgoingRequestProcessingListener(graphsync.OnOutgoingRequestProcessingListener) graphsync.UnregisterHookFunc { - // TODO: just a stub for now, hopefully nobody needs this - return func() {} +func (fgs *FakeGraphSync) RegisterOutgoingRequestProcessingListener(listener graphsync.OnRequestProcessingListener) graphsync.UnregisterHookFunc { + fgs.OutgoingRequestProcessingListener = listener + return func() { + fgs.OutgoingRequestProcessingListener = nil + } } func (fgs *FakeGraphSync) SendUpdate(ctx context.Context, id graphsync.RequestID, extensions ...graphsync.ExtensionData) error { - fgs.updates <- Update{RequestID: id, Extensions: extensions} - return nil + fgs.Updates = append(fgs.Updates, Update{RequestID: id, Extensions: extensions}) + return fgs.ReturnedSendUpdateError } var _ graphsync.GraphExchange = &FakeGraphSync{} type fakeBlkData struct { - link ipld.Link - size uint64 - index int64 + link ipld.Link + size uint64 + onWire bool + index int64 } func (fbd fakeBlkData) Link() ipld.Link { @@ -368,7 +347,10 @@ func (fbd fakeBlkData) BlockSize() uint64 { } func (fbd fakeBlkData) BlockSizeOnWire() uint64 { - return fbd.size + if fbd.onWire { + return fbd.size + } + return 0 } func (fbd fakeBlkData) Index() int64 { @@ -376,18 +358,19 @@ func (fbd fakeBlkData) Index() int64 { } // NewFakeBlockData returns a fake block that matches the block data interface -func NewFakeBlockData() graphsync.BlockData { +func NewFakeBlockData(size uint64, index int64, onWire bool) graphsync.BlockData { return &fakeBlkData{ - link: cidlink.Link{Cid: GenerateCids(1)[0]}, - size: rand.Uint64(), - index: int64(rand.Uint32()), + link: cidlink.Link{Cid: testutil.GenerateCids(1)[0]}, + size: size, + index: index, + onWire: onWire, } } type fakeRequest struct { id graphsync.RequestID root cid.Cid - selector ipld.Node + selector datamodel.Node priority graphsync.Priority requestType graphsync.RequestType extensions map[graphsync.ExtensionName]datamodel.Node @@ -404,7 +387,7 @@ func (fr *fakeRequest) Root() cid.Cid { } // Selector returns the byte representation of the selector for this request -func (fr *fakeRequest) Selector() ipld.Node { +func (fr *fakeRequest) Selector() datamodel.Node { return fr.selector } @@ -426,14 +409,14 @@ func (fr *fakeRequest) Type() graphsync.RequestType { } // NewFakeRequest returns a fake request that matches the request data interface -func NewFakeRequest(id graphsync.RequestID, extensions map[graphsync.ExtensionName]datamodel.Node) graphsync.RequestData { +func NewFakeRequest(id graphsync.RequestID, extensions map[graphsync.ExtensionName]datamodel.Node, requestType graphsync.RequestType) graphsync.RequestData { return &fakeRequest{ id: id, - root: GenerateCids(1)[0], - selector: allSelector, + root: testutil.GenerateCids(1)[0], + selector: selectorparse.CommonSelector_ExploreAllRecursively, priority: graphsync.Priority(rand.Int()), extensions: extensions, - requestType: graphsync.RequestTypeNew, + requestType: requestType, } } @@ -532,6 +515,7 @@ type FakeIncomingRequestHookActions struct { Validated bool SentExtensions []graphsync.ExtensionData Paused bool + CtxAugFuncs []func(context.Context) context.Context } func (fa *FakeIncomingRequestHookActions) SendExtensionData(ext graphsync.ExtensionData) { @@ -557,6 +541,30 @@ func (fa *FakeIncomingRequestHookActions) PauseResponse() { fa.Paused = true } +func (fa *FakeIncomingRequestHookActions) AugmentContext(ctxAugFunc func(reqCtx context.Context) context.Context) { + fa.CtxAugFuncs = append(fa.CtxAugFuncs, ctxAugFunc) +} + +func (fa *FakeIncomingRequestHookActions) AssertAugmentedContextKey(t *testing.T, key interface{}, value interface{}) { + ctx := context.Background() + for _, f := range fa.CtxAugFuncs { + ctx = f(ctx) + } + require.Equal(t, value, ctx.Value(key)) +} + +func (fa *FakeIncomingRequestHookActions) RefuteAugmentedContextKey(t *testing.T, key interface{}) { + ctx := context.Background() + for _, f := range fa.CtxAugFuncs { + ctx = f(ctx) + } + require.Nil(t, ctx.Value(key)) +} + +func (fa *FakeIncomingRequestHookActions) DTMessage(t *testing.T) datatransfer.Message { + return matchDtMessage(t, fa.SentExtensions, extension.ExtensionIncomingRequest1_1) +} + var _ graphsync.IncomingRequestHookActions = &FakeIncomingRequestHookActions{} type FakeRequestUpdatedActions struct { @@ -593,13 +601,3 @@ func (fa *FakeIncomingResponseHookActions) UpdateRequestWithExtensions(extension } var _ graphsync.IncomingResponseHookActions = &FakeIncomingResponseHookActions{} - -type FakeRequestQueuedHookActions struct { - ctxAugFuncs []func(context.Context) context.Context -} - -func (fa *FakeRequestQueuedHookActions) AugmentContext(ctxAugFunc func(reqCtx context.Context) context.Context) { - fa.ctxAugFuncs = append(fa.ctxAugFuncs, ctxAugFunc) -} - -var _ graphsync.RequestQueuedHookActions = &FakeRequestQueuedHookActions{} diff --git a/transport/graphsync/testharness/harness.go b/transport/graphsync/testharness/harness.go new file mode 100644 index 00000000..c545ad8b --- /dev/null +++ b/transport/graphsync/testharness/harness.go @@ -0,0 +1,285 @@ +package testharness + +import ( + "context" + "math/rand" + "testing" + + "github.com/ipfs/go-graphsync" + "github.com/ipld/go-ipld-prime" + "github.com/ipld/go-ipld-prime/datamodel" + "github.com/ipld/go-ipld-prime/node/basicnode" + "github.com/ipld/go-ipld-prime/traversal/selector/builder" + "github.com/stretchr/testify/require" + + datatransfer "github.com/filecoin-project/go-data-transfer/v2" + "github.com/filecoin-project/go-data-transfer/v2/message" + "github.com/filecoin-project/go-data-transfer/v2/message/types" + "github.com/filecoin-project/go-data-transfer/v2/testutil" + dtgs "github.com/filecoin-project/go-data-transfer/v2/transport/graphsync" + "github.com/filecoin-project/go-data-transfer/v2/transport/graphsync/extension" +) + +type harnessConfig struct { + isPull bool + isResponder bool + makeEvents func(gsData *GsTestHarness) *FakeEvents + makeNetwork func(gsData *GsTestHarness) *FakeNetwork + transportOptions []dtgs.Option +} + +type Option func(*harnessConfig) + +func PullRequest() Option { + return func(hc *harnessConfig) { + hc.isPull = true + } +} + +func Responder() Option { + return func(hc *harnessConfig) { + hc.isResponder = true + } +} + +func Events(makeEvents func(gsData *GsTestHarness) *FakeEvents) Option { + return func(hc *harnessConfig) { + hc.makeEvents = makeEvents + } +} + +func Network(makeNetwork func(gsData *GsTestHarness) *FakeNetwork) Option { + return func(hc *harnessConfig) { + hc.makeNetwork = makeNetwork + } +} + +func TransportOptions(options []dtgs.Option) Option { + return func(hc *harnessConfig) { + hc.transportOptions = options + } +} + +func SetupHarness(ctx context.Context, options ...Option) *GsTestHarness { + hc := &harnessConfig{} + for _, option := range options { + option(hc) + } + peers := testutil.GeneratePeers(2) + transferID := datatransfer.TransferID(rand.Uint32()) + fgs := NewFakeGraphSync() + fgs.LeaveRequestsOpen() + voucher := testutil.NewTestTypedVoucher() + baseCid := testutil.GenerateCids(1)[0] + selector := builder.NewSelectorSpecBuilder(basicnode.Prototype.Any).Matcher().Node() + chid := datatransfer.ChannelID{Initiator: peers[0], Responder: peers[1], ID: transferID} + if hc.isResponder { + chid = datatransfer.ChannelID{Initiator: peers[1], Responder: peers[0], ID: transferID} + } + channel := testutil.NewMockChannelState(testutil.MockChannelStateParams{ + BaseCID: baseCid, + Voucher: voucher, + Selector: selector, + IsPull: hc.isPull, + Self: peers[0], + ChannelID: chid, + }) + gsData := &GsTestHarness{ + Ctx: ctx, + Fgs: fgs, + Channel: channel, + CompletedRequests: make(chan datatransfer.ChannelID, 16), + CompletedResponses: make(chan datatransfer.ChannelID, 16), + OutgoingRequestHookActions: &FakeOutgoingRequestHookActions{}, + OutgoingBlockHookActions: &FakeOutgoingBlockHookActions{}, + IncomingBlockHookActions: &FakeIncomingBlockHookActions{}, + IncomingRequestHookActions: &FakeIncomingRequestHookActions{}, + RequestUpdateHookActions: &FakeRequestUpdatedActions{}, + IncomingResponseHookActions: &FakeIncomingResponseHookActions{}, + } + if hc.makeEvents != nil { + gsData.Events = hc.makeEvents(gsData) + } else { + gsData.Events = &FakeEvents{ + ReturnedChannelState: channel, + } + } + if hc.makeNetwork != nil { + gsData.DtNet = hc.makeNetwork(gsData) + } else { + gsData.DtNet = NewFakeNetwork(peers[0]) + } + gsData.Transport = dtgs.NewTransport(gsData.Fgs, gsData.DtNet, + append(hc.transportOptions, + dtgs.RegisterCompletedRequestListener(gsData.completedRequestListener), + dtgs.RegisterCompletedResponseListener(gsData.completedResponseListener))...) + gsData.Transport.SetEventHandler(gsData.Events) + return gsData +} + +type GsTestHarness struct { + Ctx context.Context + Fgs *FakeGraphSync + Channel *testutil.MockChannelState + RequestID graphsync.RequestID + AltRequestID graphsync.RequestID + Events *FakeEvents + DtNet *FakeNetwork + OutgoingRequestHookActions *FakeOutgoingRequestHookActions + IncomingBlockHookActions *FakeIncomingBlockHookActions + OutgoingBlockHookActions *FakeOutgoingBlockHookActions + IncomingRequestHookActions *FakeIncomingRequestHookActions + RequestUpdateHookActions *FakeRequestUpdatedActions + IncomingResponseHookActions *FakeIncomingResponseHookActions + Transport *dtgs.Transport + CompletedRequests chan datatransfer.ChannelID + CompletedResponses chan datatransfer.ChannelID +} + +func (th *GsTestHarness) completedRequestListener(chid datatransfer.ChannelID) { + th.CompletedRequests <- chid +} +func (th *GsTestHarness) completedResponseListener(chid datatransfer.ChannelID) { + th.CompletedResponses <- chid +} + +func (th *GsTestHarness) NewRequest(t *testing.T) datatransfer.Request { + vouch := th.Channel.Voucher() + message, err := message.NewRequest(th.Channel.TransferID(), false, th.Channel.IsPull(), &vouch, th.Channel.BaseCID(), th.Channel.Selector()) + require.NoError(t, err) + return message +} + +func (th *GsTestHarness) RestartRequest(t *testing.T) datatransfer.Request { + vouch := th.Channel.Voucher() + message, err := message.NewRequest(th.Channel.TransferID(), true, th.Channel.IsPull(), &vouch, th.Channel.BaseCID(), th.Channel.Selector()) + require.NoError(t, err) + return message +} + +func (th *GsTestHarness) VoucherRequest() datatransfer.Request { + newVouch := testutil.NewTestTypedVoucher() + return message.VoucherRequest(th.Channel.TransferID(), &newVouch) +} + +func (th *GsTestHarness) UpdateRequest(pause bool) datatransfer.Request { + return message.UpdateRequest(th.Channel.TransferID(), pause) +} + +func (th *GsTestHarness) Response() datatransfer.Response { + voucherResult := testutil.NewTestTypedVoucher() + return message.NewResponse(th.Channel.TransferID(), true, false, &voucherResult) +} + +func (th *GsTestHarness) ValidationResultResponse(pause bool) datatransfer.Response { + voucherResult := testutil.NewTestTypedVoucher() + return message.ValidationResultResponse(types.VoucherResultMessage, th.Channel.TransferID(), datatransfer.ValidationResult{VoucherResult: &voucherResult, Accepted: true}, nil, pause) +} + +func (th *GsTestHarness) RestartResponse(pause bool) datatransfer.Response { + voucherResult := testutil.NewTestTypedVoucher() + return message.ValidationResultResponse(types.RestartMessage, th.Channel.TransferID(), datatransfer.ValidationResult{VoucherResult: &voucherResult, Accepted: true}, nil, pause) +} + +func (th *GsTestHarness) UpdateResponse(paused bool) datatransfer.Response { + return message.UpdateResponse(th.Channel.TransferID(), true) +} + +func (th *GsTestHarness) OutgoingRequestHook(request graphsync.RequestData) { + th.Fgs.OutgoingRequestHook(th.Channel.OtherPeer(), request, th.OutgoingRequestHookActions) +} + +func (th *GsTestHarness) OutgoingRequestProcessingListener(request graphsync.RequestData) { + th.Fgs.OutgoingRequestProcessingListener(th.Channel.OtherPeer(), request, 0) +} + +func (th *GsTestHarness) IncomingBlockHook(response graphsync.ResponseData, block graphsync.BlockData) { + th.Fgs.IncomingBlockHook(th.Channel.OtherPeer(), response, block, th.IncomingBlockHookActions) +} + +func (th *GsTestHarness) OutgoingBlockHook(request graphsync.RequestData, block graphsync.BlockData) { + th.Fgs.OutgoingBlockHook(th.Channel.OtherPeer(), request, block, th.OutgoingBlockHookActions) +} + +func (th *GsTestHarness) IncomingRequestHook(request graphsync.RequestData) { + th.Fgs.IncomingRequestHook(th.Channel.OtherPeer(), request, th.IncomingRequestHookActions) +} + +func (th *GsTestHarness) IncomingRequestProcessingListener(request graphsync.RequestData) { + th.Fgs.IncomingRequestProcessingListener(th.Channel.OtherPeer(), request, 1) +} + +func (th *GsTestHarness) IncomingResponseHook(response graphsync.ResponseData) { + th.Fgs.IncomingResponseHook(th.Channel.OtherPeer(), response, th.IncomingResponseHookActions) +} + +func (th *GsTestHarness) ResponseCompletedListener(request graphsync.RequestData, code graphsync.ResponseStatusCode) { + th.Fgs.CompletedResponseListener(th.Channel.OtherPeer(), request, code) +} + +func (th *GsTestHarness) RequestorCancelledListener(request graphsync.RequestData) { + th.Fgs.RequestorCancelledListener(th.Channel.OtherPeer(), request) +} + +/* +func (ha *GsTestHarness) networkErrorListener(err error) { + ha.Fgs.NetworkErrorListener(ha.other, ha.request, err) +} +func (ha *GsTestHarness) receiverNetworkErrorListener(err error) { + ha.Fgs.ReceiverNetworkErrorListener(ha.other, err) +} +*/ + +func (th *GsTestHarness) BlockSentListener(request graphsync.RequestData, block graphsync.BlockData) { + th.Fgs.BlockSentListener(th.Channel.OtherPeer(), request, block) +} + +func (ha *GsTestHarness) makeRequest(requestID graphsync.RequestID, messageNode datamodel.Node, requestType graphsync.RequestType) graphsync.RequestData { + extensions := make(map[graphsync.ExtensionName]datamodel.Node) + if messageNode != nil { + extensions[extension.ExtensionDataTransfer1_1] = messageNode + } + return NewFakeRequest(requestID, extensions, requestType) +} + +func (ha *GsTestHarness) makeResponse(requestID graphsync.RequestID, messageNode datamodel.Node, responseCode graphsync.ResponseStatusCode) graphsync.ResponseData { + extensions := make(map[graphsync.ExtensionName]datamodel.Node) + if messageNode != nil { + extensions[extension.ExtensionDataTransfer1_1] = messageNode + } + return NewFakeResponse(requestID, extensions, responseCode) +} + +func assertDecodesToMessage(t *testing.T, data datamodel.Node, expected datatransfer.Message) { + actual, err := message.FromIPLD(data) + require.NoError(t, err) + require.Equal(t, expected, actual) +} + +func assertHasOutgoingMessage(t *testing.T, extensions []graphsync.ExtensionData, expected datatransfer.Message) { + nd := expected.ToIPLD() + found := false + for _, e := range extensions { + if e.Name == extension.ExtensionDataTransfer1_1 { + require.True(t, ipld.DeepEqual(nd, e.Data), "data matches") + found = true + } + } + if !found { + require.Fail(t, "extension not found") + } +} + +func assertHasExtensionMessage(t *testing.T, name graphsync.ExtensionName, extensions []graphsync.ExtensionData, expected datatransfer.Message) { + nd := expected.ToIPLD() + found := false + for _, e := range extensions { + if e.Name == name { + require.True(t, ipld.DeepEqual(nd, e.Data), "data matches") + found = true + } + } + if !found { + require.Fail(t, "extension not found") + } +} diff --git a/transport/graphsync/testharness/testnet.go b/transport/graphsync/testharness/testnet.go new file mode 100644 index 00000000..14ab118b --- /dev/null +++ b/transport/graphsync/testharness/testnet.go @@ -0,0 +1,106 @@ +package testharness + +import ( + "context" + "testing" + + "github.com/libp2p/go-libp2p-core/peer" + "github.com/stretchr/testify/require" + + datatransfer "github.com/filecoin-project/go-data-transfer/v2" + "github.com/filecoin-project/go-data-transfer/v2/transport/helpers/network" +) + +// FakeSentMessage is a recording of a message sent on the FakeNetwork +type FakeSentMessage struct { + PeerID peer.ID + TransportID datatransfer.TransportID + Message datatransfer.Message +} + +type FakeDelegates struct { + TransportID datatransfer.TransportID + Versions []datatransfer.Version + Receiver network.Receiver +} + +type ConnectWithRetryAttempt struct { + PeerID peer.ID + TransportID datatransfer.TransportID +} + +type TaggedPeer struct { + PeerID peer.ID + Tag string +} + +// FakeNetwork is a network that satisfies the DataTransferNetwork interface but +// does not actually do anything +type FakeNetwork struct { + SentMessages []FakeSentMessage + Delegates []FakeDelegates + ConnectWithRetryAttempts []ConnectWithRetryAttempt + ProtectedPeers []TaggedPeer + UnprotectedPeers []TaggedPeer + + ReturnedPeerDescription network.ProtocolDescription + ReturnedPeerID peer.ID + ReturnedSendMessageError error + ReturnedConnectWithRetryError error +} + +// NewFakeNetwork returns a new fake data transfer network instance +func NewFakeNetwork(id peer.ID) *FakeNetwork { + return &FakeNetwork{ReturnedPeerID: id} +} + +var _ network.DataTransferNetwork = (*FakeNetwork)(nil) + +// SendMessage sends a GraphSync message to a peer. +func (fn *FakeNetwork) SendMessage(ctx context.Context, + p peer.ID, + t datatransfer.TransportID, + m datatransfer.Message) error { + fn.SentMessages = append(fn.SentMessages, FakeSentMessage{p, t, m}) + return fn.ReturnedSendMessageError +} + +// SetDelegate registers the Reciver to handle messages received from the +// network. +func (fn *FakeNetwork) SetDelegate(t datatransfer.TransportID, v []datatransfer.Version, r network.Receiver) { + fn.Delegates = append(fn.Delegates, FakeDelegates{t, v, r}) +} + +// ConnectTo establishes a connection to the given peer +func (fn *FakeNetwork) ConnectTo(_ context.Context, _ peer.ID) error { + return nil +} + +func (fn *FakeNetwork) ConnectWithRetry(ctx context.Context, p peer.ID, transportID datatransfer.TransportID) error { + fn.ConnectWithRetryAttempts = append(fn.ConnectWithRetryAttempts, ConnectWithRetryAttempt{p, transportID}) + return fn.ReturnedConnectWithRetryError +} + +// ID returns a stubbed id for host of this network +func (fn *FakeNetwork) ID() peer.ID { + return fn.ReturnedPeerID +} + +// Protect does nothing on the fake network +func (fn *FakeNetwork) Protect(id peer.ID, tag string) { + fn.ProtectedPeers = append(fn.ProtectedPeers, TaggedPeer{id, tag}) +} + +// Unprotect does nothing on the fake network +func (fn *FakeNetwork) Unprotect(id peer.ID, tag string) bool { + fn.UnprotectedPeers = append(fn.UnprotectedPeers, TaggedPeer{id, tag}) + return false +} + +func (fn *FakeNetwork) Protocol(ctx context.Context, id peer.ID, transportID datatransfer.TransportID) (network.ProtocolDescription, error) { + return fn.ReturnedPeerDescription, nil +} + +func (fn *FakeNetwork) AssertSentMessage(t *testing.T, sentMessage FakeSentMessage) { + require.Contains(t, fn.SentMessages, sentMessage) +} diff --git a/transport/graphsync/utils.go b/transport/graphsync/utils.go new file mode 100644 index 00000000..71f8568c --- /dev/null +++ b/transport/graphsync/utils.go @@ -0,0 +1,118 @@ +package graphsync + +import ( + "sync" + + "github.com/ipfs/go-graphsync" + peer "github.com/libp2p/go-libp2p-core/peer" + "golang.org/x/xerrors" + + datatransfer "github.com/filecoin-project/go-data-transfer/v2" + "github.com/filecoin-project/go-data-transfer/v2/transport/graphsync/dtchannel" +) + +func (t *Transport) trackDTChannel(chid datatransfer.ChannelID) *dtchannel.Channel { + t.dtChannelsLk.Lock() + defer t.dtChannelsLk.Unlock() + + ch, ok := t.dtChannels[chid] + if !ok { + ch = dtchannel.NewChannel(chid, t.gs) + t.dtChannels[chid] = ch + } + + return ch +} + +func (t *Transport) getDTChannel(chid datatransfer.ChannelID) (*dtchannel.Channel, error) { + if t.events == nil { + return nil, datatransfer.ErrHandlerNotSet + } + + t.dtChannelsLk.RLock() + defer t.dtChannelsLk.RUnlock() + + ch, ok := t.dtChannels[chid] + if !ok { + return nil, xerrors.Errorf("channel %s: %w", chid, datatransfer.ErrChannelNotFound) + } + return ch, nil +} + +func (t *Transport) otherPeer(chid datatransfer.ChannelID) peer.ID { + if chid.Initiator == t.peerID { + return chid.Responder + } + return chid.Initiator +} + +type channelInfo struct { + sending bool + channelID datatransfer.ChannelID +} + +// Used in graphsync callbacks to map from graphsync request to the +// associated data-transfer channel ID. +type requestIDToChannelIDMap struct { + lk sync.RWMutex + m map[graphsync.RequestID]channelInfo +} + +func newRequestIDToChannelIDMap() *requestIDToChannelIDMap { + return &requestIDToChannelIDMap{ + m: make(map[graphsync.RequestID]channelInfo), + } +} + +// get the value for a key +func (m *requestIDToChannelIDMap) load(key graphsync.RequestID) (datatransfer.ChannelID, bool) { + m.lk.RLock() + defer m.lk.RUnlock() + + val, ok := m.m[key] + return val.channelID, ok +} + +// get the value if any of the keys exists in the map +func (m *requestIDToChannelIDMap) any(ks ...graphsync.RequestID) (datatransfer.ChannelID, bool) { + m.lk.RLock() + defer m.lk.RUnlock() + + for _, k := range ks { + val, ok := m.m[k] + if ok { + return val.channelID, ok + } + } + return datatransfer.ChannelID{}, false +} + +// set the value for a key +func (m *requestIDToChannelIDMap) set(key graphsync.RequestID, sending bool, chid datatransfer.ChannelID) { + m.lk.Lock() + defer m.lk.Unlock() + + m.m[key] = channelInfo{sending, chid} +} + +// call f for each key / value in the map +func (m *requestIDToChannelIDMap) forEach(f func(k graphsync.RequestID, isSending bool, chid datatransfer.ChannelID)) { + m.lk.RLock() + defer m.lk.RUnlock() + + for k, ch := range m.m { + f(k, ch.sending, ch.channelID) + } +} + +// delete any keys that reference this value +func (m *requestIDToChannelIDMap) deleteRefs(id datatransfer.ChannelID) { + m.lk.Lock() + defer m.lk.Unlock() + + for k, ch := range m.m { + if ch.channelID == id { + delete(m.m, k) + } + } +} diff --git a/transport/helpers/network/interface.go b/transport/helpers/network/interface.go new file mode 100644 index 00000000..022a1a37 --- /dev/null +++ b/transport/helpers/network/interface.go @@ -0,0 +1,85 @@ +package network + +import ( + "context" + "errors" + "strings" + + "github.com/libp2p/go-libp2p-core/peer" + "github.com/libp2p/go-libp2p-core/protocol" + + datatransfer "github.com/filecoin-project/go-data-transfer/v2" +) + +const ( + // ProtocolFilDataTransfer1_2 is the legacy filecoin data transfer protocol + // that assumes a graphsync transport + ProtocolFilDataTransfer1_2 protocol.ID = "/fil/datatransfer/1.2.0" + // ProtocolDataTransfer1_2 is the protocol identifier for current data transfer + // protocol which wraps transport information in the protocol + ProtocolDataTransfer1_2 protocol.ID = "/datatransfer/1.2.0" +) + +// ProtocolDescription describes how you are connected to a given +// peer on a given transport, if at all +type ProtocolDescription struct { + IsLegacy bool + MessageVersion datatransfer.Version + TransportVersion datatransfer.Version +} + +// MessageVersion extracts the message version from the full protocol +func MessageVersion(protocol protocol.ID) (datatransfer.Version, error) { + protocolParts := strings.Split(string(protocol), "/") + if len(protocolParts) == 0 { + return datatransfer.Version{}, errors.New("no protocol to parse") + } + return datatransfer.MessageVersionFromString(protocolParts[len(protocolParts)-1]) +} + +// DataTransferNetwork provides network connectivity for GraphSync. +type DataTransferNetwork interface { + Protect(id peer.ID, tag string) + Unprotect(id peer.ID, tag string) bool + + // SendMessage sends a GraphSync message to a peer. + SendMessage( + context.Context, + peer.ID, + datatransfer.TransportID, + datatransfer.Message) error + + // SetDelegate registers the Reciver to handle messages received from the + // network. + SetDelegate(datatransfer.TransportID, []datatransfer.Version, Receiver) + + // ConnectTo establishes a connection to the given peer + ConnectTo(context.Context, peer.ID) error + + // ConnectWithRetry establishes a connection to the given peer, retrying if + // necessary, and opens a stream on the data-transfer protocol to verify + // the peer will accept messages on the protocol + ConnectWithRetry(ctx context.Context, p peer.ID, transportID datatransfer.TransportID) error + + // ID returns the peer id of this libp2p host + ID() peer.ID + + // Protocol returns the protocol version of the peer, connecting to + // the peer if necessary + Protocol(context.Context, peer.ID, datatransfer.TransportID) (ProtocolDescription, error) +} + +// Receiver is an interface for receiving messages from the GraphSyncNetwork. +type Receiver interface { + ReceiveRequest( + ctx context.Context, + sender peer.ID, + incoming datatransfer.Request) + + ReceiveResponse( + ctx context.Context, + sender peer.ID, + incoming datatransfer.Response) + + ReceiveRestartExistingChannelRequest(ctx context.Context, sender peer.ID, incoming datatransfer.Request) +} diff --git a/network/libp2p_impl.go b/transport/helpers/network/libp2p_impl.go similarity index 54% rename from network/libp2p_impl.go rename to transport/helpers/network/libp2p_impl.go index 30437717..d4fe0edb 100644 --- a/network/libp2p_impl.go +++ b/transport/helpers/network/libp2p_impl.go @@ -2,8 +2,10 @@ package network import ( "context" + "errors" "fmt" "io" + "strings" "time" logging "github.com/ipfs/go-log/v2" @@ -16,10 +18,9 @@ import ( "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/codes" "go.opentelemetry.io/otel/trace" - "golang.org/x/xerrors" - datatransfer "github.com/filecoin-project/go-data-transfer" - "github.com/filecoin-project/go-data-transfer/message" + datatransfer "github.com/filecoin-project/go-data-transfer/v2" + "github.com/filecoin-project/go-data-transfer/v2/message" ) var log = logging.Logger("data_transfer_network") @@ -43,7 +44,12 @@ const defaultMaxAttemptDuration = 5 * time.Minute const defaultBackoffFactor = 5 var defaultDataTransferProtocols = []protocol.ID{ - datatransfer.ProtocolDataTransfer1_2, + ProtocolDataTransfer1_2, + ProtocolFilDataTransfer1_2, +} + +func isLegacyProtocol(protocol protocol.ID) bool { + return protocol == ProtocolFilDataTransfer1_2 } // Option is an option for configuring the libp2p storage market network @@ -51,26 +57,26 @@ type Option func(*libp2pDataTransferNetwork) // DataTransferProtocols OVERWRITES the default libp2p protocols we use for data transfer with the given protocols. func DataTransferProtocols(protocols []protocol.ID) Option { - return func(impl *libp2pDataTransferNetwork) { - impl.setDataTransferProtocols(protocols) + return func(dtnet *libp2pDataTransferNetwork) { + dtnet.setDataTransferProtocols(protocols) } } // SendMessageParameters changes the default parameters around sending messages func SendMessageParameters(openStreamTimeout time.Duration, sendMessageTimeout time.Duration) Option { - return func(impl *libp2pDataTransferNetwork) { - impl.sendMessageTimeout = sendMessageTimeout - impl.openStreamTimeout = openStreamTimeout + return func(dtnet *libp2pDataTransferNetwork) { + dtnet.sendMessageTimeout = sendMessageTimeout + dtnet.openStreamTimeout = openStreamTimeout } } // RetryParameters changes the default parameters around connection reopening func RetryParameters(minDuration time.Duration, maxDuration time.Duration, attempts float64, backoffFactor float64) Option { - return func(impl *libp2pDataTransferNetwork) { - impl.maxStreamOpenAttempts = attempts - impl.minAttemptDuration = minDuration - impl.maxAttemptDuration = maxDuration - impl.backoffFactor = backoffFactor + return func(dtnet *libp2pDataTransferNetwork) { + dtnet.maxStreamOpenAttempts = attempts + dtnet.minAttemptDuration = minDuration + dtnet.maxAttemptDuration = maxDuration + dtnet.backoffFactor = backoffFactor } } @@ -85,6 +91,8 @@ func NewFromLibp2pHost(host host.Host, options ...Option) DataTransferNetwork { minAttemptDuration: defaultMinAttemptDuration, maxAttemptDuration: defaultMaxAttemptDuration, backoffFactor: defaultBackoffFactor, + receivers: make(map[protocol.ID]receiverData), + transportProtocols: make(map[datatransfer.TransportID]transportProtocols), } dataTransferNetwork.setDataTransferProtocols(defaultDataTransferProtocols) @@ -95,44 +103,54 @@ func NewFromLibp2pHost(host host.Host, options ...Option) DataTransferNetwork { return &dataTransferNetwork } +type transportProtocols struct { + protocols []protocol.ID + protocolStrings []string +} + +type receiverData struct { + ProtocolDescription + transportID datatransfer.TransportID + receiver Receiver +} + // libp2pDataTransferNetwork transforms the libp2p host interface, which sends and receives // NetMessage objects, into the data transfer network interface. type libp2pDataTransferNetwork struct { host host.Host // inbound messages from the network are forwarded to the receiver - receiver Receiver - + receivers map[protocol.ID]receiverData + transportProtocols map[datatransfer.TransportID]transportProtocols openStreamTimeout time.Duration sendMessageTimeout time.Duration maxStreamOpenAttempts float64 minAttemptDuration time.Duration maxAttemptDuration time.Duration dtProtocols []protocol.ID - dtProtocolStrings []string backoffFactor float64 } -func (impl *libp2pDataTransferNetwork) openStream(ctx context.Context, id peer.ID, protocols ...protocol.ID) (network.Stream, error) { +func (dtnet *libp2pDataTransferNetwork) openStream(ctx context.Context, id peer.ID, protocols ...protocol.ID) (network.Stream, error) { b := &backoff.Backoff{ - Min: impl.minAttemptDuration, - Max: impl.maxAttemptDuration, - Factor: impl.backoffFactor, + Min: dtnet.minAttemptDuration, + Max: dtnet.maxAttemptDuration, + Factor: dtnet.backoffFactor, Jitter: true, } start := time.Now() for { - tctx, cancel := context.WithTimeout(ctx, impl.openStreamTimeout) + tctx, cancel := context.WithTimeout(ctx, dtnet.openStreamTimeout) defer cancel() // will use the first among the given protocols that the remote peer supports at := time.Now() - s, err := impl.host.NewStream(tctx, id, protocols...) + s, err := dtnet.host.NewStream(tctx, id, protocols...) if err == nil { nAttempts := b.Attempt() + 1 if b.Attempt() > 0 { log.Debugf("opened stream to %s on attempt %g of %g after %s", - id, nAttempts, impl.maxStreamOpenAttempts, time.Since(start)) + id, nAttempts, dtnet.maxStreamOpenAttempts, time.Since(start)) } return s, err @@ -140,13 +158,13 @@ func (impl *libp2pDataTransferNetwork) openStream(ctx context.Context, id peer.I // b.Attempt() starts from zero nAttempts := b.Attempt() + 1 - if nAttempts >= impl.maxStreamOpenAttempts { - return nil, xerrors.Errorf("exhausted %g attempts but failed to open stream to %s, err: %w", impl.maxStreamOpenAttempts, id, err) + if nAttempts >= dtnet.maxStreamOpenAttempts { + return nil, fmt.Errorf("exhausted %g attempts but failed to open stream to %s, err: %w", dtnet.maxStreamOpenAttempts, id, err) } d := b.Duration() log.Warnf("failed to open stream to %s on attempt %g of %g after %s, waiting %s to try again, err: %s", - id, nAttempts, impl.maxStreamOpenAttempts, time.Since(at), d, err) + id, nAttempts, dtnet.maxStreamOpenAttempts, time.Since(at), d, err) select { case <-ctx.Done(): @@ -159,6 +177,7 @@ func (impl *libp2pDataTransferNetwork) openStream(ctx context.Context, id peer.I func (dtnet *libp2pDataTransferNetwork) SendMessage( ctx context.Context, p peer.ID, + transportID datatransfer.TransportID, outgoing datatransfer.Message) error { ctx, span := otel.Tracer("data-transfer").Start(ctx, "sendMessage", trace.WithAttributes( @@ -173,22 +192,36 @@ func (dtnet *libp2pDataTransferNetwork) SendMessage( )) defer span.End() - s, err := dtnet.openStream(ctx, p, dtnet.dtProtocols...) + + transportProtocols, ok := dtnet.transportProtocols[transportID] + if !ok { + return datatransfer.ErrUnsupported + } + s, err := dtnet.openStream(ctx, p, transportProtocols.protocols...) if err != nil { span.RecordError(err) span.SetStatus(codes.Error, err.Error()) return err } - outgoing, err = outgoing.MessageForProtocol(s.Protocol()) + receiverData, ok := dtnet.receivers[s.Protocol()] + if !ok { + // this shouldn't happen, but let's be careful just in case to avoid a panic + err := errors.New("no receiver set") + span.RecordError(err) + span.SetStatus(codes.Error, err.Error()) + return err + } + + outgoing, err = outgoing.MessageForVersion(receiverData.MessageVersion) if err != nil { - err = xerrors.Errorf("failed to convert message for protocol: %w", err) + err = fmt.Errorf("failed to convert message for protocol: %w", err) span.RecordError(err) span.SetStatus(codes.Error, err.Error()) return err } - if err = dtnet.msgToStream(ctx, s, outgoing); err != nil { + if err = dtnet.msgToStream(ctx, s, outgoing, receiverData); err != nil { if err2 := s.Reset(); err2 != nil { log.Error(err) span.RecordError(err2) @@ -203,9 +236,55 @@ func (dtnet *libp2pDataTransferNetwork) SendMessage( return s.Close() } -func (dtnet *libp2pDataTransferNetwork) SetDelegate(r Receiver) { - dtnet.receiver = r - for _, p := range dtnet.dtProtocols { +func (dtnet *libp2pDataTransferNetwork) SetDelegate(transportID datatransfer.TransportID, versions []datatransfer.Version, r Receiver) { + transportProtocols := transportProtocols{} + for _, dtProtocol := range dtnet.dtProtocols { + messageVersion, _ := MessageVersion(dtProtocol) + if isLegacyProtocol(dtProtocol) { + if transportID == datatransfer.LegacyTransportID { + supportsLegacyVersion := false + for _, version := range versions { + if version == datatransfer.LegacyTransportVersion { + supportsLegacyVersion = true + break + } + } + if !supportsLegacyVersion { + continue + } + dtnet.receivers[dtProtocol] = receiverData{ + ProtocolDescription: ProtocolDescription{ + IsLegacy: true, + TransportVersion: datatransfer.LegacyTransportVersion, + MessageVersion: messageVersion, + }, + transportID: transportID, + receiver: r, + } + transportProtocols.protocols = append(transportProtocols.protocols, dtProtocol) + transportProtocols.protocolStrings = append(transportProtocols.protocolStrings, string(dtProtocol)) + } + } else { + for _, version := range versions { + joinedProtocol := strings.Join([]string{string(dtProtocol), string(transportID), version.String()}, "/") + dtnet.receivers[protocol.ID(joinedProtocol)] = receiverData{ + ProtocolDescription: ProtocolDescription{ + IsLegacy: false, + TransportVersion: version, + MessageVersion: messageVersion, + }, + transportID: transportID, + receiver: r, + } + transportProtocols.protocols = append(transportProtocols.protocols, protocol.ID(joinedProtocol)) + transportProtocols.protocolStrings = append(transportProtocols.protocolStrings, joinedProtocol) + } + } + } + + dtnet.transportProtocols[transportID] = transportProtocols + + for _, p := range transportProtocols.protocols { dtnet.host.SetStreamHandler(p, dtnet.handleNewStream) } } @@ -217,10 +296,14 @@ func (dtnet *libp2pDataTransferNetwork) ConnectTo(ctx context.Context, p peer.ID // ConnectWithRetry establishes a connection to the given peer, retrying if // necessary, and opens a stream on the data-transfer protocol to verify // the peer will accept messages on the protocol -func (dtnet *libp2pDataTransferNetwork) ConnectWithRetry(ctx context.Context, p peer.ID) error { +func (dtnet *libp2pDataTransferNetwork) ConnectWithRetry(ctx context.Context, p peer.ID, transportID datatransfer.TransportID) error { + transportProtocols, ok := dtnet.transportProtocols[transportID] + if !ok { + return datatransfer.ErrUnsupported + } // Open a stream over the data-transfer protocol, to make sure that the // peer is listening on the protocol - s, err := dtnet.openStream(ctx, p, dtnet.dtProtocols...) + s, err := dtnet.openStream(ctx, p, transportProtocols.protocols...) if err != nil { return err } @@ -234,25 +317,31 @@ func (dtnet *libp2pDataTransferNetwork) ConnectWithRetry(ctx context.Context, p func (dtnet *libp2pDataTransferNetwork) handleNewStream(s network.Stream) { defer s.Close() // nolint: errcheck,gosec - if dtnet.receiver == nil { + if len(dtnet.receivers) == 0 { s.Reset() // nolint: errcheck,gosec return } + receiverData, ok := dtnet.receivers[s.Protocol()] + if !ok { + s.Reset() // nolint: errcheck,gosec + return + } p := s.Conn().RemotePeer() + // if we have no transport handler, reset the stream for { var received datatransfer.Message var err error - switch s.Protocol() { - case datatransfer.ProtocolDataTransfer1_2: + if receiverData.IsLegacy { received, err = message.FromNet(s) + } else { + received, err = message.FromNetWrapped(s) } if err != nil { if err != io.EOF && err != io.ErrUnexpectedEOF { s.Reset() // nolint: errcheck,gosec - go dtnet.receiver.ReceiveError(err) - log.Debugf("net handleNewStream from %s error: %s", p, err) + log.Errorf("net handleNewStream from %s error: %s", p, err) } return } @@ -264,15 +353,15 @@ func (dtnet *libp2pDataTransferNetwork) handleNewStream(s network.Stream) { receivedRequest, ok := received.(datatransfer.Request) if ok { if receivedRequest.IsRestartExistingChannelRequest() { - dtnet.receiver.ReceiveRestartExistingChannelRequest(ctx, p, receivedRequest) + receiverData.receiver.ReceiveRestartExistingChannelRequest(ctx, p, receivedRequest) } else { - dtnet.receiver.ReceiveRequest(ctx, p, receivedRequest) + receiverData.receiver.ReceiveRequest(ctx, p, receivedRequest) } } } else { receivedResponse, ok := received.(datatransfer.Response) if ok { - dtnet.receiver.ReceiveResponse(ctx, p, receivedResponse) + receiverData.receiver.ReceiveResponse(ctx, p, receivedResponse) } } } @@ -290,7 +379,7 @@ func (dtnet *libp2pDataTransferNetwork) Unprotect(id peer.ID, tag string) bool { return dtnet.host.ConnManager().Unprotect(id, tag) } -func (dtnet *libp2pDataTransferNetwork) msgToStream(ctx context.Context, s network.Stream, msg datatransfer.Message) error { +func (dtnet *libp2pDataTransferNetwork) msgToStream(ctx context.Context, s network.Stream, msg datatransfer.Message, receiverData receiverData) error { if msg.IsRequest() { log.Debugf("Outgoing request message for transfer ID: %d", msg.TransferID()) } @@ -308,49 +397,52 @@ func (dtnet *libp2pDataTransferNetwork) msgToStream(ctx context.Context, s netwo } }() - switch s.Protocol() { - case datatransfer.ProtocolDataTransfer1_2: - default: - return fmt.Errorf("unrecognized protocol on remote: %s", s.Protocol()) + if !receiverData.IsLegacy { + msg = msg.WrappedForTransport(receiverData.transportID, receiverData.TransportVersion) } if err := msg.ToNet(s); err != nil { log.Debugf("error: %s", err) return err } - return nil } -func (impl *libp2pDataTransferNetwork) Protocol(ctx context.Context, id peer.ID) (protocol.ID, error) { +func (dtnet *libp2pDataTransferNetwork) Protocol(ctx context.Context, id peer.ID, transportID datatransfer.TransportID) (ProtocolDescription, error) { + transportProtocols, ok := dtnet.transportProtocols[transportID] + if !ok { + return ProtocolDescription{}, datatransfer.ErrUnsupported + } + // Check the cache for the peer's protocol version - firstProto, err := impl.host.Peerstore().FirstSupportedProtocol(id, impl.dtProtocolStrings...) + firstProto, err := dtnet.host.Peerstore().FirstSupportedProtocol(id, transportProtocols.protocolStrings...) if err != nil { - return "", err + return ProtocolDescription{}, err } if firstProto != "" { - return protocol.ID(firstProto), nil + receiverData, ok := dtnet.receivers[protocol.ID(firstProto)] + if !ok { + return ProtocolDescription{}, err + } + return receiverData.ProtocolDescription, nil } // The peer's protocol version is not in the cache, so connect to the peer. // Note that when the stream is opened, the peer's protocol will be added // to the cache. - s, err := impl.openStream(ctx, id, impl.dtProtocols...) + s, err := dtnet.openStream(ctx, id, dtnet.dtProtocols...) if err != nil { - return "", err + return ProtocolDescription{}, err } _ = s.Close() - - return s.Protocol(), nil + receiverData, ok := dtnet.receivers[s.Protocol()] + if !ok { + return ProtocolDescription{}, err + } + return receiverData.ProtocolDescription, nil } -func (impl *libp2pDataTransferNetwork) setDataTransferProtocols(protocols []protocol.ID) { - impl.dtProtocols = append([]protocol.ID{}, protocols...) - - // Keep a string version of the protocols for performance reasons - impl.dtProtocolStrings = make([]string, 0, len(impl.dtProtocols)) - for _, proto := range impl.dtProtocols { - impl.dtProtocolStrings = append(impl.dtProtocolStrings, string(proto)) - } +func (dtnet *libp2pDataTransferNetwork) setDataTransferProtocols(protocols []protocol.ID) { + dtnet.dtProtocols = append([]protocol.ID{}, protocols...) } diff --git a/network/libp2p_impl_test.go b/transport/helpers/network/libp2p_impl_test.go similarity index 81% rename from network/libp2p_impl_test.go rename to transport/helpers/network/libp2p_impl_test.go index e70e1cc4..1b689004 100644 --- a/network/libp2p_impl_test.go +++ b/transport/helpers/network/libp2p_impl_test.go @@ -18,10 +18,11 @@ import ( "github.com/stretchr/testify/require" "golang.org/x/xerrors" - datatransfer "github.com/filecoin-project/go-data-transfer" - "github.com/filecoin-project/go-data-transfer/message" - "github.com/filecoin-project/go-data-transfer/network" - "github.com/filecoin-project/go-data-transfer/testutil" + datatransfer "github.com/filecoin-project/go-data-transfer/v2" + "github.com/filecoin-project/go-data-transfer/v2/message" + "github.com/filecoin-project/go-data-transfer/v2/message/types" + "github.com/filecoin-project/go-data-transfer/v2/testutil" + "github.com/filecoin-project/go-data-transfer/v2/transport/helpers/network" ) // Receiver is an interface for receiving messages from the DataTransferNetwork. @@ -90,8 +91,8 @@ func TestMessageSendAndReceive(t *testing.T) { messageReceived: make(chan struct{}), connectedPeers: make(chan peer.ID, 2), } - dtnet1.SetDelegate(r) - dtnet2.SetDelegate(r) + dtnet1.SetDelegate("graphsync", []datatransfer.Version{datatransfer.LegacyTransportVersion}, r) + dtnet2.SetDelegate("graphsync", []datatransfer.Version{datatransfer.LegacyTransportVersion}, r) err = dtnet1.ConnectTo(ctx, host2.ID()) require.NoError(t, err) @@ -101,10 +102,10 @@ func TestMessageSendAndReceive(t *testing.T) { selector := builder.NewSelectorSpecBuilder(basicnode.Prototype.Any).Matcher().Node() isPull := false id := datatransfer.TransferID(rand.Int31()) - voucher := testutil.NewFakeDTType() - request, err := message.NewRequest(id, false, isPull, voucher.Type(), voucher, baseCid, selector) + voucher := testutil.NewTestTypedVoucher() + request, err := message.NewRequest(id, false, isPull, &voucher, baseCid, selector) require.NoError(t, err) - require.NoError(t, dtnet1.SendMessage(ctx, host2.ID(), request)) + require.NoError(t, dtnet1.SendMessage(ctx, host2.ID(), "graphsync", request)) select { case <-ctx.Done(): @@ -123,17 +124,16 @@ func TestMessageSendAndReceive(t *testing.T) { assert.Equal(t, request.IsPull(), receivedRequest.IsPull()) assert.Equal(t, request.IsRequest(), receivedRequest.IsRequest()) assert.True(t, receivedRequest.BaseCid().Equals(request.BaseCid())) - testutil.AssertEqualFakeDTVoucher(t, request, receivedRequest) + testutil.AssertEqualTestVoucher(t, request, receivedRequest) testutil.AssertEqualSelector(t, request, receivedRequest) }) t.Run("Send Response", func(t *testing.T) { accepted := false id := datatransfer.TransferID(rand.Int31()) - voucherResult := testutil.NewFakeDTType() - response, err := message.NewResponse(id, accepted, false, voucherResult.Type(), voucherResult) - require.NoError(t, err) - require.NoError(t, dtnet2.SendMessage(ctx, host1.ID(), response)) + voucherResult := testutil.NewTestTypedVoucher() + response := message.ValidationResultResponse(types.NewMessage, id, datatransfer.ValidationResult{Accepted: accepted, VoucherResult: &voucherResult}, nil, false) + require.NoError(t, dtnet2.SendMessage(ctx, host1.ID(), "graphsync", response)) select { case <-ctx.Done(): @@ -150,7 +150,7 @@ func TestMessageSendAndReceive(t *testing.T) { assert.Equal(t, response.TransferID(), receivedResponse.TransferID()) assert.Equal(t, response.Accepted(), receivedResponse.Accepted()) assert.Equal(t, response.IsRequest(), receivedResponse.IsRequest()) - testutil.AssertEqualFakeDTVoucherResult(t, response, receivedResponse) + testutil.AssertEqualTestVoucherResult(t, response, receivedResponse) }) t.Run("Send Restart Request", func(t *testing.T) { @@ -160,7 +160,7 @@ func TestMessageSendAndReceive(t *testing.T) { Responder: peers[1], ID: id} request := message.RestartExistingChannelRequest(chId) - require.NoError(t, dtnet1.SendMessage(ctx, host2.ID(), request)) + require.NoError(t, dtnet1.SendMessage(ctx, host2.ID(), "graphsync", request)) select { case <-ctx.Done(): @@ -262,8 +262,8 @@ func TestSendMessageRetry(t *testing.T) { messageReceived: make(chan struct{}), connectedPeers: make(chan peer.ID, 2), } - dtnet1.SetDelegate(r) - dtnet2.SetDelegate(r) + dtnet1.SetDelegate("graphsync", []datatransfer.Version{datatransfer.LegacyTransportVersion}, r) + dtnet2.SetDelegate("graphsync", []datatransfer.Version{datatransfer.LegacyTransportVersion}, r) err = dtnet1.ConnectTo(ctx, host2.ID()) require.NoError(t, err) @@ -272,11 +272,11 @@ func TestSendMessageRetry(t *testing.T) { selector := builder.NewSelectorSpecBuilder(basicnode.Prototype.Any).Matcher().Node() isPull := false id := datatransfer.TransferID(rand.Int31()) - voucher := testutil.NewFakeDTType() - request, err := message.NewRequest(id, false, isPull, voucher.Type(), voucher, baseCid, selector) + voucher := testutil.NewTestTypedVoucher() + request, err := message.NewRequest(id, false, isPull, &voucher, baseCid, selector) require.NoError(t, err) - err = dtnet1.SendMessage(ctx, host2.ID(), request) + err = dtnet1.SendMessage(ctx, host2.ID(), "graphsync", request) if !tcase.expSuccess { require.Error(t, err) return diff --git a/types.go b/types.go index cd970e0d..e97c8038 100644 --- a/types.go +++ b/types.go @@ -6,10 +6,9 @@ import ( "github.com/ipfs/go-cid" "github.com/ipld/go-ipld-prime" + "github.com/ipld/go-ipld-prime/datamodel" "github.com/libp2p/go-libp2p-core/peer" cbg "github.com/whyrusleeping/cbor-gen" - - "github.com/filecoin-project/go-data-transfer/encoding" ) //go:generate cbor-gen-for ChannelID ChannelStages ChannelStage Log @@ -21,23 +20,18 @@ type TypeIdentifier string // EmptyTypeIdentifier means there is no voucher present const EmptyTypeIdentifier = TypeIdentifier("") -// Registerable is a type of object in a registry. It must be encodable and must -// have a single method that uniquely identifies its type -type Registerable interface { - encoding.Encodable - // Type is a unique string identifier for this voucher type - Type() TypeIdentifier +// TypedVoucher is a voucher or voucher result in IPLD form and an associated +// type identifier for that voucher or voucher result +type TypedVoucher struct { + Voucher datamodel.Node + Type TypeIdentifier } -// Voucher is used to validate -// a data transfer request against the underlying storage or retrieval deal -// that precipitated it. The only requirement is a voucher can read and write -// from bytes, and has a string identifier type -type Voucher Registerable - -// VoucherResult is used to provide option additional information about a -// voucher being rejected or accepted -type VoucherResult Registerable +// Equals is a utility to compare that two TypedVouchers are the same - both type +// and the voucher's IPLD content +func (tv1 TypedVoucher) Equals(tv2 TypedVoucher) bool { + return tv1.Type == tv2.Type && ipld.DeepEqual(tv1.Voucher, tv2.Voucher) +} // TransferID is an identifier for a data transfer, shared between // request/responder and unique to the requester @@ -74,10 +68,10 @@ type Channel interface { // Selector returns the IPLD selector for this data transfer (represented as // an IPLD node) - Selector() ipld.Node + Selector() datamodel.Node - // Voucher returns the voucher for this data transfer - Voucher() Voucher + // Voucher returns the initial voucher for this data transfer + Voucher() TypedVoucher // Sender returns the peer id for the node that is sending data Sender() peer.ID @@ -118,32 +112,52 @@ type ChannelState interface { Message() string // Vouchers returns all vouchers sent on this channel - Vouchers() []Voucher + Vouchers() []TypedVoucher // VoucherResults are results of vouchers sent on the channel - VoucherResults() []VoucherResult + VoucherResults() []TypedVoucher // LastVoucher returns the last voucher sent on the channel - LastVoucher() Voucher + LastVoucher() TypedVoucher // LastVoucherResult returns the last voucher result sent on the channel - LastVoucherResult() VoucherResult + LastVoucherResult() TypedVoucher - // ReceivedCidsTotal returns the number of (non-unique) cids received so far - // on the channel - note that a block can exist in more than one place in the DAG - ReceivedCidsTotal() int64 + // ReceivedIndex returns the index, a transport specific identifier for "where" + // we are in receiving data for a transfer + ReceivedIndex() datamodel.Node - // QueuedCidsTotal returns the number of (non-unique) cids queued so far - // on the channel - note that a block can exist in more than one place in the DAG - QueuedCidsTotal() int64 + // QueuedIndex returns the index, a transport specific identifier for "where" + // we are in queing data for a transfer + QueuedIndex() datamodel.Node - // SentCidsTotal returns the number of (non-unique) cids sent so far - // on the channel - note that a block can exist in more than one place in the DAG - SentCidsTotal() int64 + // SentIndex returns the index, a transport specific identifier for "where" + // we are in sending data for a transfer + SentIndex() datamodel.Node // Queued returns the number of bytes read from the node and queued for sending Queued() uint64 + // DataLimit is the maximum data that can be transferred on this channel before + // revalidation. 0 indicates no limit. + DataLimit() uint64 + + // RequiresFinalization indicates at the end of the transfer, the channel should + // be left open for a final settlement + RequiresFinalization() bool + + // InitiatorPaused indicates whether the initiator of this channel is in a paused state + InitiatorPaused() bool + + // ResponderPaused indicates whether the responder of this channel is in a paused state + ResponderPaused() bool + + // BothPaused indicates both sides of the transfer have paused the transfer + BothPaused() bool + + // SelfPaused indicates whether the local peer for this channel is in a paused state + SelfPaused() bool + // Stages returns the timeline of events this data transfer has gone through, // for observability purposes. //