11from hashlib import sha256
22from base64 import b64decode
33from uuid import UUID , uuid4
4- from typing import Dict , Union , Optional , cast
4+ from typing import Dict , Union , Optional
55
66from anyio import open_file
77from httpx import AsyncClient
8- from sqlmodel import JSON , Field , Column , select
8+ from sqlalchemy import JSON , Uuid , select
9+ from sqlalchemy .orm import Mapped , mapped_column
910from nonebot .adapters .onebot .v12 .exception import DatabaseError
1011from nonebot_plugin_datastore import create_session , get_plugin_data
1112
@@ -17,15 +18,15 @@ def get_sha256(data: bytes) -> str:
1718 return sha256 (data ).hexdigest ()
1819
1920
20- class File (Model , table = True ):
21- id : UUID = Field ( default_factory = uuid4 , primary_key = True )
22- name : str
23- src : Optional [str ] = None
24- src_id : Optional [str ] = None
25- url : Optional [str ] = None
26- headers : Optional [Dict [str , str ]] = Field ( default = None , sa_column = Column ( JSON ) )
27- path : Optional [str ] = None
28- sha256 : Optional [str ] = None
21+ class File (Model ):
22+ id : Mapped [ UUID ] = mapped_column ( Uuid , primary_key = True , default = uuid4 )
23+ name : Mapped [ str ]
24+ src : Mapped [ Optional [str ]]
25+ src_id : Mapped [ Optional [str ]]
26+ url : Mapped [ Optional [str ]]
27+ headers : Mapped [ Optional [Dict [str , str ]]] = mapped_column ( JSON )
28+ path : Mapped [ Optional [str ]]
29+ sha256 : Mapped [ Optional [str ]]
2930
3031
3132DATA_PATH = plugin_data .data_dir
@@ -36,10 +37,9 @@ class File(Model, table=True):
3637async def get_file (file_id : str , src : Optional [str ] = None ) -> File :
3738 async with create_session () as session :
3839 file = (
39- await session .execute (select (File ).where (File .id == file_id ))
40+ await session .scalars (select (File ).where (File .id == UUID ( file_id ) ))
4041 ).one_or_none ()
4142 if file :
42- file = cast (File , file [0 ])
4343 if src is None :
4444 if file .sha256 :
4545 return file
@@ -49,13 +49,13 @@ async def get_file(file_id: str, src: Optional[str] = None) -> File:
4949 else :
5050 if file .sha256 :
5151 if file_ := (
52- await session .execute (
52+ await session .scalars (
5353 select (File ).where (
5454 File .sha256 == file .sha256 , File .src == src
5555 )
5656 )
5757 ).first ():
58- return file_ [ 0 ]
58+ return file_
5959 else :
6060 file .src = src
6161 file .src_id = None
@@ -77,26 +77,28 @@ async def upload_file(
7777 if src and src_id :
7878 async with create_session () as session :
7979 if file := (
80- await session .execute (
80+ await session .scalars (
8181 select (File ).where (File .src == src ).where (File .src_id == src_id )
8282 )
8383 ).first ():
84- return file [ 0 ] .id .hex
84+ return file .id .hex
8585 if sha256 :
86- async with create_session () as session , session . begin () :
86+ async with create_session () as session :
8787 if file := (
88- await session .execute (select (File ).where (File .sha256 == sha256 ))
88+ await session .scalars (select (File ).where (File .sha256 == sha256 ))
8989 ).first ():
9090 file = File (
9191 name = name ,
9292 src = src ,
9393 src_id = src_id ,
94- url = file [ 0 ] .url ,
95- headers = file [ 0 ] .headers ,
96- path = file [ 0 ] .path ,
94+ url = file .url ,
95+ headers = file .headers ,
96+ path = file .path ,
9797 sha256 = sha256 ,
9898 )
9999 session .add (file )
100+ await session .commit ()
101+ await session .refresh (file )
100102 return file .id .hex
101103
102104 if path :
@@ -124,6 +126,8 @@ async def upload_file(
124126 path = path ,
125127 sha256 = sha256 ,
126128 )
127- async with create_session () as session , session . begin () :
129+ async with create_session () as session :
128130 session .add (file )
131+ await session .commit ()
132+ await session .refresh (file )
129133 return file .id .hex
0 commit comments