diff --git "a/alembic/versions/4b9d22943860_\345\256\236\347\216\260\345\244\232\351\234\200\346\261\202\345\220\210\345\271\266.py" "b/alembic/versions/4b9d22943860_\345\256\236\347\216\260\345\244\232\351\234\200\346\261\202\345\220\210\345\271\266.py" new file mode 100644 index 0000000..6e464af --- /dev/null +++ "b/alembic/versions/4b9d22943860_\345\256\236\347\216\260\345\244\232\351\234\200\346\261\202\345\220\210\345\271\266.py" @@ -0,0 +1,40 @@ +"""实现多需求合并 + +Revision ID: 4b9d22943860 +Revises: 48b09347ef95 +Create Date: 2025-04-28 23:52:46.462144 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = '4b9d22943860' +down_revision: Union[str, None] = '48b09347ef95' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Upgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('enter_application', + sa.Column('user_id', sa.Integer(), nullable=False), + sa.Column('group_id', sa.Integer(), nullable=False), + sa.ForeignKeyConstraint(['group_id'], ['groups.id'], ), + sa.ForeignKeyConstraint(['user_id'], ['users.id'], ), + sa.PrimaryKeyConstraint('user_id', 'group_id') + ) + op.add_column('groups', sa.Column('description', sa.String(length=200), nullable=False)) + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column('groups', 'description') + op.drop_table('enter_application') + # ### end Alembic commands ### diff --git a/app/api/v1/endpoints/articleDB.py b/app/api/v1/endpoints/articleDB.py index 79e1a4f..1db5986 100644 --- a/app/api/v1/endpoints/articleDB.py +++ b/app/api/v1/endpoints/articleDB.py @@ -2,8 +2,12 @@ from sqlalchemy.ext.asyncio import AsyncSession from app.utils.get_db import get_db from app.schemas.articleDB import UploadArticle, GetArticle, DeLArticle, GetResponse -from app.curd.articleDB import create_article_in_db, get_article_in_db - +from app.curd.articleDB import create_article_in_db, get_article_in_db, get_article_in_db_by_id +from app.core.config import settings +import os +import uuid +from fastapi.responses import FileResponse +from urllib.parse import quote router = APIRouter() @router.put("/upload", response_model=dict) @@ -17,8 +21,19 @@ async def upload_article( """ Upload an article to the database. """ + # 将文件保存到指定目录 + if not os.path.exists(settings.UPLOAD_FOLDER): + os.makedirs(settings.UPLOAD_FOLDER) + + # 生成文件名,可以使用 UUID 或者其他方式来确保文件名唯一 + file_name = f"{uuid.uuid4()}.pdf" + file_path = os.path.join(settings.UPLOAD_FOLDER, file_name) try: - await create_article_in_db(db=db, upload_article=UploadArticle(title=title, author=author, url=url)) + with open(file_path, "wb") as f: + while chunk := await file.read(1024): # 每次读取 1024 字节 + f.write(chunk) + + await create_article_in_db(db=db, upload_article=UploadArticle(title=title, author=author, url=url, file_path=file_path)) return {"msg": "Article uploaded successfully"} except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @@ -36,4 +51,30 @@ async def get_article(get_article: GetArticle = Depends(), db: AsyncSession = De "total_count": total_count }, "articles": [articles.model_dump() for articles in articles] - } \ No newline at end of file + } + +@router.get("/download/{article_id}", response_model=dict) +async def download_article(article_id: int, db: AsyncSession = Depends(get_db)): + """ + Download an article file by its ID. + """ + # 根据 ID 查询文章信息 + article = await get_article_in_db_by_id(db=db, article_id=article_id) + if not article or not article.file_path: + raise HTTPException(status_code=404, detail="File not found") + + if not os.path.exists(article.file_path): + raise HTTPException(status_code=404, detail="File not found on server") + + # 从文件路径获取文件名 + file_name = os.path.basename(article.file_path) + + # 设置原始文件名,如果有标题,使用标题作为文件名 + download_filename = f"{article.title}.pdf" if article.title else file_name + + # 返回文件,并设置文件名(使用 quote 处理中文文件名) + return FileResponse( + path=article.file_path, + filename=quote(download_filename), + media_type="application/pdf" + ) \ No newline at end of file diff --git a/app/api/v1/endpoints/auth.py b/app/api/v1/endpoints/auth.py index f393a42..d154c5f 100644 --- a/app/api/v1/endpoints/auth.py +++ b/app/api/v1/endpoints/auth.py @@ -16,6 +16,7 @@ from app.curd.article import crud_self_create_folder, crud_article_statistic from app.utils.get_db import get_db from app.utils.redis import get_redis_client +from app.curd.note import find_recent_notes_in_db router = APIRouter() @@ -151,4 +152,11 @@ async def send_code(user_send_code: UserSendCode): @router.get("/articleStatistic", response_model="dict") async def article_statistic(db: AsyncSession = Depends(get_db)): articles = await crud_article_statistic(db) - return {"articles": articles} \ No newline at end of file + return {"articles": articles} + +@router.get("/recent", response_model=dict) +async def get_recent_notes(db: AsyncSession = Depends(get_db)): + notes = await find_recent_notes_in_db(db) + return { + "notes": notes + } \ No newline at end of file diff --git a/app/api/v1/endpoints/note.py b/app/api/v1/endpoints/note.py index fa4e3c7..1f9f984 100644 --- a/app/api/v1/endpoints/note.py +++ b/app/api/v1/endpoints/note.py @@ -2,7 +2,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from app.schemas.note import NoteCreate, NoteUpdate, NoteFind from app.utils.get_db import get_db -from app.curd.note import create_note_in_db, delete_note_in_db, update_note_in_db, find_notes_in_db, find_notes_title_in_db, find_recent_notes_in_db +from app.curd.note import create_note_in_db, delete_note_in_db, update_note_in_db, find_notes_in_db, find_notes_title_in_db from typing import Optional router = APIRouter() @@ -52,11 +52,3 @@ async def get_notes_title(note_find: NoteFind = Depends(), db: AsyncSession = De }, "notes": notes } - - -@router.get("/recent", response_model=dict) -async def get_recent_notes(db: AsyncSession = Depends(get_db)): - notes = await find_recent_notes_in_db(db) - return { - "notes": notes - } \ No newline at end of file diff --git a/app/core/config.py b/app/core/config.py index b881797..1fb9f8f 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -16,6 +16,7 @@ class Settings: SENDER_EMAIL : str = "jienote_buaa@163.com" SENDER_PASSWORD: str = os.getenv("SENDER_PASSWORD", "default_password") # 发件人邮箱密码 KIMI_API_KEY: str = os.getenv("KIMI_API_KEY", "default_kimi_api_key") # KIMI API密钥 + UPLOAD_FOLDER: str = "/lhcos-data/acticleDB" settings = Settings() \ No newline at end of file diff --git a/app/curd/articleDB.py b/app/curd/articleDB.py index 77424b3..9080ec0 100644 --- a/app/curd/articleDB.py +++ b/app/curd/articleDB.py @@ -8,7 +8,7 @@ async def create_article_in_db(db: AsyncSession, upload_article: UploadArticle): """ Create a new article in the database. """ - article =ArticleDB(title=upload_article.title, url=upload_article.url, author=upload_article.author) + article =ArticleDB(title=upload_article.title, url=upload_article.url, author=upload_article.author, file_path=upload_article.file_path) db.add(article) await db.commit() await db.refresh(article) @@ -37,3 +37,10 @@ async def get_article_in_db(db: AsyncSession, get_article: GetArticle): return [GetResponse.model_validate(article) for article in articles], total_count +async def get_article_in_db_by_id(db: AsyncSession, article_id: int): + """ + Get an article by its ID. + """ + result = await db.execute(select(ArticleDB).where(ArticleDB.id == article_id)) + article = result.scalars().first() + return article \ No newline at end of file diff --git a/app/main.py b/app/main.py index c9aac38..840f0ba 100644 --- a/app/main.py +++ b/app/main.py @@ -2,6 +2,7 @@ from app.routers.router import include_routers from fastapi_pagination import add_pagination from loguru import logger +from fastapi.middleware.cors import CORSMiddleware app = FastAPI() @@ -27,4 +28,13 @@ async def log_requests(request: Request, call_next): logger.info(f"Request: {request.method} {request.url}") response = await call_next(request) logger.info(f"Response status: {response.status_code}") - return response \ No newline at end of file + return response + +# 配置 CORS +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], # 允许的前端来源 + allow_credentials=True, # 允许发送凭据(如 Cookies 或 Authorization 头) + allow_methods=["*"], # 允许的 HTTP 方法 + allow_headers=["*"], # 允许的请求头 +) \ No newline at end of file diff --git a/app/models/model.py b/app/models/model.py index be9e0d2..6a5a7c2 100644 --- a/app/models/model.py +++ b/app/models/model.py @@ -11,6 +11,12 @@ Column('is_admin', Boolean, default=False) ) +enter_application = Table( + 'enter_application', Base.metadata, + Column('user_id', Integer, ForeignKey('users.id'), primary_key=True), + Column('group_id', Integer, ForeignKey('groups.id'), primary_key=True), +) + class User(Base): __tablename__ = 'users' @@ -32,6 +38,7 @@ class Group(Base): id = Column(Integer, primary_key=True, index=True, autoincrement=True) leader = Column(Integer) name = Column(String(30), nullable=False) + description = Column(String(200), nullable=False) create_time = Column(DateTime, default=func.now(), nullable=False) # 创建时间 update_time = Column(DateTime, default=func.now(), onupdate=func.now(), nullable=False) # 更新时间 users = relationship('User', secondary=user_group, back_populates='groups') @@ -109,6 +116,7 @@ class ArticleDB(Base): title = Column(String(200), nullable=False) url = Column(String(200), nullable=False) author = Column(String(100), nullable=False) + file_path = Column(String(200), nullable=False) create_time = Column(DateTime, default=func.now(), nullable=False) # 创建时间 update_time = Column(DateTime, default=func.now(), onupdate=func.now(), nullable=False) # 更新时间 \ No newline at end of file diff --git a/app/schemas/articleDB.py b/app/schemas/articleDB.py index d0d6bb0..688188e 100644 --- a/app/schemas/articleDB.py +++ b/app/schemas/articleDB.py @@ -5,6 +5,7 @@ class UploadArticle(BaseModel): title: str author: str url: str + file_path: str class GetArticle(BaseModel): id: int | None = None @@ -20,6 +21,7 @@ class GetResponse(BaseModel): url: str create_time: datetime update_time: datetime + file_path: str class Config: from_attributes = True \ No newline at end of file