Skip to content

Commit 47951bb

Browse files
committed
refactor: Upgrade to Pydantic v2
1 parent 630aa4e commit 47951bb

File tree

28 files changed

+516
-406
lines changed

28 files changed

+516
-406
lines changed

hathor/conf/settings.py

Lines changed: 30 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from enum import StrEnum, auto, unique
1717
from math import log
1818
from pathlib import Path
19-
from typing import Any, NamedTuple, Optional, Union
19+
from typing import NamedTuple, Optional, Union
2020

2121
import pydantic
2222

@@ -561,37 +561,19 @@ def parse_hex_str(hex_str: Union[str, bytes]) -> bytes:
561561
return hex_str
562562

563563

564-
def _validate_consensus_algorithm(consensus_algorithm: ConsensusSettings, values: dict[str, Any]) -> ConsensusSettings:
565-
"""Validate that if Proof-of-Authority is enabled, block rewards must not be set."""
566-
if consensus_algorithm.is_pow():
567-
return consensus_algorithm
568-
569-
assert consensus_algorithm.is_poa()
570-
blocks_per_halving = values.get('BLOCKS_PER_HALVING')
571-
initial_token_units_per_block = values.get('INITIAL_TOKEN_UNITS_PER_BLOCK')
572-
minimum_token_units_per_block = values.get('MINIMUM_TOKEN_UNITS_PER_BLOCK')
573-
assert initial_token_units_per_block is not None, 'INITIAL_TOKEN_UNITS_PER_BLOCK must be set'
574-
assert minimum_token_units_per_block is not None, 'MINIMUM_TOKEN_UNITS_PER_BLOCK must be set'
575-
576-
if blocks_per_halving is not None or initial_token_units_per_block != 0 or minimum_token_units_per_block != 0:
577-
raise ValueError('PoA networks do not support block rewards')
564+
def _validate_consensus_algorithm(consensus_algorithm: ConsensusSettings) -> ConsensusSettings:
565+
"""Validate that if Proof-of-Authority is enabled, block rewards must not be set.
578566
567+
Note: Full validation with other fields is done in the named_tuple validator.
568+
"""
579569
return consensus_algorithm
580570

581571

582-
def _validate_tokens(genesis_tokens: int, values: dict[str, Any]) -> int:
583-
"""Validate genesis tokens."""
584-
genesis_token_units = values.get('GENESIS_TOKEN_UNITS')
585-
decimal_places = values.get('DECIMAL_PLACES')
586-
assert genesis_token_units is not None, 'GENESIS_TOKEN_UNITS must be set'
587-
assert decimal_places is not None, 'DECIMAL_PLACES must be set'
588-
589-
if genesis_tokens != genesis_token_units * (10 ** decimal_places):
590-
raise ValueError(
591-
f'invalid tokens: GENESIS_TOKENS={genesis_tokens}, GENESIS_TOKEN_UNITS={genesis_token_units}, '
592-
f'DECIMAL_PLACES={decimal_places}',
593-
)
572+
def _validate_tokens(genesis_tokens: int) -> int:
573+
"""Validate genesis tokens.
594574
575+
Note: Full validation with other fields is done in the named_tuple validator.
576+
"""
595577
return genesis_tokens
596578

597579

@@ -607,43 +589,41 @@ def _validate_token_deposit_percentage(token_deposit_percentage: float) -> float
607589

608590

