|
5 | 5 |
|
6 | 6 | from nonebot import logger |
7 | 7 | from nonebot_plugin_orm import AsyncSession |
8 | | -from sqlalchemy import delete, insert, select, update |
| 8 | +from sqlalchemy import delete, select, update |
9 | 9 |
|
10 | 10 | from .exception import ( |
11 | 11 | AccountFrozen, |
@@ -40,18 +40,19 @@ async def get_or_create_currency( |
40 | 40 | ) |
41 | 41 | ) |
42 | 42 | if (currency := stmt.scalars().first()) is not None: |
| 43 | + session.add(currency) |
43 | 44 | return currency, True |
44 | | - await self.createcurrency(currency_data) |
45 | | - result = await self.get_currency(currency_data.id) |
46 | | - assert result is not None |
| 45 | + result = await self.createcurrency(currency_data) |
47 | 46 | return result, False |
48 | 47 |
|
49 | | - async def createcurrency(self, currency_data: CurrencyData): |
| 48 | + async def createcurrency(self, currency_data: CurrencyData) -> CurrencyMeta: |
50 | 49 | async with self.session as session: |
51 | 50 | """创建新货币""" |
52 | | - stmt = insert(CurrencyMeta).values(**dict(currency_data)) |
53 | | - await session.execute(stmt) |
| 51 | + currency = CurrencyMeta(**currency_data.model_dump()) |
| 52 | + session.add(currency) |
54 | 53 | await session.commit() |
| 54 | + await session.refresh(currency) |
| 55 | + return currency |
55 | 56 |
|
56 | 57 | async def update_currency(self, currency_data: CurrencyData) -> CurrencyMeta: |
57 | 58 | """更新货币信息""" |
@@ -162,6 +163,7 @@ async def get_or_create_account( |
162 | 163 | ) |
163 | 164 | session.add(account) |
164 | 165 | await session.commit() |
| 166 | + await session.refresh(account) |
165 | 167 | return account |
166 | 168 | except Exception: |
167 | 169 | await session.rollback() |
@@ -345,18 +347,9 @@ async def create_transaction( |
345 | 347 | ) |
346 | 348 | session.add(transaction_data) |
347 | 349 | await session.commit() |
348 | | - stmt = ( |
349 | | - select(Transaction) |
350 | | - .where( |
351 | | - Transaction.id == uuid, |
352 | | - ) |
353 | | - .with_for_update() |
354 | | - ) |
355 | | - result = await session.execute(stmt) |
356 | | - transaction = result.scalar_one_or_none() |
357 | | - assert transaction, f"无法读取到交易记录[... WHERE id = {uuid} ...]!" |
358 | | - session.add(transaction) |
359 | | - return transaction |
| 350 | + await session.refresh(transaction_data) |
| 351 | + session.add(transaction_data) |
| 352 | + return transaction_data |
360 | 353 |
|
361 | 354 | async def get_transaction_history( |
362 | 355 | self, account_id: str, limit: int = 100 |
|
0 commit comments