Skip to content
Merged
18 changes: 17 additions & 1 deletion nonebot_plugin_value/api/api_currency.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from ..pyd_models.currency_pyd import CurrencyData
from ..services.currency import create_currency as _create_currency
from ..services.currency import get_currency as _g_currency
from ..services.currency import get_currency_by_kwargs as __currency_by_kwargs
from ..services.currency import get_default_currency as _default_currency
from ..services.currency import get_or_create_currency as _get_or_create_currency
from ..services.currency import list_currencies as _currencies
Expand Down Expand Up @@ -47,7 +48,6 @@ async def list_currencies() -> list[CurrencyData]:
for currency in currencies
]


async def get_currency(currency_id: str) -> CurrencyData | None:
"""获取一个货币信息

Expand All @@ -64,6 +64,22 @@ async def get_currency(currency_id: str) -> CurrencyData | None:
return CurrencyData.model_validate(currency, from_attributes=True)


async def get_currency_by_kwargs(**kwargs: object) -> CurrencyData | None:
"""获取一个货币信息

Args:
**kwargs (object): 通过货币属性联合查询获取货币信息

Returns:
CurrencyData | None: 货币数据,如果不存在则返回None
"""
async with get_session() as session:
currency = await __currency_by_kwargs(**kwargs, session=session)
if currency is None:
return None
return CurrencyData.model_validate(currency, from_attributes=True)


async def get_default_currency() -> CurrencyData:
"""获取默认货币的信息

Expand Down
1 change: 1 addition & 0 deletions nonebot_plugin_value/pyd_models/balance_pyd.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ class UserAccountData(BaseData):
currency_id: str = Field(default="")
balance: float = Field(default=0.0)
last_updated: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
frozen:bool = Field(default=False)
206 changes: 206 additions & 0 deletions nonebot_plugin_value/repositories/account.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
from collections.abc import Sequence
from datetime import datetime, timezone

from nonebot_plugin_orm import AsyncSession
from sqlalchemy import delete, select

from ..exception import (
AccountFrozen,
AccountNotFound,
CurrencyNotFound,
TransactionException,
)
from ..models.balance import UserAccount
from ..models.currency import CurrencyMeta
from ..uuid_lib import get_uni_id


class AccountRepository:
"""账户操作"""

def __init__(self, session: AsyncSession):
self.session = session

async def get_or_create_account(
self, user_id: str, currency_id: str
) -> UserAccount:
async with self.session as session:
"""获取或创建用户账户"""
try:
# 获取货币配置
stmt = select(CurrencyMeta).where(CurrencyMeta.id == currency_id)
result = await session.execute(stmt)
currency = result.scalar_one_or_none()
if currency is None:
raise CurrencyNotFound(f"Currency {currency_id} not found")

# 检查账户是否存在
stmt = (
select(UserAccount)
.where(UserAccount.uni_id == get_uni_id(user_id, currency_id))
.with_for_update()
)
result = await session.execute(stmt)
account = result.scalar_one_or_none()

if account is not None:
session.add(account)
return account

session.add(currency)
account = UserAccount(
uni_id=get_uni_id(user_id, currency_id),
id=user_id,
currency_id=currency_id,
balance=currency.default_balance,
last_updated=datetime.now(timezone.utc),
)
session.add(account)
await session.commit()
await session.refresh(account)
return account
except Exception:
await session.rollback()
raise

async def set_account_frozen(
self,
account_id: str,
currency_id: str,
frozen: bool,
) -> None:
"""设置账户冻结状态"""
async with self.session as session:
try:
account = await self.get_or_create_account(account_id, currency_id)
session.add(account)
account.frozen = frozen
except Exception:
await session.rollback()
raise
else:
await session.commit()

async def set_frozen_all(self, account_id: str, frozen: bool):
async with self.session as session:
try:
result = await session.execute(
select(UserAccount).where(UserAccount.id == account_id)
)
accounts = result.scalars().all()
session.add_all(accounts)
for account in accounts:
account.frozen = frozen
except Exception as e:
await session.rollback()
raise e
else:
await session.commit()

