Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion nonebot_plugin_value/api/api_currency.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from ..pyd_models.currency_pyd import CurrencyData
from ..services.currency import create_currency as _create_currency
from ..services.currency import get_currency as _g_currency
from ..services.currency import get_currency_by_kwargs as __currency_by_kwargs
from ..services.currency import get_default_currency as _default_currency
from ..services.currency import get_or_create_currency as _get_or_create_currency
from ..services.currency import list_currencies as _currencies
Expand Down Expand Up @@ -47,7 +48,6 @@ async def list_currencies() -> list[CurrencyData]:
for currency in currencies
]


async def get_currency(currency_id: str) -> CurrencyData | None:
"""获取一个货币信息

Expand All @@ -64,6 +64,22 @@ async def get_currency(currency_id: str) -> CurrencyData | None:
return CurrencyData.model_validate(currency, from_attributes=True)


async def get_currency_by_kwargs(**kwargs: object) -> CurrencyData | None:
"""获取一个货币信息

Args:
**kwargs (object): 通过货币属性联合查询获取货币信息

Returns:
CurrencyData | None: 货币数据,如果不存在则返回None
"""
async with get_session() as session:
currency = await __currency_by_kwargs(**kwargs, session=session)
if currency is None:
return None
return CurrencyData.model_validate(currency, from_attributes=True)


async def get_default_currency() -> CurrencyData:
"""获取默认货币的信息

Expand Down
1 change: 1 addition & 0 deletions nonebot_plugin_value/pyd_models/balance_pyd.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ class UserAccountData(BaseData):
currency_id: str = Field(default="")
balance: float = Field(default=0.0)
last_updated: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
frozen:bool = Field(default=False)
206 changes: 206 additions & 0 deletions nonebot_plugin_value/repositories/account.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
from collections.abc import Sequence
from datetime import datetime, timezone

from nonebot_plugin_orm import AsyncSession
from sqlalchemy import delete, select

from ..exception import (
AccountFrozen,
AccountNotFound,
CurrencyNotFound,
TransactionException,
)
from ..models.balance import UserAccount
from ..models.currency import CurrencyMeta
from ..uuid_lib import get_uni_id


class AccountRepository:
"""账户操作"""

def __init__(self, session: AsyncSession):
self.session = session

async def get_or_create_account(
self, user_id: str, currency_id: str
) -> UserAccount:
async with self.session as session:
"""获取或创建用户账户"""
try:
# 获取货币配置
stmt = select(CurrencyMeta).where(CurrencyMeta.id == currency_id)
result = await session.execute(stmt)
currency = result.scalar_one_or_none()
if currency is None:
raise CurrencyNotFound(f"Currency {currency_id} not found")

# 检查账户是否存在
stmt = (
select(UserAccount)
.where(UserAccount.uni_id == get_uni_id(user_id, currency_id))
.with_for_update()
)
result = await session.execute(stmt)
account = result.scalar_one_or_none()

if account is not None:
session.add(account)
return account

session.add(currency)
account = UserAccount(
uni_id=get_uni_id(user_id, currency_id),
id=user_id,
currency_id=currency_id,
balance=currency.default_balance,
last_updated=datetime.now(timezone.utc),
)
session.add(account)
await session.commit()
await session.refresh(account)
return account
except Exception:
await session.rollback()
raise

async def set_account_frozen(
self,
account_id: str,
currency_id: str,
frozen: bool,
) -> None:
"""设置账户冻结状态"""
async with self.session as session:
try:
account = await self.get_or_create_account(account_id, currency_id)
session.add(account)
account.frozen = frozen
except Exception:
await session.rollback()
raise
else:
await session.commit()

async def set_frozen_all(self, account_id: str, frozen: bool):
async with self.session as session:
try:
result = await session.execute(
select(UserAccount).where(UserAccount.id == account_id)
)
accounts = result.scalars().all()
session.add_all(accounts)
for account in accounts:
account.frozen = frozen
except Exception as e:
await session.rollback()
raise e
else:
await session.commit()

