Skip to content

Commit b5fd8d0

Browse files
Refactor repository, add currency query by kwargs, account freeze (#80)
* 添加frozen,更新API * 更改逻辑 * 移除不必要import * 更新逻辑 * Update nonebot_plugin_value/repositories/currency.py Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com> * Update nonebot_plugin_value/services/currency.py Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com> * Update nonebot_plugin_value/repositories/currency.py Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com> * 更新版本 --------- Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>
1 parent f76de47 commit b5fd8d0

File tree

8 files changed

+488
-409
lines changed

8 files changed

+488
-409
lines changed

nonebot_plugin_value/api/api_currency.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from ..pyd_models.currency_pyd import CurrencyData
44
from ..services.currency import create_currency as _create_currency
55
from ..services.currency import get_currency as _g_currency
6+
from ..services.currency import get_currency_by_kwargs as __currency_by_kwargs
67
from ..services.currency import get_default_currency as _default_currency
78
from ..services.currency import get_or_create_currency as _get_or_create_currency
89
from ..services.currency import list_currencies as _currencies
@@ -47,7 +48,6 @@ async def list_currencies() -> list[CurrencyData]:
4748
for currency in currencies
4849
]
4950

50-
5151
async def get_currency(currency_id: str) -> CurrencyData | None:
5252
"""获取一个货币信息
5353
@@ -64,6 +64,22 @@ async def get_currency(currency_id: str) -> CurrencyData | None:
6464
return CurrencyData.model_validate(currency, from_attributes=True)
6565

6666

67+
async def get_currency_by_kwargs(**kwargs: object) -> CurrencyData | None:
68+
"""获取一个货币信息
69+
70+
Args:
71+
**kwargs (object): 通过货币属性联合查询获取货币信息
72+
73+
Returns:
74+
CurrencyData | None: 货币数据,如果不存在则返回None
75+
"""
76+
async with get_session() as session:
77+
currency = await __currency_by_kwargs(**kwargs, session=session)
78+
if currency is None:
79+
return None
80+
return CurrencyData.model_validate(currency, from_attributes=True)
81+
82+
6783
async def get_default_currency() -> CurrencyData:
6884
"""获取默认货币的信息
6985

