|
9 | 9 | from genericpath import exists |
10 | 10 | from tortoise import Tortoise |
11 | 11 | from tortoise.exceptions import OperationalError |
12 | | -from tortoise.transactions import in_transaction |
| 12 | +from tortoise.transactions import get_connection |
13 | 13 | from tortoise.utils import get_schema_sql |
14 | 14 |
|
15 | 15 | import dipdup.utils as utils |
|
30 | 30 | from dipdup.datasources import DatasourceT |
31 | 31 | from dipdup.datasources.bcd.datasource import BcdDatasource |
32 | 32 | from dipdup.datasources.tzkt.datasource import TzktDatasource |
33 | | -from dipdup.exceptions import HandlerImportError |
| 33 | +from dipdup.exceptions import ConfigurationError, HandlerImportError |
34 | 34 | from dipdup.hasura import configure_hasura |
35 | 35 | from dipdup.index import BigMapIndex, HandlerContext, Index, OperationIndex |
36 | 36 | from dipdup.models import BigMapData, IndexType, OperationData, State |
@@ -264,27 +264,39 @@ async def _initialize_database(self, reindex: bool = False) -> None: |
264 | 264 |
|
265 | 265 | if schema_state is None: |
266 | 266 | await Tortoise.generate_schemas() |
| 267 | + await self._execute_sql_scripts(reindex=True) |
| 268 | + |
267 | 269 | schema_state = State(index_type=IndexType.schema, index_name=connection_name, hash=schema_hash) |
268 | 270 | await schema_state.save() |
269 | 271 | elif schema_state.hash != schema_hash: |
270 | 272 | self._logger.warning('Schema hash mismatch, reindexing') |
271 | 273 | await self._ctx.reindex() |
272 | 274 |
|
| 275 | + await self._execute_sql_scripts(reindex=False) |
| 276 | + |
| 277 | + async def _execute_sql_scripts(self, reindex: bool) -> None: |
| 278 | + """Execute SQL included with project""" |
273 | 279 | sql_path = join(self._config.package_path, 'sql') |
274 | 280 | if not exists(sql_path): |
275 | 281 | return |
| 282 | + if any(map(lambda p: p not in ('on_reindex', 'on_restart'), listdir(sql_path))): |
| 283 | + raise ConfigurationError( |
| 284 | + f'SQL scripts must be placed either to `{self._config.package}/sql/on_restart` or to `{self._config.package}/sql/on_reindex` directory' |
| 285 | + ) |
276 | 286 | if not isinstance(self._config.database, PostgresDatabaseConfig): |
277 | | - self._logger.warning('Injecting raw SQL supported on PostgreSQL only') |
| 287 | + self._logger.warning('Execution of user SQL scripts is supported on PostgreSQL only, skipping') |
278 | 288 | return |
279 | 289 |
|
280 | | - for filename in listdir(sql_path): |
| 290 | + sql_path = join(sql_path, 'on_reindex' if reindex else 'on_restart') |
| 291 | + if not exists(sql_path): |
| 292 | + return |
| 293 | + self._logger.info('Executing SQL scripts from `%s`', sql_path) |
| 294 | + for filename in sorted(listdir(sql_path)): |
281 | 295 | if not filename.endswith('.sql'): |
282 | 296 | continue |
283 | 297 |
|
284 | 298 | with open(join(sql_path, filename)) as file: |
285 | 299 | sql = file.read() |
286 | 300 |
|
287 | | - self._logger.info('Applying raw SQL from `%s`', filename) |
288 | | - |
289 | | - async with in_transaction() as conn: |
290 | | - await conn.execute_query(sql) |
| 301 | + self._logger.info('Executing `%s`', filename) |
| 302 | + await get_connection(None).execute_script(sql) |
0 commit comments