diff --git a/contracts/src/periphery/ejection/EigenDAEjectionManager.sol b/contracts/src/periphery/ejection/EigenDAEjectionManager.sol index d4dba5ab34..067e8dcbd9 100644 --- a/contracts/src/periphery/ejection/EigenDAEjectionManager.sol +++ b/contracts/src/periphery/ejection/EigenDAEjectionManager.sol @@ -63,6 +63,7 @@ contract EigenDAEjectionManager is IEigenDAEjectionManager, IEigenDASemVer { /// @inheritdoc IEigenDAEjectionManager function cancelEjectionByEjector(address operator) external onlyEjector(msg.sender) { + require(EigenDAEjectionLib.getEjectionRecord(operator).ejector == msg.sender, "only ejector that issued ejection can cancel"); operator.cancelEjection(); } diff --git a/contracts/test/unit/EigenDAEjectionManager.t.sol b/contracts/test/unit/EigenDAEjectionManager.t.sol index 7b2d686c1a..0341f16aea 100644 --- a/contracts/test/unit/EigenDAEjectionManager.t.sol +++ b/contracts/test/unit/EigenDAEjectionManager.t.sol @@ -191,4 +191,28 @@ contract EigenDAEjectionManagerTest is Test { ejectionManager.startEjection(ejectee, "0x"); vm.stopPrank(); } + + function testCancelEjectionByEjectorRevertsWhenCalledByDifferentEjector() public { + // 0) create a second ejector and grant access for ejection role + address ejector2 = makeAddr("ejector2"); + accessControl.grantRole(AccessControlConstants.EJECTOR_ROLE, ejector); + accessControl.grantRole(AccessControlConstants.EJECTOR_ROLE, ejector2); + accessControl.grantRole(AccessControlConstants.OWNER_ROLE, ejector); + + // 1) first ejector starts an ejection + vm.startPrank(ejector); + ejectionManager.setCooldown(0); + ejectionManager.setDelay(0); + ejectionManager.startEjection(ejectee, "0x"); + vm.stopPrank(); + + // 2) verify the ejection was created with the first ejector + assertEq(ejectionManager.getEjector(ejectee), ejector); + + // 3) attempting to cancel the ejection from a different ejector should revert + vm.startPrank(ejector2); + vm.expectRevert("only ejector that issued ejection can cancel"); + ejectionManager.cancelEjectionByEjector(ejectee); + vm.stopPrank(); + } }