nonebot_plugin_value/pyd_models/balance_pyd.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,4 @@ class UserAccountData(BaseData):
1111
currency_id: str = Field(default="")
1212
balance: float = Field(default=0.0)
1313
last_updated: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
14+
frozen:bool = Field(default=False)
Lines changed: 206 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,206 @@
1+
from collections.abc import Sequence
2+
from datetime import datetime, timezone
3+
4+
from nonebot_plugin_orm import AsyncSession
5+
from sqlalchemy import delete, select
6+
7+
from ..exception import (
8+
AccountFrozen,
9+
AccountNotFound,
10+
CurrencyNotFound,
11+
TransactionException,
12+
)
13+
from ..models.balance import UserAccount
14+
from ..models.currency import CurrencyMeta
15+
from ..uuid_lib import get_uni_id
16+
17+
18+
class AccountRepository:
19+
"""账户操作"""
20+
21+
def __init__(self, session: AsyncSession):
22+
self.session = session
23+
24+
async def get_or_create_account(
25+
self, user_id: str, currency_id: str
26+
) -> UserAccount:
27+
async with self.session as session:
28+
"""获取或创建用户账户"""
29+
try:
30+
# 获取货币配置
31+
stmt = select(CurrencyMeta).where(CurrencyMeta.id == currency_id)
32+
result = await session.execute(stmt)
33+
currency = result.scalar_one_or_none()
34+
if currency is None:
35+
raise CurrencyNotFound(f"Currency {currency_id} not found")
36+
37+
# 检查账户是否存在
38+
stmt = (
39+
select(UserAccount)
40+
.where(UserAccount.uni_id == get_uni_id(user_id, currency_id))
41+
.with_for_update()
42+
)
43+
result = await session.execute(stmt)
44+
account = result.scalar_one_or_none()
45+
46+
if account is not None:
47+
session.add(account)
48+
return account
49+
50+
session.add(currency)
51+
account = UserAccount(
52+
uni_id=get_uni_id(user_id, currency_id),
53+
id=user_id,
54+
currency_id=currency_id,
55+
balance=currency.default_balance,
56+
last_updated=datetime.now(timezone.utc),
57+
)
58+
session.add(account)
59+
await session.commit()
60+
await session.refresh(account)
61+
return account
62+
except Exception:
63+
await session.rollback()
64+
raise
65+
66+
async def set_account_frozen(
67+
self,
68+
account_id: str,
69+
currency_id: str,
70+
frozen: bool,
71+
) -> None:
72+
"""设置账户冻结状态"""
73+
async with self.session as session:
74+
try:
75+
account = await self.get_or_create_account(account_id, currency_id)
76+
session.add(account)
77+
account.frozen = frozen
78+
except Exception:
79+
await session.rollback()
80+
raise
81+
else:
82+
await session.commit()
83+
84+
async def set_frozen_all(self, account_id: str, frozen: bool):
85+
async with self.session as session:
86+
try:
87+
result = await session.execute(
88+
select(UserAccount).where(UserAccount.id == account_id)
89+
)
90+
accounts = result.scalars().all()
91+
session.add_all(accounts)
92+
for account in accounts:
93+
account.frozen = frozen
94+
except Exception as e:
95+
await session.rollback()
96+
raise e
97+
else:
98+
await session.commit()
99+
100+
async def is_account_frozen(
101+
self,
102+
account_id: str,
103+
currency_id: str,
104+
) -> bool:
105+
"""判断账户是否冻结"""
106+
async with self.session:
107+
return (await self.get_or_create_account(account_id, currency_id)).frozen
108+
109+
async def get_balance(self, account_id: str, currency_id: str) -> float | None:
110+
"""获取账户余额"""
111+
uni_id = get_uni_id(account_id, currency_id)
112+
account = await self.session.get(UserAccount, uni_id)
113+
return account.balance if account else None
114+
115+
async def update_balance(
116+
self, account_id: str, amount: float, currency_id: str
117+
) -> tuple[float, float]:
118+
async with self.session as session:
119+
"""更新余额"""
120+
try:
121+
# 获取账户
122+
account = (
123+
await session.execute(
124+
select(UserAccount)
125+
.where(
126+
UserAccount.uni_id == get_uni_id(account_id, currency_id)
127+
)
128+
.with_for_update()
129+
)
130+
).scalar_one_or_none()
131+
132+
if account is None:
133+
raise AccountNotFound("Account not found")
134+
session.add(account)
135+
136+
if account.frozen:
137+
raise AccountFrozen(
138+
f"Account {account_id} on currency {currency_id} is frozen"
139+
)
140+
141+
# 获取货币规则
142+
currency = await session.get(CurrencyMeta, account.currency_id)
143+
session.add(currency)
144+
145+
# 负余额检查
146+
if amount < 0 and not getattr(currency, "allow_negative", False):
147+
raise TransactionException("Insufficient funds")
148+
149+
# 记录原始余额
150+
old_balance = account.balance
151+
152+
# 更新余额
153+
account.balance = amount
154+
await session.commit()
155+
156+
return old_balance, amount
157+
except Exception:
158+
await session.rollback()
159+
raise
160+
161+
async def list_accounts(
162+
self, currency_id: str | None = None
163+
) -> Sequence[UserAccount]:
164+
"""列出所有账户"""
165+
async with self.session as session:
166+
if not currency_id:
167+
result = await session.execute(select(UserAccount).with_for_update())
168+
else:
169+
result = await session.execute(
170+
select(UserAccount)
171+
.where(UserAccount.currency_id == currency_id)
172+
.with_for_update()
173+
)
174+
data = result.scalars().all()
175+
if len(data) > 0:
176+
session.add_all(data)
177+
return data
178+
179+
async def remove_account(self, account_id: str, currency_id: str | None = None):
180+
"""删除账户"""
181+
async with self.session as session:
182+
try:
183+
if not currency_id:
184+
stmt = (
185+
select(UserAccount)
186+
.where(UserAccount.id == account_id)
187+
.with_for_update()
188+
)
189+
else:
190+
stmt = (
191+
select(UserAccount)
192+
.where(
193+
UserAccount.uni_id == get_uni_id(account_id, currency_id)
194+
)
195+
.with_for_update()
196+
)
197+
accounts = (await session.execute(stmt)).scalars().all()
198+
if not accounts:
199+
raise AccountNotFound("Account not found")
200+
for account in accounts:
201+
stmt = delete(UserAccount).where(UserAccount.id == account.id)
202+
await session.execute(stmt)
203+
except Exception:
204+
await session.rollback()
205+
else:
206+
await session.commit()
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
# Repository,更加底层的数据库操作接口
2+
from collections.abc import Sequence
3+
from functools import singledispatch
4+
5+
from nonebot import logger
6+
from nonebot_plugin_orm import AsyncSession
7+
from sqlalchemy import delete, select, update
8+
9+
from ..exception import (
10+
CurrencyNotFound,
11+
)
12+
from ..models.currency import CurrencyMeta
13+
from ..pyd_models.currency_pyd import CurrencyData
14+
15+
16+
class CurrencyRepository:
17+
"""货币元数据操作"""
18+
19+
def __init__(self, session: AsyncSession):
20+
self.session = session
21+
22+
@singledispatch
23+
async def get_currency(self, currency_id: str) -> CurrencyMeta | None:
24+
"""获取货币信息"""
25+
async with self.session as session:
26+
result = await self.session.execute(
27+
select(CurrencyMeta).where(CurrencyMeta.id == currency_id)
28+
)
29+
if currency_meta := result.scalar_one_or_none():
30+
session.add(currency_meta)
31+
return currency_meta
32+
return None
33+
34+
async def get_currency_by_kwargs(self, **kwargs: object) -> CurrencyMeta | None:
35+
"""获取货币信息"""
36+
async with self.session as session:
37+
result = await session.execute(
38+
select(CurrencyMeta).where(
39+
*(
40+
getattr(CurrencyMeta, key) == value
41+
for key, value in kwargs.items()
42+
if hasattr(CurrencyMeta, key)
43+
)
44+
)
45+
)
46+
if currency_meta := result.scalar_one_or_none():
47+
session.add(currency_meta)
48+
return currency_meta
49+
return None
50+
51+
async def get_or_create_currency(
52+
self, currency_data: CurrencyData
53+
) -> tuple[CurrencyMeta, bool]:
54+
"""获取或创建货币"""
55+
async with self.session as session:
56+
stmt = await session.execute(
57+
select(CurrencyMeta).where(
58+
CurrencyMeta.id == currency_data.id,
59+
)
60+
)
61+
if (currency := stmt.scalars().first()) is not None:
62+
session.add(currency)
63+
return currency, True
64+
result = await self.createcurrency(currency_data)
65+
return result, False
66+
67+
async def createcurrency(self, currency_data: CurrencyData) -> CurrencyMeta:
68+
async with self.session as session:
69+
"""创建新货币"""
70+
currency = CurrencyMeta(**currency_data.model_dump())
71+
session.add(currency)
72+
await session.commit()
73+
await session.refresh(currency)
74+
return currency
75+
76+
async def update_currency(self, currency_data: CurrencyData) -> CurrencyMeta:
77+
"""更新货币信息"""
78+
async with self.session as session:
79+
try:
80+
stmt = (
81+
update(CurrencyMeta)
82+
.where(CurrencyMeta.id == currency_data.id)
83+
.values(**dict(currency_data))
84+
)
85+
await session.execute(stmt)
86+
await session.commit()
87+
stmt = (
88+
select(CurrencyMeta)
89+
.where(CurrencyMeta.id == currency_data.id)
90+
.with_for_update()
91+
)
92+
result = await session.execute(stmt)
93+
currency_meta = result.scalar_one()
94+
session.add(currency_meta)
95+
return currency_meta
96+
except Exception:
97+
await session.rollback()
98+
raise
99+
100+
101+
async def list_currencies(self) -> Sequence[CurrencyMeta]:
102+
"""列出所有货币"""
103+
async with self.session as session:
104+
result = await self.session.execute(select(CurrencyMeta))
105+
data = result.scalars().all()
106+
session.add_all(data)
107+
return data
108+
109+
async def remove_currency(self, currency_id: str):
110+
"""删除货币(警告!会同时删除所有关联账户!)"""
111+
async with self.session as session:
112+
currency = (
113+
await session.execute(
114+
select(CurrencyMeta)
115+
.where(CurrencyMeta.id == currency_id)
116+
.with_for_update()
117+
)
118+
).scalar()
119+
if currency is None:
120+
raise CurrencyNotFound(f"Currency {currency_id} not found")
121+
try:
122+
logger.warning(f"Deleting currency {currency_id}")
123+
stmt = delete(CurrencyMeta).where(CurrencyMeta.id == currency_id)
124+
await session.execute(stmt)
125+
except Exception:
126+
await session.rollback()
127+
raise
128+
else:
129+
await session.commit()

0 commit comments

Comments
 (0)