Skip to content

Commit 0c65f67

Browse files
committed
Fixes bug with calculating the domain for the BLS signature
The previous implementation was using the `domain_type` as the `domain` when it should in fact be the value calculated by `get_domain`.
1 parent d910d42 commit 0c65f67

File tree

1 file changed

+41
-14
lines changed

1 file changed

+41
-14
lines changed

tests/beacon/test_helpers.py

Lines changed: 41 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -810,19 +810,24 @@ def test_verify_vote_count(max_casper_votes, sample_slashable_vote_data_params,
810810
assert verify_vote_count(votes, max_casper_votes)
811811

812812

813-
def _get_indices_and_signatures(num_validators, message, privkeys):
813+
def _get_indices_and_signatures(num_validators, message, privkeys, fork_data, slot):
814814
num_indices = 5
815815
assert num_validators >= num_indices
816816
indices = random.sample(range(num_validators), num_indices)
817817
privkeys = [privkeys[i] for i in indices]
818-
domain = SignatureDomain.DOMAIN_ATTESTATION
818+
domain_type = SignatureDomain.DOMAIN_ATTESTATION
819+
domain = get_domain(
820+
fork_data=fork_data,
821+
slot=slot,
822+
domain_type=domain_type,
823+
)
819824
signatures = tuple(
820825
map(lambda key: bls.sign(message, key, domain), privkeys)
821826
)
822827
return (indices, signatures)
823828

824829

825-
def _correct_slashable_vote_data_params(params, validators, messages, privkeys):
830+
def _correct_slashable_vote_data_params(params, validators, messages, privkeys, fork_data):
826831
valid_params = copy.deepcopy(params)
827832

828833
num_validators = len(validators)
@@ -832,6 +837,8 @@ def _correct_slashable_vote_data_params(params, validators, messages, privkeys):
832837
num_validators,
833838
messages[0],
834839
privkeys,
840+
fork_data,
841+
params["data"].slot,
835842
)
836843
valid_params[key] = poc_0_indices
837844

@@ -841,6 +848,8 @@ def _correct_slashable_vote_data_params(params, validators, messages, privkeys):
841848
num_validators,
842849
messages[1],
843850
privkeys,
851+
fork_data,
852+
params["data"].slot,
844853
)
845854
valid_params[key] = poc_1_indices
846855

@@ -852,10 +861,15 @@ def _correct_slashable_vote_data_params(params, validators, messages, privkeys):
852861
return valid_params
853862

854863

855-
def _corrupt_signature(params):
864+
def _corrupt_signature(params, fork_data):
856865
message = bytes.fromhex("deadbeefcafe")
857866
privkey = 42
858-
domain = SignatureDomain.DOMAIN_ATTESTATION
867+
domain_type = SignatureDomain.DOMAIN_ATTESTATION
868+
domain = get_domain(
869+
fork_data=fork_data,
870+
slot=params["data"].slot,
871+
domain_type=domain_type,
872+
)
859873
corrupt_signature = bls.sign(message, privkey, domain)
860874

861875
return cytoolz.assoc(params, "aggregate_signature", corrupt_signature)
@@ -885,24 +899,27 @@ def _create_slashable_vote_data_messages(params):
885899
def test_verify_slashable_vote_data_signature(privkeys,
886900
sample_beacon_state_params,
887901
genesis_validators,
888-
sample_slashable_vote_data_params):
902+
sample_slashable_vote_data_params,
903+
sample_fork_data_params):
889904
sample_beacon_state_params["validator_registry"] = genesis_validators
890905
state = BeaconState(**sample_beacon_state_params)
891906

892907
# NOTE: we can do this before "correcting" the params as they
893908
# touch disjoint subsets of the provided params
894909
messages = _create_slashable_vote_data_messages(sample_slashable_vote_data_params)
895910

911+
fork_data = ForkData(**sample_fork_data_params)
896912
valid_params = _correct_slashable_vote_data_params(
897913
sample_slashable_vote_data_params,
898914
genesis_validators,
899915
messages,
900916
privkeys,
917+
fork_data,
901918
)
902919
valid_votes = SlashableVoteData(**valid_params)
903920
assert verify_slashable_vote_data_signature(state, valid_votes)
904921

905-
invalid_params = _corrupt_signature(valid_params)
922+
invalid_params = _corrupt_signature(valid_params, fork_data)
906923
invalid_votes = SlashableVoteData(**invalid_params)
907924
assert not verify_slashable_vote_data_signature(state, invalid_votes)
908925

@@ -918,22 +935,27 @@ def _run_verify_slashable_vote(params, state, max_casper_votes, should_succeed):
918935

919936
@pytest.mark.parametrize(
920937
(
921-
'param_mapper,'
922-
'should_succeed'
938+
'param_mapper',
939+
'should_succeed',
940+
'needs_fork_data',
923941
),
924942
[
925-
(lambda params: params, True),
926-
(lambda params: _corrupt_vote_count(params), False),
927-
(lambda params: _corrupt_signature(params), False),
928-
(lambda params: _corrupt_vote_count(_corrupt_signature(params)), False),
943+
(lambda params: params, True, False),
944+
(_corrupt_vote_count, False, False),
945+
(_corrupt_signature, False, True),
946+
(lambda params, fork_data: _corrupt_vote_count(
947+
_corrupt_signature(params, fork_data)
948+
), False, True),
929949
],
930950
)
931951
def test_verify_slashable_vote_data(param_mapper,
932952
should_succeed,
953+
needs_fork_data,
933954
privkeys,
934955
sample_beacon_state_params,
935956
genesis_validators,
936957
sample_slashable_vote_data_params,
958+
sample_fork_data_params,
937959
max_casper_votes):
938960
sample_beacon_state_params["validator_registry"] = genesis_validators
939961
state = BeaconState(**sample_beacon_state_params)
@@ -942,13 +964,18 @@ def test_verify_slashable_vote_data(param_mapper,
942964
# touch disjoint subsets of the provided params
943965
messages = _create_slashable_vote_data_messages(sample_slashable_vote_data_params)
944966

967+
fork_data = ForkData(**sample_fork_data_params)
945968
params = _correct_slashable_vote_data_params(
946969
sample_slashable_vote_data_params,
947970
genesis_validators,
948971
messages,
949972
privkeys,
973+
fork_data,
950974
)
951-
params = param_mapper(params)
975+
if needs_fork_data:
976+
params = param_mapper(params, fork_data)
977+
else:
978+
params = param_mapper(params)
952979
_run_verify_slashable_vote(params, state, max_casper_votes, should_succeed)
953980

954981

0 commit comments

Comments
 (0)