609591
_VALIDATORS = dict(
610-
_parse_hex_str=pydantic.validator(
592+
_parse_hex_str=pydantic.field_validator(
611593
'P2PKH_VERSION_BYTE',
612594
'MULTISIG_VERSION_BYTE',
613595
'GENESIS_OUTPUT_SCRIPT',
614596
'GENESIS_BLOCK_HASH',
615597
'GENESIS_TX1_HASH',
616598
'GENESIS_TX2_HASH',
617-
pre=True,
618-
allow_reuse=True
599+
mode='before',
619600
)(parse_hex_str),
620-
_parse_soft_voided_tx_id=pydantic.validator(
601+
_parse_soft_voided_tx_id=pydantic.field_validator(
621602
'SOFT_VOIDED_TX_IDS',
622-
pre=True,
623-
allow_reuse=True,
624-
each_item=True
625-
)(parse_hex_str),
626-
_parse_skipped_verification_tx_id=pydantic.validator(
603+
mode='before',
604+
)(lambda v: [parse_hex_str(x) for x in v] if isinstance(v, list) else v),
605+
_parse_skipped_verification_tx_id=pydantic.field_validator(
627606
'SKIP_VERIFICATION',
628-
pre=True,
629-
allow_reuse=True,
630-
each_item=True
631-
)(parse_hex_str),
632-
_parse_checkpoints=pydantic.validator(
607+
mode='before',
608+
)(lambda v: [parse_hex_str(x) for x in v] if isinstance(v, list) else v),
609+
_parse_checkpoints=pydantic.field_validator(
633610
'CHECKPOINTS',
634-
pre=True
611+
mode='before',
635612
)(_parse_checkpoints),
636-
_parse_blueprints=pydantic.validator(
613+
_parse_blueprints=pydantic.field_validator(
637614
'BLUEPRINTS',
638-
pre=True
615+
mode='before',
639616
)(_parse_blueprints),
640-
_validate_consensus_algorithm=pydantic.validator(
641-
'CONSENSUS_ALGORITHM'
617+
_validate_consensus_algorithm=pydantic.field_validator(
618+
'CONSENSUS_ALGORITHM',
619+
mode='after',
642620
)(_validate_consensus_algorithm),
643-
_validate_tokens=pydantic.validator(
644-
'GENESIS_TOKENS'
621+
_validate_tokens=pydantic.field_validator(
622+
'GENESIS_TOKENS',
623+
mode='after',
645624
)(_validate_tokens),
646-
_validate_token_deposit_percentage=pydantic.validator(
647-
'TOKEN_DEPOSIT_PERCENTAGE'
625+
_validate_token_deposit_percentage=pydantic.field_validator(
626+
'TOKEN_DEPOSIT_PERCENTAGE',
627+
mode='after',
648628
)(_validate_token_deposit_percentage),
649629
)

hathor/consensus/consensus_settings.py

Lines changed: 19 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,14 @@
1717
import hashlib
1818
from abc import ABC, abstractmethod
1919
from enum import Enum, unique
20-
from typing import TYPE_CHECKING, Annotated, Any, Literal, TypeAlias
20+
from typing import TYPE_CHECKING, Annotated, Any, Literal, TypeAlias, Union
2121

22-
from pydantic import Field, NonNegativeInt, PrivateAttr, validator
22+
from pydantic import Discriminator, NonNegativeInt, PrivateAttr, Tag, field_validator
2323
from typing_extensions import override
2424

2525
from hathor.transaction import TxVersion
2626
from hathor.util import json_dumpb
27-
from hathor.utils.pydantic import BaseModel
27+
from hathor.utils.pydantic import BaseModel, Hex
2828

2929
if TYPE_CHECKING:
3030
from hathor.conf.settings import HathorSettings
@@ -97,34 +97,19 @@ def get_peer_hello_hash(self) -> str | None:
9797

9898

9999
class PoaSignerSettings(BaseModel):
100-
public_key: bytes
100+
public_key: Hex[bytes]
101101
start_height: NonNegativeInt = 0
102102
end_height: NonNegativeInt | None = None
103103

104-
@validator('public_key', pre=True)
105-
def _parse_hex_str(cls, hex_str: str | bytes) -> bytes:
106-
from hathor.conf.settings import parse_hex_str
107-
return parse_hex_str(hex_str)
108-
109-
@validator('end_height')
110-
def _validate_end_height(cls, end_height: int | None, values: dict[str, Any]) -> int | None:
111-
start_height = values.get('start_height')
112-
assert start_height is not None, 'start_height must be set'
113-
114-
if end_height is None:
115-
return None
116-
117-
if end_height <= start_height:
118-
raise ValueError(f'end_height ({end_height}) must be greater than start_height ({start_height})')
119-
120-
return end_height
104+
def __init__(self, **data: Any) -> None:
105+
super().__init__(**data)
106+
# Validate end_height > start_height after initialization
107+
if self.end_height is not None and self.end_height <= self.start_height:
108+
raise ValueError(f'end_height ({self.end_height}) must be greater than start_height ({self.start_height})')
121109

122110
def to_json_dict(self) -> dict[str, Any]:
123111
"""Return this signer settings instance as a json dict."""
124-
json_dict = self.dict()
125-
# TODO: We can use a custom serializer to convert bytes to hex when we update to Pydantic V2.
126-
json_dict['public_key'] = self.public_key.hex()
127-
return json_dict
112+
return self.model_dump()
128113

129114

130115
class PoaSettings(_BaseConsensusSettings):
@@ -133,7 +118,8 @@ class PoaSettings(_BaseConsensusSettings):
133118
# A list of Proof-of-Authority signer public keys that have permission to produce blocks.
134119
signers: tuple[PoaSignerSettings, ...]
135120

136-
@validator('signers')
121+
@field_validator('signers', mode='after')
122+
@classmethod
137123
def _validate_signers(cls, signers: tuple[PoaSignerSettings, ...]) -> tuple[PoaSignerSettings, ...]:
138124
if len(signers) == 0:
139125
raise ValueError('At least one signer must be provided in PoA networks')
@@ -165,4 +151,10 @@ def _calculate_peer_hello_hash(self) -> str | None:
165151
return hashlib.sha256(data).digest().hex()
166152

167153

168-
ConsensusSettings: TypeAlias = Annotated[PowSettings | PoaSettings, Field(discriminator='type')]
154+
ConsensusSettings: TypeAlias = Annotated[
155+
Union[
156+
Annotated[PowSettings, Tag(ConsensusType.PROOF_OF_WORK)],
157+
Annotated[PoaSettings, Tag(ConsensusType.PROOF_OF_AUTHORITY)],
158+
],
159+
Discriminator('type')
160+
]

hathor/consensus/poa/poa_signer.py

Lines changed: 42 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
from cryptography.hazmat.primitives import hashes
2121
from cryptography.hazmat.primitives.asymmetric import ec
22-
from pydantic import Field, validator
22+
from pydantic import ConfigDict, Field, field_validator, model_validator
2323

2424
from hathor.consensus import poa
2525
from hathor.crypto.util import (
@@ -33,46 +33,62 @@
3333
from hathor.transaction.poa import PoaBlock
3434

3535

36-
class PoaSignerFile(BaseModel, arbitrary_types_allowed=True):
36+
class PoaSignerFile(BaseModel):
3737
"""Class that represents a Proof-of-Authority signer configuration file."""
38+
model_config = ConfigDict(arbitrary_types_allowed=True)
39+
3840
private_key: ec.EllipticCurvePrivateKeyWithSerialization = Field(alias='private_key_hex')
3941
public_key: ec.EllipticCurvePublicKey = Field(alias='public_key_hex')
4042
address: str
43+
_public_key_hex: bytes = bytes() # Store original hex for validation
4144

42-
@validator('private_key', pre=True)
45+
@field_validator('private_key', mode='before')
46+
@classmethod
4347
def _parse_private_key(cls, private_key_hex: str) -> ec.EllipticCurvePrivateKeyWithSerialization:
4448
"""Parse a private key hex into a private key instance."""
4549
private_key_bytes = bytes.fromhex(private_key_hex)
4650
return get_private_key_from_bytes(private_key_bytes)
4751

48-
@validator('public_key', pre=True)
49-
def _validate_public_key_first_bytes(
50-
cls,
51-
public_key_hex: str,
52-
values: dict[str, Any]
53-
) -> ec.EllipticCurvePublicKey:
54-
"""Parse a public key hex into a public key instance, and validate that it corresponds to the private key."""
55-
private_key = values.get('private_key')
56-
assert isinstance(private_key, ec.EllipticCurvePrivateKey), 'private_key must be set'
57-
58-
public_key_bytes = bytes.fromhex(public_key_hex)
59-
actual_public_key = private_key.public_key()
60-
61-
if public_key_bytes != get_public_key_bytes_compressed(actual_public_key):
52+
@model_validator(mode='before')
53+
@classmethod
54+
def _store_public_key_hex(cls, data: dict[str, Any]) -> dict[str, Any]:
55+
"""Store the public key hex for validation and convert to actual key."""
56+
public_key_hex = data.get('public_key_hex')
57+
if public_key_hex is not None:
58+
data['_public_key_hex_input'] = bytes.fromhex(public_key_hex)
59+
return data
60+
61+
@model_validator(mode='after')
62+
def _validate_keys_and_address(self) -> 'PoaSignerFile':
63+
"""Validate that public key and address correspond to the private key."""
64+
actual_public_key = self.private_key.public_key()
65+
actual_public_key_bytes = get_public_key_bytes_compressed(actual_public_key)
66+
67+
# Validate the provided public key matches the one derived from private key
68+
provided_public_key_bytes = get_public_key_bytes_compressed(self.public_key)
69+
if provided_public_key_bytes != actual_public_key_bytes:
6270
raise ValueError('invalid public key')
6371

64-
return actual_public_key
72+
if self.address != get_address_b58_from_public_key(actual_public_key):
73+
raise ValueError('invalid address')
6574

66-
@validator('address')
67-
def _validate_address(cls, address: str, values: dict[str, Any]) -> str:
68-
"""Validate that the provided address corresponds to the provided private key."""
69-
private_key = values.get('private_key')
70-
assert isinstance(private_key, ec.EllipticCurvePrivateKey), 'private_key must be set'
75+
return self
7176

72-
if address != get_address_b58_from_public_key(private_key.public_key()):
73-
raise ValueError('invalid address')
77+
@field_validator('public_key', mode='before')
78+
@classmethod
79+
def _parse_public_key(cls, public_key_hex: str | ec.EllipticCurvePublicKey) -> ec.EllipticCurvePublicKey:
80+
"""Parse public key hex to public key object."""
81+
if isinstance(public_key_hex, ec.EllipticCurvePublicKey):
82+
return public_key_hex
83+
# The public key is provided as compressed bytes
84+
public_key_bytes = bytes.fromhex(public_key_hex)
85+
# For compressed public keys, we need to use a different approach
86+
from cryptography.hazmat.primitives.asymmetric import ec as ec_module
7487

75-
return address
88+
# Load compressed public key
89+
return ec_module.EllipticCurvePublicKey.from_encoded_point(
90+
ec_module.SECP256K1(), public_key_bytes
91+
)
7692

7793
def get_signer(self) -> PoaSigner:
7894
"""Get a PoaSigner for this file."""

hathor/event/model/base_event.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,19 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import Any, Optional
15+
from typing import Optional
1616

17-
from pydantic import NonNegativeInt, validator
17+
from pydantic import ConfigDict, NonNegativeInt, model_validator
1818

19-
from hathor.event.model.event_data import BaseEventData, EventData
19+
from hathor.event.model.event_data import EventData
2020
from hathor.event.model.event_type import EventType
2121
from hathor.pubsub import EventArguments
2222
from hathor.utils.pydantic import BaseModel
2323

2424

25-
class BaseEvent(BaseModel, use_enum_values=True):
25+
class BaseEvent(BaseModel):
26+
model_config = ConfigDict(use_enum_values=True)
27+
2628
# Event unique id, determines event order
2729
id: NonNegativeInt
2830
# Timestamp in which the event was emitted, this follows the unix_timestamp format, it's only informative, events
@@ -57,12 +59,12 @@ def from_event_arguments(
5759
group_id=group_id,
5860
)
5961

60-
@validator('data')
61-
def data_type_must_match_event_type(cls, v: BaseEventData, values: dict[str, Any]) -> BaseEventData:
62-
event_type = EventType(values['type'])
62+
@model_validator(mode='after')
63+
def data_type_must_match_event_type(self) -> 'BaseEvent':
64+
event_type = EventType(self.type)
6365
expected_data_type = event_type.data_type()
6466

65-
if type(v) is not expected_data_type:
67+
if type(self.data) is not expected_data_type:
6668
raise ValueError('event data type does not match event type')
6769

68-
return v
70+
return self

0 commit comments

Comments
 (0)