99from .models .balance import Transaction , UserAccount
1010from .models .currency import CurrencyMeta
1111from .pyd_models .currency_pyd import CurrencyData
12- from .uuid_lib import NAMESPACE_VALUE
12+ from .uuid_lib import NAMESPACE_VALUE , get_uni_id
1313
1414DEFAULT_NAME = "DEFAULT_CURRENCY_USD"
1515DEFAULT_CURRENCY_UUID = uuid5 (NAMESPACE_VALUE , DEFAULT_NAME )
@@ -53,7 +53,7 @@ async def update_currency(self, currency_data: CurrencyData) -> CurrencyMeta:
5353 session .add (currency_meta )
5454 return currency_meta
5555
56- async def getcurrency (self , currency_id : str ) -> CurrencyMeta | None :
56+ async def get_currency (self , currency_id : str ) -> CurrencyMeta | None :
5757 """获取货币信息"""
5858 async with self .session as session :
5959 result = await self .session .execute (
@@ -117,10 +117,7 @@ async def get_or_create_account(
117117 # 检查账户是否存在
118118 stmt = (
119119 select (UserAccount )
120- .where (
121- UserAccount .id == user_id ,
122- UserAccount .currency_id == currency_id ,
123- )
120+ .where (UserAccount .uni_id == get_uni_id (user_id , currency_id ))
124121 .with_for_update ()
125122 )
126123 result = await session .execute (stmt )
@@ -132,7 +129,7 @@ async def get_or_create_account(
132129
133130 session .add (currency )
134131 account = UserAccount (
135- uni_id = uuid5 ( NAMESPACE_VALUE , f" { user_id } { currency_id } " ). hex ,
132+ uni_id = get_uni_id ( user_id , currency_id ) ,
136133 id = user_id ,
137134 currency_id = currency_id ,
138135 balance = currency .default_balance ,
@@ -142,17 +139,17 @@ async def get_or_create_account(
142139 await session .commit ()
143140
144141 stmt = select (UserAccount ).where (
145- UserAccount .id == user_id ,
146- UserAccount .currency_id == currency_id ,
142+ UserAccount .uni_id == get_uni_id (user_id , currency_id )
147143 )
148144 result = await session .execute (stmt )
149145 account = result .scalar_one ()
150146 session .add (account )
151147 return account
152148
153- async def get_balance (self , account_id : str ) -> float | None :
149+ async def get_balance (self , account_id : str , currency_id : str ) -> float | None :
154150 """获取账户余额"""
155- account = await self .session .get (UserAccount , account_id )
151+ uni_id = get_uni_id (account_id , currency_id )
152+ account = await self .session .get (UserAccount , uni_id )
156153 return account .balance if account else None
157154
158155 async def update_balance (
@@ -165,10 +162,7 @@ async def update_balance(
165162 account = (
166163 await session .execute (
167164 select (UserAccount )
168- .where (
169- UserAccount .id == account_id ,
170- UserAccount .currency_id == currency_id ,
171- )
165+ .where (UserAccount .uni_id == get_uni_id (account_id , currency_id ))
172166 .with_for_update ()
173167 )
174168 ).scalar_one_or_none ()
@@ -212,14 +206,21 @@ async def list_accounts(
212206 session .add_all (data )
213207 return data
214208
215- async def remove_account (self , account_id : str ):
209+ async def remove_account (self , account_id : str , currency_id : str | None = None ):
216210 """删除账户"""
217211 async with self .session as session :
218- stmt = (
219- select (UserAccount )
220- .where (UserAccount .id == account_id )
221- .with_for_update ()
222- )
212+ if not currency_id :
213+ stmt = (
214+ select (UserAccount )
215+ .where (UserAccount .id == account_id )
216+ .with_for_update ()
217+ )
218+ else :
219+ stmt = (
220+ select (UserAccount )
221+ .where (UserAccount .uni_id == get_uni_id (account_id , currency_id ))
222+ .with_for_update ()
223+ )
223224 accounts = (await session .execute (stmt )).scalars ().all ()
224225 if not accounts :
225226 raise ValueError ("Account not found" )
@@ -250,7 +251,7 @@ async def create_transaction(
250251 if timestamp is None :
251252 timestamp = datetime .now (timezone .utc )
252253 uuid = uuid1 ().hex
253- stmt = insert ( Transaction ). values (
254+ transaction_data = Transaction (
254255 id = uuid ,
255256 account_id = account_id ,
256257 currency_id = currency_id ,
@@ -261,18 +262,18 @@ async def create_transaction(
261262 balance_after = balance_after ,
262263 timestamp = timestamp ,
263264 )
264- await session .execute ( stmt )
265+ session .add ( transaction_data )
265266 await session .commit ()
266267 stmt = (
267268 select (Transaction )
268269 .where (
269270 Transaction .id == uuid ,
270- Transaction .timestamp == timestamp ,
271271 )
272272 .with_for_update ()
273273 )
274274 result = await session .execute (stmt )
275- transaction = result .scalars ().one ()
275+ transaction = result .scalar_one_or_none ()
276+ assert transaction , f"无法读取到交易记录[... WHERE id = { uuid } ...]!"
276277 session .add (transaction )
277278 return transaction
278279
0 commit comments