11from abc import abstractmethod
22from collections .abc import Iterator , Sequence
33from functools import cache , cached_property
4- from typing import TYPE_CHECKING , Any , ClassVar , Generic , Optional , TypeVar , Union
4+ from typing import TYPE_CHECKING , Any , ClassVar , Generic , Optional , Type , TypeVar , Union
55
66from ethpm_types .abi import EventABI , MethodABI
77from pydantic import NonNegativeInt , PositiveInt , field_validator , model_validator
88
99from ape .logging import logger
10+ from ape .types import ContractLog
1011from ape .types .address import AddressType
1112from ape .utils import singledispatchmethod
1213from ape .utils .basemodel import BaseInterface , BaseInterfaceModel , BaseModel
@@ -103,13 +104,14 @@ def extract_fields(item: BaseInterfaceModel, columns: Sequence[str]) -> list[Any
103104
104105
105106class _BaseQuery (BaseModel , Generic [ModelType ]):
106- Model : ClassVar [ModelType ]
107+ Model : ClassVar [Optional [Type [BaseInterfaceModel ]]] = None
108+
107109 columns : set [str ]
108110
109111 @field_validator ("columns" , mode = "before" )
110112 def expand_wildcard (cls , value : Any ) -> Any :
111- if isinstance ( value , str ) and value == "*" :
112- return _basic_columns ( cls .Model )
113+ if cls . Model :
114+ return validate_and_expand_columns ( value , cls .Model )
113115
114116 return value
115117
@@ -180,19 +182,21 @@ def end_index(self) -> int:
180182 return self .stop_block
181183
182184
183- class BlockQuery (_BaseBlockQuery ):
185+ class BlockQuery (_BaseBlockQuery , _BaseQuery [ BlockAPI ] ):
184186 """
185187 A ``QueryType`` that collects properties of ``BlockAPI`` over a range of
186188 blocks between ``start_block`` and ``stop_block``.
187189 """
188190
189191
190- class BlockTransactionQuery (_BaseQuery ):
192+ class BlockTransactionQuery (_BaseQuery [ TransactionAPI ] ):
191193 """
192194 A ``QueryType`` that collects properties of ``TransactionAPI`` over a range of
193195 transactions collected inside the ``BlockAPI` object represented by ``block_id``.
194196 """
195197
198+ Model = TransactionAPI
199+
196200 block_id : Any
197201 num_transactions : NonNegativeInt
198202
@@ -205,18 +209,19 @@ def end_index(self) -> int:
205209 return self .num_transactions - 1
206210
207211
208- class AccountTransactionQuery (_BaseQuery ):
212+ class AccountTransactionQuery (_BaseQuery [ TransactionAPI ] ):
209213 """
210214 A ``QueryType`` that collects properties of ``TransactionAPI`` over a range
211215 of transactions made by ``account`` between ``start_nonce`` and ``stop_nonce``.
212216 """
213217
218+ Model = TransactionAPI
219+
214220 account : AddressType
215221 start_nonce : NonNegativeInt = 0
216222 stop_nonce : NonNegativeInt
217223
218224 @model_validator (mode = "before" )
219- @classmethod
220225 def check_start_nonce_before_stop_nonce (cls , values : dict ) -> dict :
221226 if values ["stop_nonce" ] < values ["start_nonce" ]:
222227 raise ValueError (
@@ -235,16 +240,7 @@ def end_index(self) -> int:
235240 return self .stop_nonce
236241
237242
238- class ContractCreationQuery (_BaseQuery ):
239- """
240- A ``QueryType`` that obtains information about contract deployment.
241- Returns ``ContractCreation(txn_hash, block, deployer, factory)``.
242- """
243-
244- contract : AddressType
245-
246-
247- class ContractCreation (BaseModel , BaseInterface ):
243+ class ContractCreation (BaseInterfaceModel ):
248244 """
249245 Contract-creation metadata, such as the transaction
250246 and deployer. Useful for contract-verification,
@@ -303,18 +299,31 @@ def from_receipt(cls, receipt: ReceiptAPI) -> "ContractCreation":
303299 )
304300
305301
306- class ContractEventQuery (_BaseBlockQuery ):
302+ class ContractCreationQuery (_BaseQuery [ContractCreation ]):
303+ """
304+ A ``QueryType`` that obtains information about contract deployment.
305+ Returns ``ContractCreation(txn_hash, block, deployer, factory)``.
306+ """
307+
308+ Model = ContractCreation
309+
310+ contract : AddressType
311+
312+
313+ class ContractEventQuery (_BaseBlockQuery , _BaseQuery [ContractLog ]):
307314 """
308315 A ``QueryType`` that collects members from ``event`` over a range of
309316 logs emitted by ``contract`` between ``start_block`` and ``stop_block``.
310317 """
311318
319+ Model = ContractLog
320+
312321 contract : Union [list [AddressType ], AddressType ]
313322 event : EventABI
314323 search_topics : Optional [dict [str , Any ]] = None
315324
316325
317- class ContractMethodQuery (_BaseBlockQuery ):
326+ class ContractMethodQuery (_BaseBlockQuery , _BaseQuery [ Any ] ):
318327 """
319328 A ``QueryType`` that collects return values from calling ``method`` in ``contract``
320329 over a range of blocks between ``start_block`` and ``stop_block``.
@@ -325,11 +334,8 @@ class ContractMethodQuery(_BaseBlockQuery):
325334 method_args : dict [str , Any ]
326335
327336
328- QueryType = TypeVar ("QueryType" , bound = _BaseQuery )
329-
330-
331- class BaseCursorAPI (BaseInterfaceModel , Generic [QueryType , ModelType ]):
332- query : QueryType
337+ class BaseCursorAPI (BaseInterfaceModel , Generic [ModelType ]):
338+ query : _BaseQuery [ModelType ]
333339
334340 @abstractmethod
335341 def shrink (
@@ -408,6 +414,16 @@ def as_model_iter(self) -> Iterator[ModelType]:
408414 """
409415
410416
417+ QueryType = Union [
418+ AccountTransactionQuery ,
419+ BlockQuery ,
420+ BlockTransactionQuery ,
421+ ContractCreationQuery ,
422+ ContractEventQuery ,
423+ ContractMethodQuery ,
424+ ]
425+
426+
411427class QueryEngineAPI (BaseInterface ):
412428 @singledispatchmethod
413429 def exec (self , query : QueryType ) -> Iterator [BaseCursorAPI ]:
0 commit comments