async def is_account_frozen(
self,
account_id: str,
currency_id: str,
) -> bool:
"""判断账户是否冻结"""
async with self.session:
return (await self.get_or_create_account(account_id, currency_id)).frozen

async def get_balance(self, account_id: str, currency_id: str) -> float | None:
"""获取账户余额"""
uni_id = get_uni_id(account_id, currency_id)
account = await self.session.get(UserAccount, uni_id)
return account.balance if account else None

async def update_balance(
self, account_id: str, amount: float, currency_id: str
) -> tuple[float, float]:
async with self.session as session:
"""更新余额"""
try:
# 获取账户
account = (
await session.execute(
select(UserAccount)
.where(
UserAccount.uni_id == get_uni_id(account_id, currency_id)
)
.with_for_update()
)
).scalar_one_or_none()

if account is None:
raise AccountNotFound("Account not found")
session.add(account)

if account.frozen:
raise AccountFrozen(
f"Account {account_id} on currency {currency_id} is frozen"
)

# 获取货币规则
currency = await session.get(CurrencyMeta, account.currency_id)
session.add(currency)

# 负余额检查
if amount < 0 and not getattr(currency, "allow_negative", False):
raise TransactionException("Insufficient funds")

# 记录原始余额
old_balance = account.balance

# 更新余额
account.balance = amount
await session.commit()

return old_balance, amount
except Exception:
await session.rollback()
raise

async def list_accounts(
self, currency_id: str | None = None
) -> Sequence[UserAccount]:
"""列出所有账户"""
async with self.session as session:
if not currency_id:
result = await session.execute(select(UserAccount).with_for_update())
else:
result = await session.execute(
select(UserAccount)
.where(UserAccount.currency_id == currency_id)
.with_for_update()
)
data = result.scalars().all()
if len(data) > 0:
session.add_all(data)
return data

async def remove_account(self, account_id: str, currency_id: str | None = None):
"""删除账户"""
async with self.session as session:
try:
if not currency_id:
stmt = (
select(UserAccount)
.where(UserAccount.id == account_id)
.with_for_update()
)
else:
stmt = (
select(UserAccount)
.where(
UserAccount.uni_id == get_uni_id(account_id, currency_id)
)
.with_for_update()
)
accounts = (await session.execute(stmt)).scalars().all()
if not accounts:
raise AccountNotFound("Account not found")
for account in accounts:
stmt = delete(UserAccount).where(UserAccount.id == account.id)
await session.execute(stmt)
except Exception:
await session.rollback()
else:
await session.commit()
129 changes: 129 additions & 0 deletions nonebot_plugin_value/repositories/currency.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
# Repository,更加底层的数据库操作接口
from collections.abc import Sequence
from functools import singledispatch

from nonebot import logger
from nonebot_plugin_orm import AsyncSession
from sqlalchemy import delete, select, update

from ..exception import (
CurrencyNotFound,
)
from ..models.currency import CurrencyMeta
from ..pyd_models.currency_pyd import CurrencyData


class CurrencyRepository:
"""货币元数据操作"""

def __init__(self, session: AsyncSession):
self.session = session

@singledispatch
async def get_currency(self, currency_id: str) -> CurrencyMeta | None:
"""获取货币信息"""
async with self.session as session:
result = await self.session.execute(
select(CurrencyMeta).where(CurrencyMeta.id == currency_id)
)
if currency_meta := result.scalar_one_or_none():
session.add(currency_meta)
return currency_meta
return None

async def get_currency_by_kwargs(self, **kwargs: object) -> CurrencyMeta | None:
"""获取货币信息"""
async with self.session as session:
result = await session.execute(
select(CurrencyMeta).where(
*(
getattr(CurrencyMeta, key) == value
for key, value in kwargs.items()
if hasattr(CurrencyMeta, key)
)
)
)
if currency_meta := result.scalar_one_or_none():
session.add(currency_meta)
return currency_meta
return None

