Skip to content

Commit 4b7aab2

Browse files
Operation pattern matching improvements (#43)
1 parent ead7c6b commit 4b7aab2

File tree

9 files changed

+223
-72
lines changed

9 files changed

+223
-72
lines changed

src/demo_quipuswap/dipdup.yml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,9 @@ templates:
7474
- type: transaction
7575
destination: <dex_contract>
7676
entrypoint: withdrawProfit
77+
- type: transaction
78+
source: <dex_contract>
79+
optional: True
7780

7881
quipuswap_fa2:
7982
kind: operation
@@ -118,6 +121,9 @@ templates:
118121
- type: transaction
119122
destination: <dex_contract>
120123
entrypoint: withdrawProfit
124+
- type: transaction
125+
source: <dex_contract>
126+
optional: True
121127

122128
indexes:
123129
kusd_mainnet:
Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,26 @@
11
from decimal import Decimal
2+
from typing import Optional
23

34
import demo_quipuswap.models as models
45
from demo_quipuswap.types.quipu_fa12.parameter.withdraw_profit import WithdrawProfitParameter
56
from demo_quipuswap.types.quipu_fa12.storage import QuipuFa12Storage
6-
from dipdup.models import OperationHandlerContext, TransactionContext
7+
from dipdup.models import OperationData, OperationHandlerContext, OriginationContext, TransactionContext
78

89

910
async def on_fa12_withdraw_profit(
1011
ctx: OperationHandlerContext,
1112
withdraw_profit: TransactionContext[WithdrawProfitParameter, QuipuFa12Storage],
13+
transaction_0: Optional[OperationData],
1214
) -> None:
13-
1415
if ctx.template_values is None:
1516
raise Exception('This index must be templated')
1617

1718
symbol = ctx.template_values['symbol']
1819
trader = withdraw_profit.data.sender_address
1920

2021
position, _ = await models.Position.get_or_create(trader=trader, symbol=symbol)
21-
transaction = next(op for op in ctx.operations if op.amount)
22-
23-
assert transaction.amount is not None
24-
position.realized_pl += Decimal(transaction.amount) / (10 ** 6) # type: ignore
22+
if transaction_0:
23+
assert transaction_0.amount is not None
24+
position.realized_pl += Decimal(transaction_0.amount) / (10 ** 6) # type: ignore
2525

26-
await position.save()
26+
await position.save()
Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
11
from decimal import Decimal
2+
from typing import Optional
23

34
import demo_quipuswap.models as models
45
from demo_quipuswap.types.quipu_fa2.parameter.withdraw_profit import WithdrawProfitParameter
56
from demo_quipuswap.types.quipu_fa2.storage import QuipuFa2Storage
6-
from dipdup.models import OperationHandlerContext, TransactionContext
7+
from dipdup.models import OperationData, OperationHandlerContext, OriginationContext, TransactionContext
78

89

910
async def on_fa20_withdraw_profit(
1011
ctx: OperationHandlerContext,
1112
withdraw_profit: TransactionContext[WithdrawProfitParameter, QuipuFa2Storage],
13+
transaction_0: Optional[OperationData],
1214
) -> None:
1315

1416
if ctx.template_values is None:
@@ -18,9 +20,9 @@ async def on_fa20_withdraw_profit(
1820
trader = withdraw_profit.data.sender_address
1921

2022
position, _ = await models.Position.get_or_create(trader=trader, symbol=symbol)
21-
transaction = next(op for op in ctx.operations if op.amount)
2223

23-
assert transaction.amount is not None
24-
position.realized_pl += Decimal(transaction.amount) / (10 ** 6) # type: ignore
24+
if transaction_0:
25+
assert transaction_0.amount is not None
26+
position.realized_pl += Decimal(transaction_0.amount) / (10 ** 6) # type: ignore
2527

26-
await position.save()
28+
await position.save()

src/dipdup/codegen.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
DipDupConfig,
1818
IndexTemplateConfig,
1919
OperationHandlerConfig,
20+
OperationHandlerOriginationPatternConfig,
2021
OperationHandlerTransactionPatternConfig,
2122
OperationIndexConfig,
2223
TzktDatasourceConfig,
@@ -91,7 +92,17 @@ async def fetch_schemas(config: DipDupConfig):
9192
if isinstance(index_config, OperationIndexConfig):
9293
for operation_handler_config in index_config.handlers:
9394
for operation_pattern_config in operation_handler_config.pattern:
94-
contract_config = operation_pattern_config.contract_config
95+
96+
if (
97+
isinstance(operation_pattern_config, OperationHandlerTransactionPatternConfig)
98+
and operation_pattern_config.entrypoint
99+
):
100+
contract_config = operation_pattern_config.destination_contract_config
101+
elif isinstance(operation_pattern_config, OperationHandlerOriginationPatternConfig):
102+
contract_config = operation_pattern_config.contract_config
103+
else:
104+
continue
105+
95106
contract_schemas = await schemas_cache.get(index_config.datasource_config, contract_config)
96107

97108
contract_schemas_path = join(schemas_path, contract_config.module_name)

src/dipdup/config.py

Lines changed: 109 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,13 @@
66
import re
77
import sys
88
from collections import defaultdict
9-
from dataclasses import field
109
from enum import Enum
1110
from os import environ as env
1211
from os.path import dirname
1312
from typing import Any, Callable, Dict, List, Optional, Type, Union, cast
1413
from urllib.parse import urlparse
1514

16-
from pydantic import Field, validator
15+
from pydantic import validator
1716
from pydantic.dataclasses import dataclass
1817
from pydantic.json import pydantic_encoder
1918
from ruamel.yaml import YAML
@@ -155,21 +154,35 @@ class OperationHandlerTransactionPatternConfig:
155154
:param entrypoint: Contract entrypoint
156155
"""
157156

158-
type: Literal['transaction']
159-
destination: Union[str, ContractConfig]
160-
entrypoint: str
157+
type: Literal['transaction'] = 'transaction'
158+
source: Optional[Union[str, ContractConfig]] = None
159+
destination: Optional[Union[str, ContractConfig]] = None
160+
entrypoint: Optional[str] = None
161+
optional: bool = False
161162

162163
def __post_init_post_parse__(self):
164+
if self.entrypoint and not self.destination:
165+
raise ConfigurationError('Transactions with entrypoint must also have destination')
163166
self._parameter_type_cls = None
164167
self._storage_type_cls = None
168+
self._transaction_id = None
165169

166170
@property
167-
def contract_config(self) -> ContractConfig:
168-
assert isinstance(self.destination, ContractConfig)
171+
def source_contract_config(self) -> ContractConfig:
172+
if not isinstance(self.source, ContractConfig):
173+
raise RuntimeError('Config is not initialized')
174+
return self.source
175+
176+
@property
177+
def destination_contract_config(self) -> ContractConfig:
178+
if not isinstance(self.destination, ContractConfig):
179+
raise RuntimeError('Config is not initialized')
169180
return self.destination
170181

171182
@property
172183
def parameter_type_cls(self) -> Optional[Type]:
184+
if not self.entrypoint:
185+
raise RuntimeError('entrypoint is empty')
173186
if self._parameter_type_cls is None:
174187
raise RuntimeError('Config is not initialized')
175188
return self._parameter_type_cls
@@ -180,6 +193,8 @@ def parameter_type_cls(self, typ: Type) -> None:
180193

181194
@property
182195
def storage_type_cls(self) -> Type:
196+
if not self.entrypoint:
197+
raise RuntimeError('entrypoint is empty')
183198
if self._storage_type_cls is None:
184199
raise RuntimeError('Config is not initialized')
185200
return self._storage_type_cls
@@ -188,22 +203,51 @@ def storage_type_cls(self) -> Type:
188203
def storage_type_cls(self, typ: Type) -> None:
189204
self._storage_type_cls = typ
190205

206+
@property
207+
def transaction_id(self) -> int:
208+
if self._transaction_id is None:
209+
raise RuntimeError('Config is not initialized')
210+
return self._transaction_id
211+
212+
@transaction_id.setter
213+
def transaction_id(self, id_: int) -> None:
214+
self._transaction_id = id_
215+
191216
def get_handler_imports(self, package: str) -> str:
192-
return '\n'.join(
193-
[
194-
f'from {package}.types.{self.contract_config.module_name}.parameter.{camel_to_snake(self.entrypoint)} import {snake_to_camel(self.entrypoint)}Parameter',
195-
f'from {package}.types.{self.contract_config.module_name}.storage import {snake_to_camel(self.contract_config.module_name)}Storage',
196-
]
197-
)
217+
if self.entrypoint:
218+
module_name = self.destination_contract_config.module_name
219+
entrypoint = camel_to_snake(self.entrypoint)
220+
parameter_cls = f'{snake_to_camel(self.entrypoint)}Parameter'
221+
storage_cls = f'{snake_to_camel(module_name)}Storage'
222+
return '\n'.join(
223+
[
224+
f'from {package}.types.{module_name}.parameter.{entrypoint} import {parameter_cls}',
225+
f'from {package}.types.{module_name}.storage import {storage_cls}',
226+
]
227+
)
228+
else:
229+
return ''
198230

199231
def get_handler_argument(self) -> str:
200-
return f'{camel_to_snake(self.entrypoint)}: TransactionContext[{snake_to_camel(self.entrypoint)}Parameter, {snake_to_camel(self.contract_config.module_name)}Storage],'
232+
if self.entrypoint:
233+
module_name = self.destination_contract_config.module_name
234+
entrypoint = camel_to_snake(self.entrypoint)
235+
parameter_cls = f'{snake_to_camel(self.entrypoint)}Parameter'
236+
storage_cls = f'{snake_to_camel(module_name)}Storage'
237+
if self.optional:
238+
return f'{entrypoint}: Optional[TransactionContext[{parameter_cls}, {storage_cls}]],'
239+
return f'{entrypoint}: TransactionContext[{parameter_cls}, {storage_cls}],'
240+
else:
241+
if self.optional:
242+
return f'transaction_{self._transaction_id}: Optional[OperationData],'
243+
return f'transaction_{self._transaction_id}: OperationData,'
201244

202245

203246
@dataclass
204247
class OperationHandlerOriginationPatternConfig:
205-
type: Literal['origination']
206248
originated_contract: Union[str, ContractConfig]
249+
type: Literal['origination'] = 'origination'
250+
optional: bool = False
207251

208252
def __post_init_post_parse__(self):
209253
self._storage_type_cls = None
@@ -228,10 +272,16 @@ def storage_type_cls(self, typ: Type) -> None:
228272
self._storage_type_cls = typ
229273

230274
def get_handler_imports(self, package: str) -> str:
231-
return f'from {package}.types.{self.contract_config.module_name}.storage import {snake_to_camel(self.contract_config.module_name)}Storage'
275+
module_name = self.contract_config.module_name
276+
storage_cls = f'{snake_to_camel(module_name)}Storage'
277+
return f'from {package}.types.{module_name}.storage import {storage_cls}'
232278

233279
def get_handler_argument(self) -> str:
234-
return f'{self.contract_config.module_name}_origination: OriginationContext[{snake_to_camel(self.contract_config.module_name)}Storage],'
280+
module_name = self.contract_config.module_name
281+
storage_cls = f'{snake_to_camel(module_name)}Storage'
282+
if self.optional:
283+
return f'{module_name}_origination: Optional[OriginationContext[{storage_cls}]],'
284+
return f'{module_name}_origination: OriginationContext[{storage_cls}],'
235285

236286

237287
OperationHandlerPatternConfig = Union[OperationHandlerOriginationPatternConfig, OperationHandlerTransactionPatternConfig]
@@ -481,6 +531,7 @@ def __post_init_post_parse__(self):
481531
except KeyError as e:
482532
raise ConfigurationError(f'Contract `{contract}` not found in `contracts` config section') from e
483533

534+
transaction_id = 0
484535
for handler_config in index_config.handlers:
485536
callback_patterns[handler_config.callback].append(handler_config.pattern)
486537
for pattern_config in handler_config.pattern:
@@ -492,6 +543,17 @@ def __post_init_post_parse__(self):
492543
raise ConfigurationError(
493544
f'Contract `{pattern_config.destination}` not found in `contracts` config section'
494545
) from e
546+
if isinstance(pattern_config.source, str):
547+
try:
548+
pattern_config.source = self.contracts[pattern_config.source]
549+
except KeyError as e:
550+
raise ConfigurationError(
551+
f'Contract `{pattern_config.source}` not found in `contracts` config section'
552+
) from e
553+
if not pattern_config.entrypoint:
554+
pattern_config.transaction_id = transaction_id
555+
transaction_id += 1
556+
495557
elif isinstance(pattern_config, OperationHandlerOriginationPatternConfig):
496558
if isinstance(pattern_config.originated_contract, str):
497559
try:
@@ -522,7 +584,13 @@ def __post_init_post_parse__(self):
522584
if len(patterns) > 1:
523585

524586
def get_pattern_type(pattern: List[OperationHandlerPatternConfig]):
525-
return '::'.join(map(lambda x: x.contract_config.module_name, pattern))
587+
module_names = []
588+
for pattern_config in pattern:
589+
if isinstance(pattern_config, OperationHandlerTransactionPatternConfig) and pattern_config.entrypoint:
590+
module_names.append(pattern_config.destination_contract_config.module_name)
591+
elif isinstance(pattern_config, OperationHandlerOriginationPatternConfig):
592+
module_names.append(pattern_config.contract_config.module_name)
593+
return '::'.join(module_names)
526594

527595
pattern_types = list(map(get_pattern_type, patterns))
528596
if any(map(lambda x: x != pattern_types[0], pattern_types)):
@@ -616,11 +684,14 @@ async def initialize(self) -> None:
616684

617685
for operation_pattern_config in operation_handler_config.pattern:
618686
if isinstance(operation_pattern_config, OperationHandlerTransactionPatternConfig):
687+
if not operation_pattern_config.entrypoint:
688+
continue
689+
619690
_logger.info('Registering parameter type for entrypoint `%s`', operation_pattern_config.entrypoint)
620691
parameter_type_module = importlib.import_module(
621692
f'{self.package}'
622693
f'.types'
623-
f'.{operation_pattern_config.contract_config.module_name}'
694+
f'.{operation_pattern_config.destination_contract_config.module_name}'
624695
f'.parameter'
625696
f'.{camel_to_snake(operation_pattern_config.entrypoint)}'
626697
)
@@ -629,14 +700,25 @@ async def initialize(self) -> None:
629700
)
630701
operation_pattern_config.parameter_type_cls = parameter_type_cls
631702

632-
_logger.info('Registering storage type')
633-
storage_type_module = importlib.import_module(
634-
f'{self.package}.types.{operation_pattern_config.contract_config.module_name}.storage'
635-
)
636-
storage_type_cls = getattr(
637-
storage_type_module, snake_to_camel(operation_pattern_config.contract_config.module_name) + 'Storage'
638-
)
639-
operation_pattern_config.storage_type_cls = storage_type_cls
703+
_logger.info('Registering storage type')
704+
storage_type_module = importlib.import_module(
705+
f'{self.package}.types.{operation_pattern_config.destination_contract_config.module_name}.storage'
706+
)
707+
storage_type_cls = getattr(
708+
storage_type_module,
709+
snake_to_camel(operation_pattern_config.destination_contract_config.module_name) + 'Storage',
710+
)
711+
operation_pattern_config.storage_type_cls = storage_type_cls
712+
713+
elif isinstance(operation_handler_config, OperationHandlerOriginationPatternConfig):
714+
_logger.info('Registering storage type')
715+
storage_type_module = importlib.import_module(
716+
f'{self.package}.types.{operation_pattern_config.contract_config.module_name}.storage'
717+
)
718+
storage_type_cls = getattr(
719+
storage_type_module, snake_to_camel(operation_pattern_config.contract_config.module_name) + 'Storage'
720+
)
721+
operation_pattern_config.storage_type_cls = storage_type_cls
640722

641723
elif isinstance(index_config, BigMapIndexConfig):
642724
for big_map_handler_config in index_config.handlers:

0 commit comments

Comments
 (0)