Skip to content

Commit 322c54f

Browse files
authored
rework class protocol check pattern (#15134)
1 parent 808514a commit 322c54f

File tree

6 files changed

+36
-42
lines changed

6 files changed

+36
-42
lines changed

chia/data_layer/data_layer_wallet.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import logging
55
import time
66
from operator import attrgetter
7-
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple
7+
from typing import TYPE_CHECKING, Any, ClassVar, Dict, List, Optional, Set, Tuple, cast
88

99
from blspy import G1Element, G2Element
1010
from clvm.EvalError import EvalError
@@ -105,6 +105,11 @@ def from_json_dict(cls, json_dict: Dict[str, Any]) -> "Mirror":
105105

106106
@final
107107
class DataLayerWallet:
108+
if TYPE_CHECKING:
109+
from chia.wallet.wallet_protocol import WalletProtocol
110+
111+
_protocol_check: ClassVar[WalletProtocol] = cast("DataLayerWallet", None)
112+
108113
wallet_state_manager: WalletStateManager
109114
log: logging.Logger
110115
wallet_info: WalletInfo
@@ -1390,9 +1395,3 @@ def verify_offer(
13901395

13911396
if taker_from_offer != taker_from_reference:
13921397
raise OfferIntegrityError("taker: reference and offer inclusions do not match")
1393-
1394-
1395-
if TYPE_CHECKING:
1396-
from chia.wallet.wallet_protocol import WalletProtocol
1397-
1398-
_dummy: WalletProtocol = DataLayerWallet()

chia/pools/pool_wallet.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import dataclasses
44
import logging
55
import time
6-
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, cast
6+
from typing import TYPE_CHECKING, Any, ClassVar, Dict, List, Optional, Set, Tuple, cast
77

88
from blspy import G1Element, G2Element, PrivateKey
99
from typing_extensions import final
@@ -59,6 +59,11 @@
5959
@final
6060
@dataclasses.dataclass
6161
class PoolWallet:
62+
if TYPE_CHECKING:
63+
from chia.wallet.wallet_protocol import WalletProtocol
64+
65+
_protocol_check: ClassVar[WalletProtocol] = cast("PoolWallet", None)
66+
6267
MINIMUM_INITIAL_BALANCE = 1
6368
MINIMUM_RELATIVE_LOCK_HEIGHT = 5
6469
MAXIMUM_RELATIVE_LOCK_HEIGHT = 1000
@@ -988,9 +993,3 @@ def puzzle_hash_for_pk(self, pubkey: G1Element) -> bytes32:
988993

989994
def get_name(self) -> str:
990995
return self.wallet_info.name
991-
992-
993-
if TYPE_CHECKING:
994-
from chia.wallet.wallet_protocol import WalletProtocol
995-
996-
_dummy: WalletProtocol = cast(PoolWallet, None)

chia/wallet/cat_wallet/cat_wallet.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import time
66
import traceback
77
from secrets import token_bytes
8-
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple
8+
from typing import TYPE_CHECKING, Any, ClassVar, Dict, List, Optional, Set, Tuple, cast
99

1010
from blspy import AugSchemeMPL, G1Element, G2Element
1111

@@ -68,6 +68,11 @@
6868

6969

7070
class CATWallet:
71+
if TYPE_CHECKING:
72+
from chia.wallet.wallet_protocol import WalletProtocol
73+
74+
_protocol_check: ClassVar[WalletProtocol] = cast("CATWallet", None)
75+
7176
wallet_state_manager: WalletStateManager
7277
log: logging.Logger
7378
wallet_info: WalletInfo
@@ -925,9 +930,3 @@ async def get_coins_to_offer(
925930
if balance < amount:
926931
raise Exception(f"insufficient funds in wallet {self.id()}")
927932
return await self.select_coins(amount, min_coin_amount=min_coin_amount, max_coin_amount=max_coin_amount)
928-
929-
930-
if TYPE_CHECKING:
931-
from chia.wallet.wallet_protocol import WalletProtocol
932-
933-
_dummy: WalletProtocol = CATWallet()

chia/wallet/did_wallet/did_wallet.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import re
77
import time
88
from secrets import token_bytes
9-
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple
9+
from typing import TYPE_CHECKING, Any, ClassVar, Dict, List, Optional, Set, Tuple, cast
1010

1111
from blspy import AugSchemeMPL, G1Element, G2Element
1212

@@ -53,6 +53,11 @@
5353

5454

5555
class DIDWallet:
56+
if TYPE_CHECKING:
57+
from chia.wallet.wallet_protocol import WalletProtocol
58+
59+
_protocol_check: ClassVar[WalletProtocol] = cast("DIDWallet", None)
60+
5661
wallet_state_manager: Any
5762
log: logging.Logger
5863
wallet_info: WalletInfo
@@ -1455,9 +1460,3 @@ def deserialize_backup_data(backup_data: str) -> DIDInfo:
14551460

14561461
def require_derivation_paths(self) -> bool:
14571462
return True
1458-
1459-
1460-
if TYPE_CHECKING:
1461-
from chia.wallet.wallet_protocol import WalletProtocol
1462-
1463-
_dummy: WalletProtocol = DIDWallet()

chia/wallet/nft_wallet/nft_wallet.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import logging
66
import math
77
import time
8-
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Type, TypeVar
8+
from typing import TYPE_CHECKING, Any, ClassVar, Dict, List, Optional, Set, Tuple, Type, TypeVar, cast
99

1010
from blspy import AugSchemeMPL, G1Element, G2Element
1111
from clvm.casts import int_from_bytes, int_to_bytes
@@ -61,6 +61,11 @@
6161

6262

6363
class NFTWallet:
64+
if TYPE_CHECKING:
65+
from chia.wallet.wallet_protocol import WalletProtocol
66+
67+
_protocol_check: ClassVar[WalletProtocol] = cast("NFTWallet", None)
68+
6469
wallet_state_manager: Any
6570
log: logging.Logger
6671
wallet_info: WalletInfo
@@ -1723,9 +1728,3 @@ def puzzle_hash_for_pk(self, pubkey: G1Element) -> bytes32:
17231728

17241729
def get_name(self) -> str:
17251730
return self.wallet_info.name
1726-
1727-
1728-
if TYPE_CHECKING:
1729-
from chia.wallet.wallet_protocol import WalletProtocol
1730-
1731-
_dummy: WalletProtocol = NFTWallet()

chia/wallet/wallet.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import logging
44
import time
5-
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple
5+
from typing import TYPE_CHECKING, Any, ClassVar, Dict, List, Optional, Set, Tuple, cast
66

77
from blspy import AugSchemeMPL, G1Element, G2Element
88

@@ -56,6 +56,11 @@
5656

5757

5858
class Wallet:
59+
if TYPE_CHECKING:
60+
from chia.wallet.wallet_protocol import WalletProtocol
61+
62+
_protocol_check: ClassVar[WalletProtocol] = cast("Wallet", None)
63+
5964
wallet_info: WalletInfo
6065
wallet_state_manager: Any
6166
log: logging.Logger
@@ -608,9 +613,3 @@ async def coin_added(
608613

609614
def get_name(self) -> str:
610615
return "Standard Wallet"
611-
612-
613-
if TYPE_CHECKING:
614-
from chia.wallet.wallet_protocol import WalletProtocol
615-
616-
_dummy: WalletProtocol = Wallet()

0 commit comments

Comments
 (0)