async def is_account_frozen(
self,
account_id: str,
currency_id: str,
) -> bool:
"""判断账户是否冻结"""
async with self.session:
return (await self.get_or_create_account(account_id, currency_id)).frozen

async def get_balance(self, account_id: str, currency_id: str) -> float | None:
"""获取账户余额"""
uni_id = get_uni_id(account_id, currency_id)
account = await self.session.get(UserAccount, uni_id)
return account.balance if account else None

async def update_balance(
self, account_id: str, amount: float, currency_id: str
) -> tuple[float, float]:
async with self.session as session:
"""更新余额"""
try:
# 获取账户
account = (
await session.execute(
select(UserAccount)
.where(
UserAccount.uni_id == get_uni_id(account_id, currency_id)
)
.with_for_update()
)
).scalar_one_or_none()

if account is None:
raise AccountNotFound("Account not found")
session.add(account)

if account.frozen:
raise AccountFrozen(
f"Account {account_id} on currency {currency_id} is frozen"
)

# 获取货币规则
currency = await session.get(CurrencyMeta, account.currency_id)
session.add(currency)

# 负余额检查
if amount < 0 and not getattr(currency, "allow_negative", False):
raise TransactionException("Insufficient funds")

# 记录原始余额
old_balance = account.balance

# 更新余额
account.balance = amount
await session.commit()

return old_balance, amount
except Exception:
await session.rollback()
raise

async def list_accounts(
self, currency_id: str | None = None
) -> Sequence[UserAccount]:
"""列出所有账户"""
async with self.session as session:
if not currency_id:
result = await session.execute(select(UserAccount).with_for_update())
else:
result = await session.execute(
select(UserAccount)
.where(UserAccount.currency_id == currency_id)
.with_for_update()
)
data = result.scalars().all()
if len(data) > 0:
session.add_all(data)
return data

async def remove_account(self, account_id: str, currency_id: str | None = None):
"""删除账户"""
async with self.session as session:
try:
if not currency_id:
stmt = (
select(UserAccount)
.where(UserAccount.id == account_id)
.with_for_update()
)
else:
stmt = (
select(UserAccount)
.where(
UserAccount.uni_id == get_uni_id(account_id, currency_id)
)
.with_for_update()
)
accounts = (await session.execute(stmt)).scalars().all()
if not accounts:
raise AccountNotFound("Account not found")
for account in accounts:
stmt = delete(UserAccount).where(UserAccount.id == account.id)
await session.execute(stmt)
except Exception:
await session.rollback()
else:
await session.commit()
131 changes: 131 additions & 0 deletions nonebot_plugin_value/repositories/currency.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
# Repository,更加底层的数据库操作接口
from collections.abc import Sequence
from functools import singledispatch

from nonebot import logger
from nonebot_plugin_orm import AsyncSession
from sqlalchemy import delete, select, update

from ..exception import (
CurrencyNotFound,
)
from ..models.currency import CurrencyMeta
from ..pyd_models.currency_pyd import CurrencyData


class CurrencyRepository:
"""货币元数据操作"""

def __init__(self, session: AsyncSession):
self.session = session

@singledispatch
async def get_currency(self, currency_id: str) -> CurrencyMeta | None:
"""获取货币信息"""
async with self.session as session:
result = await self.session.execute(
select(CurrencyMeta).where(CurrencyMeta.id == currency_id)
)
currency_meta = result.scalar_one_or_none()
if currency_meta:
session.add(currency_meta)
return currency_meta
return None

async def get_currency_by_kwargs(self, **kwargs: object) -> CurrencyMeta | None:
"""获取货币信息"""
async with self.session as session:
result = await session.execute(
select(CurrencyMeta).where(
*(
getattr(CurrencyMeta, key) == value
for key, value in kwargs.items()
if hasattr(CurrencyMeta, key)
)
)
)
currency_meta = result.scalar_one_or_none()
if currency_meta:
session.add(currency_meta)
return currency_meta
return None

