Skip to content

Commit 9fe6b38

Browse files
committed
Fixes bug with calculating the domain for the BLS signature
The previous implementation was using the `domain_type` as the `domain` when it should in fact be the value calculated by `get_domain`.
1 parent d8c7885 commit 9fe6b38

File tree

1 file changed

+41
-14
lines changed

1 file changed

+41
-14
lines changed

tests/beacon/test_helpers.py

Lines changed: 41 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -785,19 +785,24 @@ def test_verify_vote_count(max_casper_votes, sample_slashable_vote_data_params,
785785
assert verify_vote_count(votes, max_casper_votes)
786786

787787

788-
def _get_indices_and_signatures(num_validators, message, privkeys):
788+
def _get_indices_and_signatures(num_validators, message, privkeys, fork_data, slot):
789789
num_indices = 5
790790
assert num_validators >= num_indices
791791
indices = random.sample(range(num_validators), num_indices)
792792
privkeys = [privkeys[i] for i in indices]
793-
domain = SignatureDomain.DOMAIN_ATTESTATION
793+
domain_type = SignatureDomain.DOMAIN_ATTESTATION
794+
domain = get_domain(
795+
fork_data=fork_data,
796+
slot=slot,
797+
domain_type=domain_type,
798+
)
794799
signatures = tuple(
795800
map(lambda key: bls.sign(message, key, domain), privkeys)
796801
)
797802
return (indices, signatures)
798803

799804

800-
def _correct_slashable_vote_data_params(params, validators, messages, privkeys):
805+
def _correct_slashable_vote_data_params(params, validators, messages, privkeys, fork_data):
801806
valid_params = copy.deepcopy(params)
802807

803808
num_validators = len(validators)
@@ -807,6 +812,8 @@ def _correct_slashable_vote_data_params(params, validators, messages, privkeys):
807812
num_validators,
808813
messages[0],
809814
privkeys,
815+
fork_data,
816+
params["data"].slot,
810817
)
811818
valid_params[key] = poc_0_indices
812819

@@ -816,6 +823,8 @@ def _correct_slashable_vote_data_params(params, validators, messages, privkeys):
816823
num_validators,
817824
messages[1],
818825
privkeys,
826+
fork_data,
827+
params["data"].slot,
819828
)
820829
valid_params[key] = poc_1_indices
821830

@@ -827,10 +836,15 @@ def _correct_slashable_vote_data_params(params, validators, messages, privkeys):
827836
return valid_params
828837

829838

830-
def _corrupt_signature(params):
839+
def _corrupt_signature(params, fork_data):
831840
message = bytes.fromhex("deadbeefcafe")
832841
privkey = 42
833-
domain = SignatureDomain.DOMAIN_ATTESTATION
842+
domain_type = SignatureDomain.DOMAIN_ATTESTATION
843+
domain = get_domain(
844+
fork_data=fork_data,
845+
slot=params["data"].slot,
846+
domain_type=domain_type,
847+
)
834848
corrupt_signature = bls.sign(message, privkey, domain)
835849

836850
return cytoolz.assoc(params, "aggregate_signature", corrupt_signature)
@@ -860,24 +874,27 @@ def _create_slashable_vote_data_messages(params):
860874
def test_verify_slashable_vote_data_signature(privkeys,
861875
sample_beacon_state_params,
862876
genesis_validators,
863-
sample_slashable_vote_data_params):
877+
sample_slashable_vote_data_params,
878+
sample_fork_data_params):
864879
sample_beacon_state_params["validator_registry"] = genesis_validators
865880
state = BeaconState(**sample_beacon_state_params)
866881

867882
# NOTE: we can do this before "correcting" the params as they
868883
# touch disjoint subsets of the provided params
869884
messages = _create_slashable_vote_data_messages(sample_slashable_vote_data_params)
870885

886+
fork_data = ForkData(**sample_fork_data_params)
871887
valid_params = _correct_slashable_vote_data_params(
872888
sample_slashable_vote_data_params,
873889
genesis_validators,
874890
messages,
875891
privkeys,
892+
fork_data,
876893
)
877894
valid_votes = SlashableVoteData(**valid_params)
878895
assert verify_slashable_vote_data_signature(state, valid_votes)
879896

880-
invalid_params = _corrupt_signature(valid_params)
897+
invalid_params = _corrupt_signature(valid_params, fork_data)
881898
invalid_votes = SlashableVoteData(**invalid_params)
882899
assert not verify_slashable_vote_data_signature(state, invalid_votes)
883900

@@ -893,22 +910,27 @@ def _run_verify_slashable_vote(params, state, max_casper_votes, should_succeed):
893910

894911
@pytest.mark.parametrize(
895912
(
896-
'param_mapper,'
897-
'should_succeed'
913+
'param_mapper',
914+
'should_succeed',
915+
'needs_fork_data',
898916
),
899917
[
900-
(lambda params: params, True),
901-
(lambda params: _corrupt_vote_count(params), False),
902-
(lambda params: _corrupt_signature(params), False),
903-
(lambda params: _corrupt_vote_count(_corrupt_signature(params)), False),
918+
(lambda params: params, True, False),
919+
(_corrupt_vote_count, False, False),
920+
(_corrupt_signature, False, True),
921+
(lambda params, fork_data: _corrupt_vote_count(
922+
_corrupt_signature(params, fork_data)
923+
), False, True),
904924
],
905925
)
906926
def test_verify_slashable_vote_data(param_mapper,
907927
should_succeed,
928+
needs_fork_data,
908929
privkeys,
909930
sample_beacon_state_params,
910931
genesis_validators,
911932
sample_slashable_vote_data_params,
933+
sample_fork_data_params,
912934
max_casper_votes):
913935
sample_beacon_state_params["validator_registry"] = genesis_validators
914936
state = BeaconState(**sample_beacon_state_params)
@@ -917,13 +939,18 @@ def test_verify_slashable_vote_data(param_mapper,
917939
# touch disjoint subsets of the provided params
918940
messages = _create_slashable_vote_data_messages(sample_slashable_vote_data_params)
919941

942+
fork_data = ForkData(**sample_fork_data_params)
920943
params = _correct_slashable_vote_data_params(
921944
sample_slashable_vote_data_params,
922945
genesis_validators,
923946
messages,
924947
privkeys,
948+
fork_data,
925949
)
926-
params = param_mapper(params)
950+
if needs_fork_data:
951+
params = param_mapper(params, fork_data)
952+
else:
953+
params = param_mapper(params)
927954
_run_verify_slashable_vote(params, state, max_casper_votes, should_succeed)
928955

929956

0 commit comments

Comments
 (0)