diff --git a/core/connection/manager.go b/core/connection/manager.go index e87d5a32a..2e68c945b 100644 --- a/core/connection/manager.go +++ b/core/connection/manager.go @@ -145,6 +145,8 @@ type connectionManager struct { validator validator p2pDialer p2p.Dialer timeGetter TimeGetter + priceCheckInterval time.Duration + priceDropPercent float64 // These are populated by Connect at runtime. ctx context.Context @@ -207,6 +209,8 @@ func NewManager( preReconnect: preReconnect, postReconnect: postReconnect, uuid: uuid.String(), + priceDropPercent: 10, // reconnect if price dropped 10% or more + priceCheckInterval: 30 * time.Second, } m.eventBus.SubscribeAsync(connectionstate.AppTopicConnectionState, m.reconnectOnHold) @@ -303,6 +307,7 @@ func (m *connectionManager) Connect(consumerID identity.Identity, hermesID commo go m.consumeConnectionStates(m.activeConnection.State()) go m.checkSessionIP(m.channel, m.connectOptions.ConsumerID, m.connectOptions.SessionID, originalPublicIP) + go m.monitorPrice(prc, proposalLookup) return nil } @@ -384,7 +389,7 @@ func (m *connectionManager) initSession(tracer *trace.Tracer, prc market.Price) m.setStatus(func(status *connectionstate.Status) { status.SessionID = sessionID }) - m.publishSessionCreate(sessionID) + m.publishSessionCreate() paymentSession.SetSessionID(string(sessionID)) tracer.EndStage(traceStart) @@ -662,7 +667,7 @@ func (m *connectionManager) createP2PSession(c Connection, opts ConnectOptions, return &sessionResponse, nil } -func (m *connectionManager) publishSessionCreate(sessionID session.ID) { +func (m *connectionManager) publishSessionCreate() { sessionInfo := m.Status() // avoid printing IP address in logs sessionInfo.ConsumerLocation.IP = "" @@ -1031,3 +1036,30 @@ func logDisconnectError(err error) { log.Error().Err(err).Msg("Disconnect error") } } + +func (m *connectionManager) monitorPrice(currentPrice market.Price, proposalLookup ProposalLookup) { + t := time.NewTicker(m.priceCheckInterval) + for { + select { + case <-m.currentCtx().Done(): + return + case <-t.C: + proposal, err := proposalLookup() + if err != nil { + log.Error().Err(err).Msg("Failed to lookup proposal") + continue + } + newPrice := m.priceFromProposal(*proposal) + + // Check if both GiB and Hourly prices dropped by at least 10% + giBDrop := float64(currentPrice.PricePerGiB.Int64()-newPrice.PricePerGiB.Int64()) / float64(currentPrice.PricePerGiB.Int64()) + hourDrop := float64(currentPrice.PricePerHour.Int64()-newPrice.PricePerHour.Int64()) / float64(currentPrice.PricePerHour.Int64()) + + if giBDrop*100 >= m.priceDropPercent || hourDrop*100 >= m.priceDropPercent { + log.Info().Msgf("Price dropped significantly from %q to %q, disconnecting", currentPrice.String(), newPrice.String()) + m.Disconnect() + return + } + } + } +} diff --git a/core/connection/manager_test.go b/core/connection/manager_test.go index 030445967..d3c658eb7 100644 --- a/core/connection/manager_test.go +++ b/core/connection/manager_test.go @@ -392,6 +392,36 @@ func (tc *testContext) TestConnectMethodReturnsErrorIfConnectionExitsDuringConne assert.Equal(tc.T(), ErrConnectionFailed, err) } +func (tc *testContext) TestDisconnectDueToPriceDrop() { + tc.fakeConnectionFactory.mockConnection.onStartReportStates = []fakeState{ + connectedState, + } + tc.connManager.priceCheckInterval = time.Millisecond + + mux := sync.RWMutex{} + proposalLookup := func() (proposal *proposal.PricedServiceProposal, err error) { + mux.RLock() + defer mux.RUnlock() + return &activeProposal, nil + } + + err := tc.connManager.Connect(consumerID, hermesID, proposalLookup, ConnectParams{}) + assert.NoError(tc.T(), err) + assert.Equal(tc.T(), connectionstate.Connected, tc.connManager.Status().State) + + newPrice := market.Price{ + PricePerHour: big.NewInt(1), + PricePerGiB: big.NewInt(1), + } + mux.Lock() + activeProposal.Price = newPrice + mux.Unlock() + + waitABit() + + assert.Equal(tc.T(), connectionstate.NotConnected, tc.connManager.Status().State) +} + func (tc *testContext) Test_PaymentManager_WhenManagerMadeConnectionIsStarted() { err := tc.connManager.Connect(consumerID, hermesID, activeProposalLookup, ConnectParams{}) waitABit()