Skip to content

[audit-02] fix: [TRST-H-2] Only agreement owner can collect indexing fee #1199

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: ma/indexing-payments-audit-fixes-01-H-1
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,8 @@ contract RecurringCollector is EIP712, GraphDirectory, Authorizable, IRecurringC
* @dev Caller must be the data service the RCA was issued to.
*/
function collect(IGraphPayments.PaymentTypes paymentType, bytes calldata data) external returns (uint256) {
require(
paymentType == IGraphPayments.PaymentTypes.IndexingFee,
RecurringCollectorInvalidPaymentType(paymentType)
);
try this.decodeCollectData(data) returns (CollectParams memory collectParams) {
return _collect(collectParams);
return _collect(paymentType, collectParams);
} catch {
revert RecurringCollectorInvalidCollectData(data);
}
Expand Down Expand Up @@ -269,10 +265,14 @@ contract RecurringCollector is EIP712, GraphDirectory, Authorizable, IRecurringC
*
* Emits {PaymentCollected} and {RCACollected} events.
*
* @param _paymentType The type of payment to collect
* @param _params The decoded parameters for the collection
* @return The amount of tokens collected
*/
function _collect(CollectParams memory _params) private returns (uint256) {
function _collect(
IGraphPayments.PaymentTypes _paymentType,
CollectParams memory _params
) private returns (uint256) {
AgreementData storage agreement = _getAgreementStorage(_params.agreementId);
require(
_isCollectable(agreement),
Expand All @@ -289,7 +289,7 @@ contract RecurringCollector is EIP712, GraphDirectory, Authorizable, IRecurringC
tokensToCollect = _requireValidCollect(agreement, _params.agreementId, _params.tokens);

_graphPaymentsEscrow().collect(
IGraphPayments.PaymentTypes.IndexingFee,
_paymentType,
agreement.payer,
agreement.serviceProvider,
tokensToCollect,
Expand All @@ -301,7 +301,7 @@ contract RecurringCollector is EIP712, GraphDirectory, Authorizable, IRecurringC
agreement.lastCollectionAt = uint64(block.timestamp);

emit PaymentCollected(
IGraphPayments.PaymentTypes.IndexingFee,
_paymentType,
_params.collectionId,
agreement.payer,
agreement.serviceProvider,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
// SPDX-License-Identifier: MIT
pragma solidity 0.8.27;

import { IGraphPayments } from "../../../../contracts/interfaces/IGraphPayments.sol";

import { IRecurringCollector } from "../../../../contracts/interfaces/IRecurringCollector.sol";

import { RecurringCollectorSharedTest } from "./shared.t.sol";
Expand All @@ -14,32 +12,14 @@ contract RecurringCollectorCollectTest is RecurringCollectorSharedTest {

/* solhint-disable graph/func-name-mixedcase */

function test_Collect_Revert_WhenInvalidPaymentType(uint8 unboundedPaymentType, bytes memory data) public {
IGraphPayments.PaymentTypes paymentType = IGraphPayments.PaymentTypes(
bound(
unboundedPaymentType,
uint256(type(IGraphPayments.PaymentTypes).min),
uint256(type(IGraphPayments.PaymentTypes).max)
)
);
vm.assume(paymentType != IGraphPayments.PaymentTypes.IndexingFee);

bytes memory expectedErr = abi.encodeWithSelector(
IRecurringCollector.RecurringCollectorInvalidPaymentType.selector,
paymentType
);
vm.expectRevert(expectedErr);
_recurringCollector.collect(paymentType, data);
}

function test_Collect_Revert_WhenInvalidData(address caller, bytes memory data) public {
function test_Collect_Revert_WhenInvalidData(address caller, uint8 unboundedPaymentType, bytes memory data) public {
bytes memory expectedErr = abi.encodeWithSelector(
IRecurringCollector.RecurringCollectorInvalidCollectData.selector,
data
);
vm.expectRevert(expectedErr);
vm.prank(caller);
_recurringCollector.collect(IGraphPayments.PaymentTypes.IndexingFee, data);
_recurringCollector.collect(_paymentType(unboundedPaymentType), data);
}

function test_Collect_Revert_WhenCallerNotDataService(
Expand All @@ -61,7 +41,7 @@ contract RecurringCollectorCollectTest is RecurringCollectorSharedTest {
);
vm.expectRevert(expectedErr);
vm.prank(notDataService);
_recurringCollector.collect(IGraphPayments.PaymentTypes.IndexingFee, data);
_recurringCollector.collect(_paymentType(fuzzy.unboundedPaymentType), data);
}

function test_Collect_Revert_WhenUnknownAgreement(FuzzyTestCollect memory fuzzy, address dataService) public {
Expand All @@ -74,7 +54,7 @@ contract RecurringCollectorCollectTest is RecurringCollectorSharedTest {
);
vm.expectRevert(expectedErr);
vm.prank(dataService);
_recurringCollector.collect(IGraphPayments.PaymentTypes.IndexingFee, data);
_recurringCollector.collect(_paymentType(fuzzy.unboundedPaymentType), data);
}

function test_Collect_Revert_WhenCanceledAgreementByServiceProvider(FuzzyTestCollect calldata fuzzy) public {
Expand All @@ -97,7 +77,7 @@ contract RecurringCollectorCollectTest is RecurringCollectorSharedTest {
);
vm.expectRevert(expectedErr);
vm.prank(accepted.rca.dataService);
_recurringCollector.collect(IGraphPayments.PaymentTypes.IndexingFee, data);
_recurringCollector.collect(_paymentType(fuzzy.unboundedPaymentType), data);
}

function test_Collect_Revert_WhenCollectingTooSoon(
Expand All @@ -116,7 +96,7 @@ contract RecurringCollectorCollectTest is RecurringCollectorSharedTest {
)
);
vm.prank(accepted.rca.dataService);
_recurringCollector.collect(IGraphPayments.PaymentTypes.IndexingFee, data);
_recurringCollector.collect(_paymentType(fuzzy.unboundedPaymentType), data);

uint256 collectionSeconds = boundSkip(unboundedCollectionSeconds, 1, accepted.rca.minSecondsPerCollection - 1);
skip(collectionSeconds);
Expand All @@ -136,7 +116,7 @@ contract RecurringCollectorCollectTest is RecurringCollectorSharedTest {
);
vm.expectRevert(expectedErr);
vm.prank(accepted.rca.dataService);
_recurringCollector.collect(IGraphPayments.PaymentTypes.IndexingFee, data);
_recurringCollector.collect(_paymentType(fuzzy.unboundedPaymentType), data);
}

function test_Collect_Revert_WhenCollectingTooLate(
Expand All @@ -163,7 +143,7 @@ contract RecurringCollectorCollectTest is RecurringCollectorSharedTest {
)
);
vm.prank(accepted.rca.dataService);
_recurringCollector.collect(IGraphPayments.PaymentTypes.IndexingFee, data);
_recurringCollector.collect(_paymentType(fuzzy.unboundedPaymentType), data);

// skip beyond collectable time but still within the agreement endsAt
uint256 collectionSeconds = boundSkip(
Expand All @@ -189,7 +169,7 @@ contract RecurringCollectorCollectTest is RecurringCollectorSharedTest {
);
vm.expectRevert(expectedErr);
vm.prank(accepted.rca.dataService);
_recurringCollector.collect(IGraphPayments.PaymentTypes.IndexingFee, data);
_recurringCollector.collect(_paymentType(fuzzy.unboundedPaymentType), data);
}

function test_Collect_OK_WhenCollectingTooMuch(
Expand Down Expand Up @@ -219,7 +199,7 @@ contract RecurringCollectorCollectTest is RecurringCollectorSharedTest {
)
);
vm.prank(accepted.rca.dataService);
_recurringCollector.collect(IGraphPayments.PaymentTypes.IndexingFee, initialData);
_recurringCollector.collect(_paymentType(fuzzy.unboundedPaymentType), initialData);
}

// skip to collectable time
Expand All @@ -240,7 +220,7 @@ contract RecurringCollectorCollectTest is RecurringCollectorSharedTest {
);
bytes memory data = _generateCollectData(collectParams);
vm.prank(accepted.rca.dataService);
uint256 collected = _recurringCollector.collect(IGraphPayments.PaymentTypes.IndexingFee, data);
uint256 collected = _recurringCollector.collect(_paymentType(fuzzy.unboundedPaymentType), data);
assertEq(collected, maxTokens);
}

Expand All @@ -258,9 +238,9 @@ contract RecurringCollectorCollectTest is RecurringCollectorSharedTest {
unboundedTokens
);
skip(collectionSeconds);
_expectCollectCallAndEmit(accepted.rca, fuzzy.collectParams, tokens);
_expectCollectCallAndEmit(accepted.rca, _paymentType(fuzzy.unboundedPaymentType), fuzzy.collectParams, tokens);
vm.prank(accepted.rca.dataService);
uint256 collected = _recurringCollector.collect(IGraphPayments.PaymentTypes.IndexingFee, data);
uint256 collected = _recurringCollector.collect(_paymentType(fuzzy.unboundedPaymentType), data);
assertEq(collected, tokens);
}
/* solhint-enable graph/func-name-mixedcase */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import { RecurringCollectorHelper } from "./RecurringCollectorHelper.t.sol";
contract RecurringCollectorSharedTest is Test, Bounder {
struct FuzzyTestCollect {
FuzzyTestAccept fuzzyTestAccept;
uint8 unboundedPaymentType;
IRecurringCollector.CollectParams collectParams;
}

Expand Down Expand Up @@ -106,6 +107,7 @@ contract RecurringCollectorSharedTest is Test, Bounder {

function _expectCollectCallAndEmit(
IRecurringCollector.RecurringCollectionAgreement memory _rca,
IGraphPayments.PaymentTypes __paymentType,
IRecurringCollector.CollectParams memory _fuzzyParams,
uint256 _tokens
) internal {
Expand All @@ -114,7 +116,7 @@ contract RecurringCollectorSharedTest is Test, Bounder {
abi.encodeCall(
_paymentsEscrow.collect,
(
IGraphPayments.PaymentTypes.IndexingFee,
__paymentType,
_rca.payer,
_rca.serviceProvider,
_tokens,
Expand All @@ -126,7 +128,7 @@ contract RecurringCollectorSharedTest is Test, Bounder {
);
vm.expectEmit(address(_recurringCollector));
emit IPaymentsCollector.PaymentCollected(
IGraphPayments.PaymentTypes.IndexingFee,
__paymentType,
_fuzzyParams.collectionId,
_rca.payer,
_rca.serviceProvider,
Expand Down Expand Up @@ -193,4 +195,15 @@ contract RecurringCollectorSharedTest is Test, Bounder {
bound(_seed, 0, uint256(IRecurringCollector.CancelAgreementBy.Payer))
);
}

function _paymentType(uint8 _unboundedPaymentType) internal pure returns (IGraphPayments.PaymentTypes) {
return
IGraphPayments.PaymentTypes(
bound(
_unboundedPaymentType,
uint256(type(IGraphPayments.PaymentTypes).min),
uint256(type(IGraphPayments.PaymentTypes).max)
)
);
}
}
10 changes: 9 additions & 1 deletion packages/subgraph-service/contracts/SubgraphService.sol
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,12 @@ contract SubgraphService is
paymentCollected = _collectIndexingRewards(indexer, data);
} else if (paymentType == IGraphPayments.PaymentTypes.IndexingFee) {
(bytes16 agreementId, bytes memory iaCollectionData) = IndexingAgreementDecoder.decodeCollectData(data);
paymentCollected = _collectIndexingFees(agreementId, paymentsDestination[indexer], iaCollectionData);
paymentCollected = _collectIndexingFees(
indexer,
agreementId,
paymentsDestination[indexer],
iaCollectionData
);
} else {
revert SubgraphServiceInvalidPaymentType(paymentType);
}
Expand Down Expand Up @@ -754,19 +759,22 @@ contract SubgraphService is
* Emits a {StakeClaimLocked} event.
* Emits a {IndexingFeesCollectedV1} event.
*
* @param _indexer The address of the indexer
* @param _agreementId The id of the indexing agreement
* @param _paymentsDestination The address where the fees should be sent
* @param _data The indexing agreement collection data
* @return The amount of fees collected
*/
function _collectIndexingFees(
address _indexer,
bytes16 _agreementId,
address _paymentsDestination,
bytes memory _data
) private returns (uint256) {
(address indexer, uint256 tokensCollected) = IndexingAgreement._getStorageManager().collect(
_allocations,
IndexingAgreement.CollectParams({
indexer: _indexer,
agreementId: _agreementId,
currentEpoch: _graphEpochManager().currentEpoch(),
receiverDestination: _paymentsDestination,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,12 +75,14 @@ library IndexingAgreement {

/**
* @notice Parameters for collecting indexing fees
* @param indexer The address of the indexer
* @param agreementId The ID of the indexing agreement
* @param currentEpoch The current epoch
* @param receiverDestination The address where the collected fees should be sent
* @param data The encoded data containing the number of entities indexed, proof of indexing, and epoch
*/
struct CollectParams {
address indexer;
bytes16 agreementId;
uint256 currentEpoch;
address receiverDestination;
Expand Down Expand Up @@ -523,6 +525,10 @@ library IndexingAgreement {
wrapper.agreement.allocationId,
wrapper.collectorAgreement.serviceProvider
);
require(
allocation.indexer == params.indexer,
IndexingAgreementNotAuthorized(params.agreementId, params.indexer)
);
require(_isCollectable(wrapper), IndexingAgreementNotCollectable(params.agreementId));

require(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,35 @@ contract SubgraphServiceIndexingAgreementCollectTest is SubgraphServiceIndexingA
);
}

function test_SubgraphService_CollectIndexingFees_Reverts_WhenIndexingAgreementNotAuthorized(
Seed memory seed,
uint256 entities,
bytes32 poi
) public {
Context storage ctx = _newCtx(seed);
IndexerState memory indexerState = _withIndexer(ctx);
IndexerState memory otherIndexerState = _withIndexer(ctx);
IRecurringCollector.SignedRCA memory accepted = _withAcceptedIndexingAgreement(ctx, indexerState);

vm.assume(otherIndexerState.addr != indexerState.addr);

resetPrank(otherIndexerState.addr);

uint256 currentEpochBlock = epochManager.currentEpochBlock();

bytes memory expectedErr = abi.encodeWithSelector(
IndexingAgreement.IndexingAgreementNotAuthorized.selector,
accepted.rca.agreementId,
otherIndexerState.addr
);
vm.expectRevert(expectedErr);
subgraphService.collect(
otherIndexerState.addr,
IGraphPayments.PaymentTypes.IndexingFee,
_encodeCollectDataV1(accepted.rca.agreementId, entities, poi, currentEpochBlock, bytes(""))
);
}

function test_SubgraphService_CollectIndexingFees_Reverts_WhenStopService(
Seed memory seed,
uint256 entities,
Expand Down