|
20 | 20 | from hashlib import sha256
|
21 | 21 | from types import SimpleNamespace
|
22 | 22 | from typing import Any, Optional, Union, TypedDict
|
23 |
| -from unittest.mock import MagicMock |
| 23 | +from unittest.mock import MagicMock, patch |
24 | 24 |
|
25 | 25 | from bittensor_wallet import Wallet
|
| 26 | +from substrateinterface.base import SubstrateInterface |
| 27 | +from websockets.sync.client import ClientConnection |
26 | 28 |
|
27 | 29 | from bittensor.core.chain_data import (
|
28 | 30 | NeuronInfo,
|
|
33 | 35 | from bittensor.core.types import AxonServeCallParams, PrometheusServeCallParams
|
34 | 36 | from bittensor.core.errors import ChainQueryError
|
35 | 37 | from bittensor.core.subtensor import Subtensor
|
| 38 | +import bittensor.core.subtensor as subtensor_module |
36 | 39 | from bittensor.utils import RAOPERTAO, u16_normalized_float
|
37 | 40 | from bittensor.utils.balance import Balance
|
38 | 41 |
|
@@ -248,14 +251,22 @@ def setup(self) -> None:
|
248 | 251 |
|
249 | 252 | self.network = "mock"
|
250 | 253 | self.chain_endpoint = "ws://mock_endpoint.bt"
|
251 |
| - self.substrate = MagicMock() |
| 254 | + self.substrate = MagicMock(autospec=SubstrateInterface) |
252 | 255 |
|
253 | 256 | def __init__(self, *args, **kwargs) -> None:
|
254 |
| - super().__init__() |
255 |
| - self.__dict__ = __GLOBAL_MOCK_STATE__ |
256 |
| - |
257 |
| - if not hasattr(self, "chain_state") or getattr(self, "chain_state") is None: |
258 |
| - self.setup() |
| 257 | + mock_substrate_interface = MagicMock(autospec=SubstrateInterface) |
| 258 | + mock_websocket = MagicMock(autospec=ClientConnection) |
| 259 | + mock_websocket.close_code = None |
| 260 | + with patch.object( |
| 261 | + subtensor_module, |
| 262 | + "SubstrateInterface", |
| 263 | + return_value=mock_substrate_interface, |
| 264 | + ): |
| 265 | + super().__init__(websocket=mock_websocket) |
| 266 | + self.__dict__ = __GLOBAL_MOCK_STATE__ |
| 267 | + |
| 268 | + if not hasattr(self, "chain_state") or getattr(self, "chain_state") is None: |
| 269 | + self.setup() |
259 | 270 |
|
260 | 271 | def get_block_hash(self, block_id: int) -> str:
|
261 | 272 | return "0x" + sha256(str(block_id).encode()).hexdigest()[:64]
|
|
0 commit comments