async def get_or_create_currency(
self, currency_data: CurrencyData
) -> tuple[CurrencyMeta, bool]:
"""获取或创建货币"""
async with self.session as session:
stmt = await session.execute(
select(CurrencyMeta).where(
CurrencyMeta.id == currency_data.id,
)
)
if (currency := stmt.scalars().first()) is not None:
session.add(currency)
return currency, True
result = await self.createcurrency(currency_data)
return result, False

async def createcurrency(self, currency_data: CurrencyData) -> CurrencyMeta:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nitpick: 方法名“createcurrency”与 Python 命名约定不一致。

将“createcurrency”重命名为“create_currency”,以符合 Python 标准和代码库的约定。

Original comment in English

nitpick: Method name 'createcurrency' is inconsistent with Python naming conventions.

Rename 'createcurrency' to 'create_currency' for consistency with Python standards and the codebase.

async with self.session as session:
"""创建新货币"""
currency = CurrencyMeta(**currency_data.model_dump())
session.add(currency)
await session.commit()
await session.refresh(currency)
return currency
Comment on lines +64 to +74

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion: 方法名 'createcurrency' 不符合标准 Python 命名约定。

重命名为 'create_currency' 以符合 PEP8 并保持一致性。

Suggested change
result = await self.createcurrency(currency_data)
return result, False
async def createcurrency(self, currency_data: CurrencyData) -> CurrencyMeta:
async with self.session as session:
"""创建新货币"""
currency = CurrencyMeta(**currency_data.model_dump())
session.add(currency)
await session.commit()
await session.refresh(currency)
return currency
result = await self.create_currency(currency_data)
return result, False
async def create_currency(self, currency_data: CurrencyData) -> CurrencyMeta:
async with self.session as session:
"""创建新货币"""
currency = CurrencyMeta(**currency_data.model_dump())
session.add(currency)
await session.commit()
await session.refresh(currency)
return currency
Original comment in English

suggestion: Method name 'createcurrency' does not follow standard Python naming conventions.

Rename to 'create_currency' to align with PEP8 and maintain consistency.

Suggested change
result = await self.createcurrency(currency_data)
return result, False
async def createcurrency(self, currency_data: CurrencyData) -> CurrencyMeta:
async with self.session as session:
"""创建新货币"""
currency = CurrencyMeta(**currency_data.model_dump())
session.add(currency)
await session.commit()
await session.refresh(currency)
return currency
result = await self.create_currency(currency_data)
return result, False
async def create_currency(self, currency_data: CurrencyData) -> CurrencyMeta:
async with self.session as session:
"""创建新货币"""
currency = CurrencyMeta(**currency_data.model_dump())
session.add(currency)
await session.commit()
await session.refresh(currency)
return currency


async def update_currency(self, currency_data: CurrencyData) -> CurrencyMeta:
"""更新货币信息"""
async with self.session as session:
try:
stmt = (
update(CurrencyMeta)
.where(CurrencyMeta.id == currency_data.id)
.values(**dict(currency_data))
)
await session.execute(stmt)
await session.commit()
stmt = (
select(CurrencyMeta)
.where(CurrencyMeta.id == currency_data.id)
.with_for_update()
)
result = await session.execute(stmt)
currency_meta = result.scalar_one()
session.add(currency_meta)
return currency_meta
except Exception:
await session.rollback()
raise


async def list_currencies(self) -> Sequence[CurrencyMeta]:
"""列出所有货币"""
async with self.session as session:
result = await self.session.execute(select(CurrencyMeta))
data = result.scalars().all()
session.add_all(data)
return data

async def remove_currency(self, currency_id: str):
"""删除货币(警告!会同时删除所有关联账户!)"""
async with self.session as session:
currency = (
await session.execute(
select(CurrencyMeta)
.where(CurrencyMeta.id == currency_id)
.with_for_update()
)
).scalar()
if currency is None:
raise CurrencyNotFound(f"Currency {currency_id} not found")
try:
logger.warning(f"Deleting currency {currency_id}")
stmt = delete(CurrencyMeta).where(CurrencyMeta.id == currency_id)
await session.execute(stmt)
except Exception:
await session.rollback()
raise
else:
await session.commit()
Loading
Loading