async def get_or_create_currency(
self, currency_data: CurrencyData
) -> tuple[CurrencyMeta, bool]:
"""获取或创建货币"""
async with self.session as session:
stmt = await session.execute(
select(CurrencyMeta).where(
CurrencyMeta.id == currency_data.id,
)
)
if (currency := stmt.scalars().first()) is not None:
session.add(currency)
return currency, True
result = await self.createcurrency(currency_data)
return result, False

async def createcurrency(self, currency_data: CurrencyData) -> CurrencyMeta:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nitpick: 方法名“createcurrency”与 Python 命名约定不一致。

将“createcurrency”重命名为“create_currency”,以符合 Python 标准和代码库的约定。

Original comment in English

nitpick: Method name 'createcurrency' is inconsistent with Python naming conventions.

Rename 'createcurrency' to 'create_currency' for consistency with Python standards and the codebase.

async with self.session as session:
"""创建新货币"""
currency = CurrencyMeta(**currency_data.model_dump())
session.add(currency)
await session.commit()
await session.refresh(currency)
return currency
Comment on lines +64 to +74

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion: 方法名 'createcurrency' 不符合标准 Python 命名约定。

重命名为 'create_currency' 以符合 PEP8 并保持一致性。

Suggested change
result = await self.createcurrency(currency_data)
return result, False
async def createcurrency(self, currency_data: CurrencyData) -> CurrencyMeta:
async with self.session as session:
"""创建新货币"""
currency = CurrencyMeta(**currency_data.model_dump())
session.add(currency)
await session.commit()
await session.refresh(currency)
return currency
result = await self.create_currency(currency_data)
return result, False
async def create_currency(self, currency_data: CurrencyData) -> CurrencyMeta:
async with self.session as session:
"""创建新货币"""
currency = CurrencyMeta(**currency_data.model_dump())
session.add(currency)
await session.commit()
await session.refresh(currency)
return currency
Original comment in English

suggestion: Method name 'createcurrency' does not follow standard Python naming conventions.

Rename to 'create_currency' to align with PEP8 and maintain consistency.

Suggested change
result = await self.createcurrency(currency_data)
return result, False
async def createcurrency(self, currency_data: CurrencyData) -> CurrencyMeta:
async with self.session as session:
"""创建新货币"""
currency = CurrencyMeta(**currency_data.model_dump())
session.add(currency)
await session.commit()
await session.refresh(currency)
return currency
result = await self.create_currency(currency_data)
return result, False
async def create_currency(self, currency_data: CurrencyData) -> CurrencyMeta:
async with self.session as session:
"""创建新货币"""
currency = CurrencyMeta(**currency_data.model_dump())
session.add(currency)
await session.commit()
await session.refresh(currency)
return currency


async def update_currency(self, currency_data: CurrencyData) -> CurrencyMeta:
"""更新货币信息"""
async with self.session as session:
try:
stmt = (
update(CurrencyMeta)
.where(CurrencyMeta.id == currency_data.id)
.values(**dict(currency_data))
)
await session.execute(stmt)
await session.commit()
stmt = (
select(CurrencyMeta)
.where(CurrencyMeta.id == currency_data.id)
.with_for_update()
)
result = await session.execute(stmt)
currency_meta = result.scalar_one()
session.add(currency_meta)
return currency_meta
except Exception:
await session.rollback()
raise


async def list_currencies(self) -> Sequence[CurrencyMeta]:
"""列出所有货币"""
async with self.session as session:
result = await self.session.execute(select(CurrencyMeta))
data = result.scalars().all()
session.add_all(data)
return data

async def remove_currency(self, currency_id: str):
"""删除货币(警告!会同时删除所有关联账户!)"""
async with self.session as session:
currency = (
await session.execute(
select(CurrencyMeta)
.where(CurrencyMeta.id == currency_id)
.with_for_update()
)
).scalar()
if currency is None:
raise CurrencyNotFound(f"Currency {currency_id} not found")
try:
logger.warning(f"Deleting currency {currency_id}")
stmt = delete(CurrencyMeta).where(CurrencyMeta.id == currency_id)
await session.execute(stmt)
except Exception:
await session.rollback()
raise
else:
await session.commit()
Loading
Loading