diff --git a/packages/horizon/contracts/interfaces/IRecurringCollector.sol b/packages/horizon/contracts/interfaces/IRecurringCollector.sol index 954b1be94..ef7ba05f7 100644 --- a/packages/horizon/contracts/interfaces/IRecurringCollector.sol +++ b/packages/horizon/contracts/interfaces/IRecurringCollector.sol @@ -253,6 +253,11 @@ interface IRecurringCollector is IAuthorizable, IPaymentsCollector { * @param unauthorizedDataService The address of the unauthorized data service */ error RecurringCollectorDataServiceNotAuthorized(bytes16 agreementId, address unauthorizedDataService); + /** + * @notice Thrown when the data service is not authorized for the service provider + * @param dataService The address of the unauthorized data service + */ + error RecurringCollectorUnauthorizedDataService(address dataService); /** * @notice Thrown when interacting with an agreement with an elapsed deadline diff --git a/packages/horizon/contracts/payments/collectors/RecurringCollector.sol b/packages/horizon/contracts/payments/collectors/RecurringCollector.sol index 662dc549f..5f43c482e 100644 --- a/packages/horizon/contracts/payments/collectors/RecurringCollector.sol +++ b/packages/horizon/contracts/payments/collectors/RecurringCollector.sol @@ -284,6 +284,17 @@ contract RecurringCollector is EIP712, GraphDirectory, Authorizable, IRecurringC RecurringCollectorDataServiceNotAuthorized(_params.agreementId, msg.sender) ); + // Check the service provider has an active provision with the data service + // This prevents an attack where the payer can deny the service provider from collecting payments + // by using a signer as data service to syphon off the tokens in the escrow to an account they control + { + uint256 tokensAvailable = _graphStaking().getProviderTokensAvailable( + agreement.serviceProvider, + agreement.dataService + ); + require(tokensAvailable > 0, RecurringCollectorUnauthorizedDataService(agreement.dataService)); + } + uint256 tokensToCollect = 0; if (_params.tokens != 0) { tokensToCollect = _requireValidCollect(agreement, _params.agreementId, _params.tokens); diff --git a/packages/horizon/test/unit/mocks/HorizonStakingMock.t.sol b/packages/horizon/test/unit/mocks/HorizonStakingMock.t.sol index 647df06f7..d08975e09 100644 --- a/packages/horizon/test/unit/mocks/HorizonStakingMock.t.sol +++ b/packages/horizon/test/unit/mocks/HorizonStakingMock.t.sol @@ -29,4 +29,9 @@ contract HorizonStakingMock { function setIsAuthorized(address serviceProvider, address verifier, address operator, bool authorized) external { authorizations[serviceProvider][verifier][operator] = authorized; } + + function getProviderTokensAvailable(address serviceProvider, address verifier) external view returns (uint256) { + IHorizonStakingTypes.Provision memory provision = provisions[serviceProvider][verifier]; + return provision.tokens - provision.tokensThawing; + } } diff --git a/packages/horizon/test/unit/payments/recurring-collector/collect.t.sol b/packages/horizon/test/unit/payments/recurring-collector/collect.t.sol index 4382fa852..0002c68af 100644 --- a/packages/horizon/test/unit/payments/recurring-collector/collect.t.sol +++ b/packages/horizon/test/unit/payments/recurring-collector/collect.t.sol @@ -2,6 +2,7 @@ pragma solidity 0.8.27; import { IRecurringCollector } from "../../../../contracts/interfaces/IRecurringCollector.sol"; +import { IHorizonStakingTypes } from "../../../../contracts/interfaces/internal/IHorizonStakingTypes.sol"; import { RecurringCollectorSharedTest } from "./shared.t.sol"; @@ -44,6 +45,42 @@ contract RecurringCollectorCollectTest is RecurringCollectorSharedTest { _recurringCollector.collect(_paymentType(fuzzy.unboundedPaymentType), data); } + function test_Collect_Revert_WhenUnauthorizedDataService(FuzzyTestCollect calldata fuzzy) public { + (IRecurringCollector.SignedRCA memory accepted, ) = _sensibleAuthorizeAndAccept(fuzzy.fuzzyTestAccept); + IRecurringCollector.CollectParams memory collectParams = fuzzy.collectParams; + + collectParams.agreementId = accepted.rca.agreementId; + collectParams.tokens = bound(collectParams.tokens, 1, type(uint256).max); + bytes memory data = _generateCollectData(collectParams); + + // Set up the scenario where service provider has no tokens staked with data service + // This simulates an unauthorized data service attack + _horizonStaking.setProvision( + accepted.rca.serviceProvider, + accepted.rca.dataService, + IHorizonStakingTypes.Provision({ + tokens: 0, // No tokens staked - this triggers the vulnerability + tokensThawing: 0, + sharesThawing: 0, + maxVerifierCut: 100000, + thawingPeriod: 604800, + createdAt: uint64(block.timestamp), + maxVerifierCutPending: 100000, + thawingPeriodPending: 604800, + lastParametersStagedAt: 0, + thawingNonce: 0 + }) + ); + + bytes memory expectedErr = abi.encodeWithSelector( + IRecurringCollector.RecurringCollectorUnauthorizedDataService.selector, + accepted.rca.dataService + ); + vm.expectRevert(expectedErr); + vm.prank(accepted.rca.dataService); + _recurringCollector.collect(_paymentType(fuzzy.unboundedPaymentType), data); + } + function test_Collect_Revert_WhenUnknownAgreement(FuzzyTestCollect memory fuzzy, address dataService) public { bytes memory data = _generateCollectData(fuzzy.collectParams); diff --git a/packages/horizon/test/unit/payments/recurring-collector/shared.t.sol b/packages/horizon/test/unit/payments/recurring-collector/shared.t.sol index 2dbd0e1a0..d8d9483e7 100644 --- a/packages/horizon/test/unit/payments/recurring-collector/shared.t.sol +++ b/packages/horizon/test/unit/payments/recurring-collector/shared.t.sol @@ -6,10 +6,12 @@ import { Test } from "forge-std/Test.sol"; import { IGraphPayments } from "../../../../contracts/interfaces/IGraphPayments.sol"; import { IPaymentsCollector } from "../../../../contracts/interfaces/IPaymentsCollector.sol"; import { IRecurringCollector } from "../../../../contracts/interfaces/IRecurringCollector.sol"; +import { IHorizonStakingTypes } from "../../../../contracts/interfaces/internal/IHorizonStakingTypes.sol"; import { RecurringCollector } from "../../../../contracts/payments/collectors/RecurringCollector.sol"; import { Bounder } from "../../../unit/utils/Bounder.t.sol"; import { PartialControllerMock } from "../../mocks/PartialControllerMock.t.sol"; +import { HorizonStakingMock } from "../../mocks/HorizonStakingMock.t.sol"; import { PaymentsEscrowMock } from "./PaymentsEscrowMock.t.sol"; import { RecurringCollectorHelper } from "./RecurringCollectorHelper.t.sol"; @@ -32,12 +34,15 @@ contract RecurringCollectorSharedTest is Test, Bounder { RecurringCollector internal _recurringCollector; PaymentsEscrowMock internal _paymentsEscrow; + HorizonStakingMock internal _horizonStaking; RecurringCollectorHelper internal _recurringCollectorHelper; function setUp() public { _paymentsEscrow = new PaymentsEscrowMock(); - PartialControllerMock.Entry[] memory entries = new PartialControllerMock.Entry[](1); + _horizonStaking = new HorizonStakingMock(); + PartialControllerMock.Entry[] memory entries = new PartialControllerMock.Entry[](2); entries[0] = PartialControllerMock.Entry({ name: "PaymentsEscrow", addr: address(_paymentsEscrow) }); + entries[1] = PartialControllerMock.Entry({ name: "Staking", addr: address(_horizonStaking) }); _recurringCollector = new RecurringCollector( "RecurringCollector", "1", @@ -71,6 +76,9 @@ contract RecurringCollectorSharedTest is Test, Bounder { } function _accept(IRecurringCollector.SignedRCA memory _signedRCA) internal { + // Set up valid staking provision by default to allow collections to succeed + _setupValidProvision(_signedRCA.rca.serviceProvider, _signedRCA.rca.dataService); + vm.expectEmit(address(_recurringCollector)); emit IRecurringCollector.AgreementAccepted( _signedRCA.rca.dataService, @@ -88,6 +96,25 @@ contract RecurringCollectorSharedTest is Test, Bounder { _recurringCollector.accept(_signedRCA); } + function _setupValidProvision(address _serviceProvider, address _dataService) internal { + _horizonStaking.setProvision( + _serviceProvider, + _dataService, + IHorizonStakingTypes.Provision({ + tokens: 1000 ether, + tokensThawing: 0, + sharesThawing: 0, + maxVerifierCut: 100000, // 10% + thawingPeriod: 604800, // 7 days + createdAt: uint64(block.timestamp), + maxVerifierCutPending: 100000, + thawingPeriodPending: 604800, + lastParametersStagedAt: 0, + thawingNonce: 0 + }) + ); + } + function _cancel( IRecurringCollector.RecurringCollectionAgreement memory _rca, IRecurringCollector.CancelAgreementBy _by