Skip to content

Commit 8f6f24c

Browse files
committed
refactor: fix type heirarchy and other small bugs
1 parent 782fb61 commit 8f6f24c

File tree

2 files changed

+46
-32
lines changed

2 files changed

+46
-32
lines changed

src/ape/api/query.py

Lines changed: 41 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
from abc import abstractmethod
22
from collections.abc import Iterator, Sequence
33
from functools import cache, cached_property
4-
from typing import TYPE_CHECKING, Any, ClassVar, Generic, Optional, TypeVar, Union
4+
from typing import TYPE_CHECKING, Any, ClassVar, Generic, Optional, Type, TypeVar, Union
55

66
from ethpm_types.abi import EventABI, MethodABI
77
from pydantic import NonNegativeInt, PositiveInt, field_validator, model_validator
88

99
from ape.logging import logger
10+
from ape.types import ContractLog
1011
from ape.types.address import AddressType
1112
from ape.utils import singledispatchmethod
1213
from ape.utils.basemodel import BaseInterface, BaseInterfaceModel, BaseModel
@@ -103,13 +104,14 @@ def extract_fields(item: BaseInterfaceModel, columns: Sequence[str]) -> list[Any
103104

104105

105106
class _BaseQuery(BaseModel, Generic[ModelType]):
106-
Model: ClassVar[ModelType]
107+
Model: ClassVar[Optional[Type[BaseInterfaceModel]]] = None
108+
107109
columns: set[str]
108110

109111
@field_validator("columns", mode="before")
110112
def expand_wildcard(cls, value: Any) -> Any:
111-
if isinstance(value, str) and value == "*":
112-
return _basic_columns(cls.Model)
113+
if cls.Model:
114+
return validate_and_expand_columns(value, cls.Model)
113115

114116
return value
115117

@@ -180,19 +182,21 @@ def end_index(self) -> int:
180182
return self.stop_block
181183

182184

183-
class BlockQuery(_BaseBlockQuery):
185+
class BlockQuery(_BaseBlockQuery, _BaseQuery[BlockAPI]):
184186
"""
185187
A ``QueryType`` that collects properties of ``BlockAPI`` over a range of
186188
blocks between ``start_block`` and ``stop_block``.
187189
"""
188190

189191

190-
class BlockTransactionQuery(_BaseQuery):
192+
class BlockTransactionQuery(_BaseQuery[TransactionAPI]):
191193
"""
192194
A ``QueryType`` that collects properties of ``TransactionAPI`` over a range of
193195
transactions collected inside the ``BlockAPI` object represented by ``block_id``.
194196
"""
195197

198+
Model = TransactionAPI
199+
196200
block_id: Any
197201
num_transactions: NonNegativeInt
198202

@@ -205,18 +209,19 @@ def end_index(self) -> int:
205209
return self.num_transactions - 1
206210

207211

208-
class AccountTransactionQuery(_BaseQuery):
212+
class AccountTransactionQuery(_BaseQuery[TransactionAPI]):
209213
"""
210214
A ``QueryType`` that collects properties of ``TransactionAPI`` over a range
211215
of transactions made by ``account`` between ``start_nonce`` and ``stop_nonce``.
212216
"""
213217

218+
Model = TransactionAPI
219+
214220
account: AddressType
215221
start_nonce: NonNegativeInt = 0
216222
stop_nonce: NonNegativeInt
217223

218224
@model_validator(mode="before")
219-
@classmethod
220225
def check_start_nonce_before_stop_nonce(cls, values: dict) -> dict:
221226
if values["stop_nonce"] < values["start_nonce"]:
222227
raise ValueError(
@@ -235,16 +240,7 @@ def end_index(self) -> int:
235240
return self.stop_nonce
236241

237242

238-
class ContractCreationQuery(_BaseQuery):
239-
"""
240-
A ``QueryType`` that obtains information about contract deployment.
241-
Returns ``ContractCreation(txn_hash, block, deployer, factory)``.
242-
"""
243-
244-
contract: AddressType
245-
246-
247-
class ContractCreation(BaseModel, BaseInterface):
243+
class ContractCreation(BaseInterfaceModel):
248244
"""
249245
Contract-creation metadata, such as the transaction
250246
and deployer. Useful for contract-verification,
@@ -303,18 +299,31 @@ def from_receipt(cls, receipt: ReceiptAPI) -> "ContractCreation":
303299
)
304300

305301

306-
class ContractEventQuery(_BaseBlockQuery):
302+
class ContractCreationQuery(_BaseQuery[ContractCreation]):
303+
"""
304+
A ``QueryType`` that obtains information about contract deployment.
305+
Returns ``ContractCreation(txn_hash, block, deployer, factory)``.
306+
"""
307+
308+
Model = ContractCreation
309+
310+
contract: AddressType
311+
312+
313+
class ContractEventQuery(_BaseBlockQuery, _BaseQuery[ContractLog]):
307314
"""
308315
A ``QueryType`` that collects members from ``event`` over a range of
309316
logs emitted by ``contract`` between ``start_block`` and ``stop_block``.
310317
"""
311318

319+
Model = ContractLog
320+
312321
contract: Union[list[AddressType], AddressType]
313322
event: EventABI
314323
search_topics: Optional[dict[str, Any]] = None
315324

316325

317-
class ContractMethodQuery(_BaseBlockQuery):
326+
class ContractMethodQuery(_BaseBlockQuery, _BaseQuery[Any]):
318327
"""
319328
A ``QueryType`` that collects return values from calling ``method`` in ``contract``
320329
over a range of blocks between ``start_block`` and ``stop_block``.
@@ -325,11 +334,8 @@ class ContractMethodQuery(_BaseBlockQuery):
325334
method_args: dict[str, Any]
326335

327336

328-
QueryType = TypeVar("QueryType", bound=_BaseQuery)
329-
330-
331-
class BaseCursorAPI(BaseInterfaceModel, Generic[QueryType, ModelType]):
332-
query: QueryType
337+
class BaseCursorAPI(BaseInterfaceModel, Generic[ModelType]):
338+
query: _BaseQuery[ModelType]
333339

334340
@abstractmethod
335341
def shrink(
@@ -408,6 +414,16 @@ def as_model_iter(self) -> Iterator[ModelType]:
408414
"""
409415

410416

417+
QueryType = Union[
418+
AccountTransactionQuery,
419+
BlockQuery,
420+
BlockTransactionQuery,
421+
ContractCreationQuery,
422+
ContractEventQuery,
423+
ContractMethodQuery,
424+
]
425+
426+
411427
class QueryEngineAPI(BaseInterface):
412428
@singledispatchmethod
413429
def exec(self, query: QueryType) -> Iterator[BaseCursorAPI]:

src/ape/managers/query.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434

3535
except ImportError:
3636
# TODO: Remove when 3.9 dropped (`itertools.pairwise` introduced in 3.10)
37-
from more_itertools import pairwise # type: ignore[no-redef,assignment]
37+
from more_itertools import pairwise # type: ignore[import-not-found,no-redef,assignment]
3838

3939

4040
if TYPE_CHECKING:
@@ -268,8 +268,8 @@ def perform_account_transactions_query(
268268
)
269269

270270

271-
class QueryResult(BaseCursorAPI):
272-
cursors: list[BaseCursorAPI]
271+
class QueryResult(BaseCursorAPI[ModelType]):
272+
cursors: list[BaseCursorAPI[ModelType]]
273273
"""The optimal set of cursors (in sorted order) that fulfill this query."""
274274

275275
@model_validator(mode="after")
@@ -320,13 +320,11 @@ def as_dataframe(
320320
backend: Union[str, nw.Implementation, None] = None,
321321
) -> "Frame":
322322
if backend is None:
323-
backend = cast(nw.Implementation, self.config_manager.config.query.backend)
323+
backend = cast(nw.Implementation, self.config_manager.query.backend)
324324

325325
elif isinstance(backend, str):
326326
backend = nw.Implementation.from_backend(backend)
327327

328-
assert isinstance(backend, str)
329-
330328
# TODO: Source `backend` from core `query:` config if defaulted to `None`
331329
return nw.concat([c.as_dataframe(backend=backend) for c in self.cursors], how="vertical")
332330

@@ -459,7 +457,7 @@ def _experimental_query(
459457
)
460458

461459
logger.debug("Sorted cursors:\n " + "\n ".join(map(str, all_cursors)))
462-
result = QueryResult(
460+
result: QueryResult = QueryResult(
463461
query=query,
464462
cursors=list(self._solve_optimal_coverage(query, all_cursors)),
465463
)

0 commit comments

Comments
 (0)