diff --git a/tests/beacon/test_helpers.py b/tests/beacon/test_helpers.py index b7e218b781..6ed9c245e0 100644 --- a/tests/beacon/test_helpers.py +++ b/tests/beacon/test_helpers.py @@ -12,6 +12,7 @@ denoms, ValidationError, ) +from eth_utils.toolz import assoc from eth.constants import ( ZERO_HASH32, @@ -779,28 +780,33 @@ def test_verify_vote_count(max_casper_votes, sample_slashable_vote_data_params, assert verify_vote_count(votes, max_casper_votes) -def _get_indices_and_signatures(num_validators, message, privkeys): +def _get_indices_and_signatures(num_validators, message, privkeys, fork_data, slot): num_indices = 5 assert num_validators >= num_indices indices = random.sample(range(num_validators), num_indices) privkeys = [privkeys[i] for i in indices] - domain = SignatureDomain.DOMAIN_ATTESTATION + domain_type = SignatureDomain.DOMAIN_ATTESTATION + domain = get_domain( + fork_data=fork_data, + slot=slot, + domain_type=domain_type, + ) signatures = tuple( map(lambda key: bls.sign(message, key, domain), privkeys) ) return (indices, signatures) -def _correct_slashable_vote_data_params(params, validators, messages, privkeys): +def _correct_slashable_vote_data_params(num_validators, params, messages, privkeys, fork_data): valid_params = copy.deepcopy(params) - num_validators = len(validators) - key = "custody_bit_0_indices" (poc_0_indices, poc_0_signatures) = _get_indices_and_signatures( num_validators, messages[0], privkeys, + fork_data, + params["data"].slot, ) valid_params[key] = poc_0_indices @@ -810,6 +816,8 @@ def _correct_slashable_vote_data_params(params, validators, messages, privkeys): num_validators, messages[1], privkeys, + fork_data, + params["data"].slot, ) valid_params[key] = poc_1_indices @@ -821,23 +829,32 @@ def _correct_slashable_vote_data_params(params, validators, messages, privkeys): return valid_params -def _corrupt_signature(params): - params = copy.deepcopy(params) +def _corrupt_signature(params, fork_data): message = bytes.fromhex("deadbeefcafe") privkey = 42 - domain = SignatureDomain.DOMAIN_ATTESTATION - params["aggregate_signature"] = bls.sign(message, privkey, domain) - return params + domain_type = SignatureDomain.DOMAIN_ATTESTATION + domain = get_domain( + fork_data=fork_data, + slot=params["data"].slot, + domain_type=domain_type, + ) + corrupt_signature = bls.sign(message, privkey, domain) + + return assoc(params, "aggregate_signature", corrupt_signature) def _corrupt_vote_count(params): - params = copy.deepcopy(params) key = "custody_bit_0_indices" for i in itertools.count(): if i not in params[key]: - params[key].append(i) - break - return params + new_vote_count = params[key] + [i] + return assoc( + params, + key, + new_vote_count, + ) + else: + raise Exception("Unreachable code path") def _create_slashable_vote_data_messages(params): @@ -846,27 +863,47 @@ def _create_slashable_vote_data_messages(params): return votes.messages -def test_verify_slashable_vote_data_signature(privkeys, +@pytest.mark.parametrize( + ( + 'num_validators', + ), + [ + (40,), + ] +) +def test_verify_slashable_vote_data_signature(num_validators, + privkeys, sample_beacon_state_params, genesis_validators, - sample_slashable_vote_data_params): - sample_beacon_state_params["validator_registry"] = genesis_validators - state = BeaconState(**sample_beacon_state_params) + sample_slashable_vote_data_params, + sample_fork_data_params): + beacon_state_params_with_genesis_validators = assoc( + sample_beacon_state_params, + "validator_registry", + genesis_validators, + ) + beacon_state_params_with_fork_data = assoc( + beacon_state_params_with_genesis_validators, + "fork_data", + ForkData(**sample_fork_data_params), + ) + state = BeaconState(**beacon_state_params_with_fork_data) # NOTE: we can do this before "correcting" the params as they # touch disjoint subsets of the provided params messages = _create_slashable_vote_data_messages(sample_slashable_vote_data_params) valid_params = _correct_slashable_vote_data_params( + num_validators, sample_slashable_vote_data_params, - genesis_validators, messages, privkeys, + state.fork_data, ) valid_votes = SlashableVoteData(**valid_params) assert verify_slashable_vote_data_signature(state, valid_votes) - invalid_params = _corrupt_signature(valid_params) + invalid_params = _corrupt_signature(valid_params, state.fork_data) invalid_votes = SlashableVoteData(**invalid_params) assert not verify_slashable_vote_data_signature(state, invalid_votes) @@ -882,37 +919,64 @@ def _run_verify_slashable_vote(params, state, max_casper_votes, should_succeed): @pytest.mark.parametrize( ( - 'param_mapper,' - 'should_succeed' + 'num_validators', ), [ - (lambda params: params, True), - (lambda params: _corrupt_vote_count(params), False), - (lambda params: _corrupt_signature(params), False), - (lambda params: _corrupt_vote_count(_corrupt_signature(params)), False), + (40,), + ] +) +@pytest.mark.parametrize( + ( + 'param_mapper', + 'should_succeed', + 'needs_fork_data', + ), + [ + (lambda params: params, True, False), + (_corrupt_vote_count, False, False), + (_corrupt_signature, False, True), + (lambda params, fork_data: _corrupt_vote_count( + _corrupt_signature(params, fork_data) + ), False, True), ], ) -def test_verify_slashable_vote_data(param_mapper, +def test_verify_slashable_vote_data(num_validators, + param_mapper, should_succeed, + needs_fork_data, privkeys, sample_beacon_state_params, genesis_validators, sample_slashable_vote_data_params, + sample_fork_data_params, max_casper_votes): - sample_beacon_state_params["validator_registry"] = genesis_validators - state = BeaconState(**sample_beacon_state_params) + beacon_state_params_with_genesis_validators = assoc( + sample_beacon_state_params, + "validator_registry", + genesis_validators, + ) + beacon_state_params_with_fork_data = assoc( + beacon_state_params_with_genesis_validators, + "fork_data", + ForkData(**sample_fork_data_params), + ) + state = BeaconState(**beacon_state_params_with_fork_data) # NOTE: we can do this before "correcting" the params as they # touch disjoint subsets of the provided params messages = _create_slashable_vote_data_messages(sample_slashable_vote_data_params) params = _correct_slashable_vote_data_params( + num_validators, sample_slashable_vote_data_params, - genesis_validators, messages, privkeys, + state.fork_data, ) - params = param_mapper(params) + if needs_fork_data: + params = param_mapper(params, state.fork_data) + else: + params = param_mapper(params) _run_verify_slashable_vote(params, state, max_casper_votes, should_succeed)