diff --git a/.env b/.env new file mode 100644 index 0000000..453d6a6 --- /dev/null +++ b/.env @@ -0,0 +1,3 @@ +SECRET_KEY=bN3hZ6LbHG7nH9YXWULCr-crcS3GAaRELbNBdAyHBuiHH5TRctd0Zbd6OuLRHHa4Fbs +SENDER_PASSWORD=TXVU2unpCAE2EtEX +KIMI_API_KEY=sk-icdiHIiv6x8XjJCaN6J6Un7uoVxm6df5WPhflq10ZVFo03D9 \ No newline at end of file diff --git a/.github/workflows/check.yml b/.github/workflows/check.yml new file mode 100644 index 0000000..5fb8768 --- /dev/null +++ b/.github/workflows/check.yml @@ -0,0 +1,56 @@ +name: Check FastAPI Backend + +on: + pull_request: + branches: + - dev + +jobs: + test: + name: Run Tests and Check FastAPI + runs-on: ubuntu-latest + + services: + postgres: + image: postgres:13 + env: + POSTGRES_USER: test_user + POSTGRES_PASSWORD: test_password + POSTGRES_DB: test_db + ports: + - 5432:5432 + options: >- + --health-cmd="pg_isready -U test_user" + --health-interval=10s + --health-timeout=5s + --health-retries=5 + + redis: + image: redis:7 + ports: + - 6379:6379 + options: >- + --health-cmd="redis-cli ping" + --health-interval=10s + --health-timeout=5s + --health-retries=5 + + steps: + - name: Checkout code + uses: actions/checkout@v3 + + - name: Set up Python + uses: actions/setup-python@v3 + with: + python-version: '3.12' + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -r requirements.txt + + - name: Check FastAPI Server + run: | + uvicorn app.main:app --host 0.0.0.0 --port 8000 --log-level warning & + sleep 5 + curl -f http://localhost:8000/docs \ No newline at end of file diff --git a/.github/workflows/deploy.yml b/.github/workflows/deploy.yml new file mode 100644 index 0000000..daaf716 --- /dev/null +++ b/.github/workflows/deploy.yml @@ -0,0 +1,43 @@ +name: Deploy FastAPI Backend + +on: + push: + branches: + - main + - dev + +jobs: + deploy: + name: Deploy to Server + runs-on: ubuntu-22.04 + + steps: + - name: Checkout Code + uses: actions/checkout@v3 + + - name: Setup SSH + uses: webfactory/ssh-agent@v0.5.3 + with: + ssh-private-key: ${{ secrets.SERVER_SSH_KEY }} + + - name: Add server to known_hosts + run: | + ssh-keyscan -H jienote.top >> ~/.ssh/known_hosts + + - name: Sync Code to Server + run: | + rsync -avz --delete \ + --exclude '.git' \ + --exclude '.github' \ + ./ \ + ${{ secrets.REMOTE_USER }}@${{ secrets.REMOTE_HOST }}:${{ secrets.REMOTE_PATH }} + + - name: Build and Restart Docker on Server + run: | + ssh ${{ secrets.REMOTE_USER }}@${{ secrets.REMOTE_HOST }} << 'EOF' + cd ${{ secrets.REMOTE_PATH }} + cd .. + docker-compose down + docker-compose build + docker-compose up -d + EOF \ No newline at end of file diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..c75ecb7 --- /dev/null +++ b/.gitignore @@ -0,0 +1,4 @@ +env +__pycache__ +articles +app.log \ No newline at end of file diff --git a/README.md b/README.md index 1ce8e16..f27b91c 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,117 @@ -# JieNote_backend -2025春季软件工程课程团队项目JieNote项目后端 +# JieNote Backend + +This is the backend service for JieNote, built with FastAPI. + +## Features +- RESTful API endpoints +- Modular structure for scalability + +## File Structure +- `app/`: Contains the main application code. + - `main.py`: Entry point for the FastAPI application. + - `models/`: Database models and schemas. + - `core/`: Core configurations and settings. + - include database settings + - include JWT settings + - include CORS settings + - …… + - `curd/`: CRUD operations for database interactions. + - `db/`: Database connection and session management. + - `schemas/`: Pydantic schemas for data validation. + - `static/`: Static files (e.g., images, CSS). + - `routers/`: API route definitions. +- `tests/`: Contains test cases for the application. +- `requirements.txt`: List of dependencies. +- `README.md`: Documentation for the project. +- `alembic/`: Database migration scripts and configurations. +- `env/`: Virtual environment (not included in version control). +- `img/`: Images used in the project. + +## Setup +1. Create a virtual environment: ✔ + ```bash + python -m venv env + ``` +2. Activate the virtual environment: + - On Windows: + ```bash + .\env\Scripts\activate + ``` + - On macOS/Linux: + ```bash + source env/bin/activate + ``` +3. Install dependencies: + ```bash + pip install -r requirements.txt + ``` +4. freeze requirements(do before commit !!!): + ```bash + pip freeze > requirements.txt + ``` + +## Database Migration + +1. Install Alembic: ✔ + ```bash + pip install alembic + ``` +2. Initialize Alembic: ✔ + ```bash + alembic init alembic + ``` +3. Configure Alembic: ✔ + + 1. Edit `alembic.ini` to set the database URL. + 2. Edit `alembic/env.py` to set up the target metadata. + ```python + from app.models import Base # Import your models here + target_metadata = Base.metadata + ``` +4. Create a migration script: need to modify the script + ```bash + alembic revision --autogenerate -m "提交信息" + ``` +5. Apply the migration: need to modify the script + ```bash + alembic upgrade head + ``` + + +## Run the Application +```bash +uvicorn app.main:app --reload +``` + +## Redis +- Redis is used for caching and session management. +- Make sure to have Redis installed and running. + +```bash +cd path/to/redis +# Start Redis server +redis-server.exe redis.windows.conf +``` +Attention!!! +- Make sure the port is not occupied by other services. +- If you want to use the default port, please modify the `redis.windows.conf` file. +- Must connect Redis before running the application. ‼️‼️‼️ + + +## Token Authentication +- JWT (JSON Web Token) is used for authentication. +- Refresh tokens for 7 days and access tokens for 5min. + +## Folder Structure +- `app/`: Contains the main application code. +- `tests/`: Contains test cases. +- `env/`: Virtual environment (not included in version control). + +## OCR +- Must install Poppler + +## ER Diagram + + +## License +MIT License diff --git a/alembic.ini b/alembic.ini new file mode 100644 index 0000000..c44a966 --- /dev/null +++ b/alembic.ini @@ -0,0 +1,119 @@ +# A generic, single database configuration. + +[alembic] +# path to migration scripts +# Use forward slashes (/) also on windows to provide an os agnostic path +script_location = alembic + +# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s +# Uncomment the line below if you want the files to be prepended with date and time +# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file +# for all available tokens +# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s + +# sys.path path, will be prepended to sys.path if present. +# defaults to the current working directory. +prepend_sys_path = . + +# timezone to use when rendering the date within the migration file +# as well as the filename. +# If specified, requires the python>=3.9 or backports.zoneinfo library and tzdata library. +# Any required deps can installed by adding `alembic[tz]` to the pip requirements +# string value is passed to ZoneInfo() +# leave blank for localtime +# timezone = + +# max length of characters to apply to the "slug" field +# truncate_slug_length = 40 + +# set to 'true' to run the environment during +# the 'revision' command, regardless of autogenerate +# revision_environment = false + +# set to 'true' to allow .pyc and .pyo files without +# a source .py file to be detected as revisions in the +# versions/ directory +# sourceless = false + +# version location specification; This defaults +# to alembic/versions. When using multiple version +# directories, initial revisions must be specified with --version-path. +# The path separator used here should be the separator specified by "version_path_separator" below. +# version_locations = %(here)s/bar:%(here)s/bat:alembic/versions + +# version path separator; As mentioned above, this is the character used to split +# version_locations. The default within new alembic.ini files is "os", which uses os.pathsep. +# If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas. +# Valid values for version_path_separator are: +# +# version_path_separator = : +# version_path_separator = ; +# version_path_separator = space +# version_path_separator = newline +# +# Use os.pathsep. Default configuration used for new projects. +version_path_separator = os + +# set to 'true' to search source files recursively +# in each "version_locations" directory +# new in Alembic version 1.10 +# recursive_version_locations = false + +# the output encoding used when revision files +# are written from script.py.mako +# output_encoding = utf-8 + +sqlalchemy.url = mysql+pymysql://root:oneapi@47.93.172.156:3306/JieNote + + +[post_write_hooks] +# post_write_hooks defines scripts or Python functions that are run +# on newly generated revision scripts. See the documentation for further +# detail and examples + +# format using "black" - use the console_scripts runner, against the "black" entrypoint +# hooks = black +# black.type = console_scripts +# black.entrypoint = black +# black.options = -l 79 REVISION_SCRIPT_FILENAME + +# lint with attempts to fix using "ruff" - use the exec runner, execute a binary +# hooks = ruff +# ruff.type = exec +# ruff.executable = %(here)s/.venv/bin/ruff +# ruff.options = check --fix REVISION_SCRIPT_FILENAME + +# Logging configuration +[loggers] +keys = root,sqlalchemy,alembic + +[handlers] +keys = console + +[formatters] +keys = generic + +[logger_root] +level = WARNING +handlers = console +qualname = + +[logger_sqlalchemy] +level = WARNING +handlers = +qualname = sqlalchemy.engine + +[logger_alembic] +level = INFO +handlers = +qualname = alembic + +[handler_console] +class = StreamHandler +args = (sys.stderr,) +level = NOTSET +formatter = generic + +[formatter_generic] +format = %(levelname)-5.5s [%(name)s] %(message)s +datefmt = %H:%M:%S diff --git a/alembic/README b/alembic/README new file mode 100644 index 0000000..98e4f9c --- /dev/null +++ b/alembic/README @@ -0,0 +1 @@ +Generic single-database configuration. \ No newline at end of file diff --git a/alembic/env.py b/alembic/env.py new file mode 100644 index 0000000..41ef15e --- /dev/null +++ b/alembic/env.py @@ -0,0 +1,79 @@ +from logging.config import fileConfig + +from sqlalchemy import engine_from_config +from sqlalchemy import pool +from app.db.base import Base + +from alembic import context + +# this is the Alembic Config object, which provides +# access to the values within the .ini file in use. +config = context.config + +# Interpret the config file for Python logging. +# This line sets up loggers basically. +if config.config_file_name is not None: + fileConfig(config.config_file_name) + +# add your model's MetaData object here +# for 'autogenerate' support +# from myapp import mymodel +# target_metadata = mymodel.Base.metadata +target_metadata = Base.metadata + +# other values from the config, defined by the needs of env.py, +# can be acquired: +# my_important_option = config.get_main_option("my_important_option") +# ... etc. + + +def run_migrations_offline() -> None: + """Run migrations in 'offline' mode. + + This configures the context with just a URL + and not an Engine, though an Engine is acceptable + here as well. By skipping the Engine creation + we don't even need a DBAPI to be available. + + Calls to context.execute() here emit the given string to the + script output. + + """ + url = config.get_main_option("sqlalchemy.url") + context.configure( + url=url, + target_metadata=target_metadata, + literal_binds=True, + dialect_opts={"paramstyle": "named"}, + ) + + with context.begin_transaction(): + context.run_migrations() + + +def run_migrations_online() -> None: + """Run migrations in 'online' mode. + + In this scenario we need to create an Engine + and associate a connection with the context. + + """ + connectable = engine_from_config( + config.get_section(config.config_ini_section, {}), + prefix="sqlalchemy.", + poolclass=pool.NullPool, + ) + + with connectable.connect() as connection: + context.configure( + connection=connection, target_metadata=target_metadata + ) + + with context.begin_transaction(): + context.run_migrations() + + +if context.is_offline_mode(): + run_migrations_offline() +else: + run_migrations_online() diff --git a/alembic/script.py.mako b/alembic/script.py.mako new file mode 100644 index 0000000..480b130 --- /dev/null +++ b/alembic/script.py.mako @@ -0,0 +1,28 @@ +"""${message} + +Revision ID: ${up_revision} +Revises: ${down_revision | comma,n} +Create Date: ${create_date} + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +${imports if imports else ""} + +# revision identifiers, used by Alembic. +revision: str = ${repr(up_revision)} +down_revision: Union[str, None] = ${repr(down_revision)} +branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)} +depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)} + + +def upgrade() -> None: + """Upgrade schema.""" + ${upgrades if upgrades else "pass"} + + +def downgrade() -> None: + """Downgrade schema.""" + ${downgrades if downgrades else "pass"} diff --git "a/alembic/versions/004c4aa2b3f3_\344\270\252\344\272\272\345\233\236\346\224\266\347\253\231\350\241\250\345\242\236\345\212\240\344\270\212\347\272\247\344\277\241\346\201\257.py" "b/alembic/versions/004c4aa2b3f3_\344\270\252\344\272\272\345\233\236\346\224\266\347\253\231\350\241\250\345\242\236\345\212\240\344\270\212\347\272\247\344\277\241\346\201\257.py" new file mode 100644 index 0000000..d7742e9 --- /dev/null +++ "b/alembic/versions/004c4aa2b3f3_\344\270\252\344\272\272\345\233\236\346\224\266\347\253\231\350\241\250\345\242\236\345\212\240\344\270\212\347\272\247\344\277\241\346\201\257.py" @@ -0,0 +1,46 @@ +"""个人回收站表增加上级信息 + +Revision ID: 004c4aa2b3f3 +Revises: d6d6ae6d9680 +Create Date: 2025-05-21 21:29:14.873544 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = '004c4aa2b3f3' +down_revision: Union[str, None] = 'd6d6ae6d9680' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Upgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.drop_constraint('articles_ibfk_1', 'articles', type_='foreignkey') + op.create_foreign_key(None, 'articles', 'folders', ['folder_id'], ['id'], ondelete='CASCADE') + op.drop_constraint('notes_ibfk_1', 'notes', type_='foreignkey') + op.create_foreign_key(None, 'notes', 'articles', ['article_id'], ['id'], ondelete='CASCADE') + op.add_column('self_recycle_bin', sa.Column('article_id', sa.Integer(), nullable=True)) + op.add_column('self_recycle_bin', sa.Column('folder_id', sa.Integer(), nullable=True)) + op.create_foreign_key(None, 'self_recycle_bin', 'folders', ['folder_id'], ['id'], ondelete='CASCADE') + op.create_foreign_key(None, 'self_recycle_bin', 'articles', ['article_id'], ['id'], ondelete='CASCADE') + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.drop_constraint(None, 'self_recycle_bin', type_='foreignkey') + op.drop_constraint(None, 'self_recycle_bin', type_='foreignkey') + op.drop_column('self_recycle_bin', 'folder_id') + op.drop_column('self_recycle_bin', 'article_id') + op.drop_constraint(None, 'notes', type_='foreignkey') + op.create_foreign_key('notes_ibfk_1', 'notes', 'articles', ['article_id'], ['id']) + op.drop_constraint(None, 'articles', type_='foreignkey') + op.create_foreign_key('articles_ibfk_1', 'articles', 'folders', ['folder_id'], ['id']) + # ### end Alembic commands ### diff --git "a/alembic/versions/0c8a143b1c4d_group\345\242\236\345\212\240name\345\255\227\346\256\265_\346\224\271\345\217\230article\347\232\204name\345\255\227\346\256\265\345\256\232\344\271\211.py" "b/alembic/versions/0c8a143b1c4d_group\345\242\236\345\212\240name\345\255\227\346\256\265_\346\224\271\345\217\230article\347\232\204name\345\255\227\346\256\265\345\256\232\344\271\211.py" new file mode 100644 index 0000000..1f05188 --- /dev/null +++ "b/alembic/versions/0c8a143b1c4d_group\345\242\236\345\212\240name\345\255\227\346\256\265_\346\224\271\345\217\230article\347\232\204name\345\255\227\346\256\265\345\256\232\344\271\211.py" @@ -0,0 +1,40 @@ +"""group增加name字段, 改变article的name字段定义 + +Revision ID: 0c8a143b1c4d +Revises: f0cfac833b7d +Create Date: 2025-04-22 23:25:31.324190 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import mysql + +# revision identifiers, used by Alembic. +revision: str = '0c8a143b1c4d' +down_revision: Union[str, None] = 'f0cfac833b7d' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Upgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.alter_column('articles', 'name', + existing_type=mysql.VARCHAR(length=30), + type_=sa.Text(), + existing_nullable=False) + op.add_column('groups', sa.Column('name', sa.String(length=30), nullable=False)) + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column('groups', 'name') + op.alter_column('articles', 'name', + existing_type=sa.Text(), + type_=mysql.VARCHAR(length=30), + existing_nullable=False) + # ### end Alembic commands ### diff --git "a/alembic/versions/48b09347ef95_\345\242\236\345\212\240\347\224\250\346\210\267\344\277\241\346\201\257.py" "b/alembic/versions/48b09347ef95_\345\242\236\345\212\240\347\224\250\346\210\267\344\277\241\346\201\257.py" new file mode 100644 index 0000000..27d780a --- /dev/null +++ "b/alembic/versions/48b09347ef95_\345\242\236\345\212\240\347\224\250\346\210\267\344\277\241\346\201\257.py" @@ -0,0 +1,38 @@ +"""增加用户信息 + +Revision ID: 48b09347ef95 +Revises: dfc2d1f9dad2 +Create Date: 2025-04-23 16:53:15.415162 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.sql import func + +# revision identifiers, used by Alembic. +revision: str = '48b09347ef95' +down_revision: Union[str, None] = 'dfc2d1f9dad2' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Upgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('users', sa.Column('create_time', sa.DateTime(), server_default=func.now(), nullable=False)) + op.add_column('users', sa.Column('address', sa.String(length=100), nullable=True)) + op.add_column('users', sa.Column('university', sa.String(length=100), nullable=True)) + op.add_column('users', sa.Column('introduction', sa.Text(), nullable=True)) + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column('users', 'create_time') + op.drop_column('users', 'introduction') + op.drop_column('users', 'university') + op.drop_column('users', 'address') + # ### end Alembic commands ### diff --git "a/alembic/versions/4b9d22943860_\345\256\236\347\216\260\345\244\232\351\234\200\346\261\202\345\220\210\345\271\266.py" "b/alembic/versions/4b9d22943860_\345\256\236\347\216\260\345\244\232\351\234\200\346\261\202\345\220\210\345\271\266.py" new file mode 100644 index 0000000..6e464af --- /dev/null +++ "b/alembic/versions/4b9d22943860_\345\256\236\347\216\260\345\244\232\351\234\200\346\261\202\345\220\210\345\271\266.py" @@ -0,0 +1,40 @@ +"""实现多需求合并 + +Revision ID: 4b9d22943860 +Revises: 48b09347ef95 +Create Date: 2025-04-28 23:52:46.462144 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = '4b9d22943860' +down_revision: Union[str, None] = '48b09347ef95' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Upgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('enter_application', + sa.Column('user_id', sa.Integer(), nullable=False), + sa.Column('group_id', sa.Integer(), nullable=False), + sa.ForeignKeyConstraint(['group_id'], ['groups.id'], ), + sa.ForeignKeyConstraint(['user_id'], ['users.id'], ), + sa.PrimaryKeyConstraint('user_id', 'group_id') + ) + op.add_column('groups', sa.Column('description', sa.String(length=200), nullable=False)) + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column('groups', 'description') + op.drop_table('enter_application') + # ### end Alembic commands ### diff --git "a/alembic/versions/7af566a6091b_\344\277\256\346\224\271note.py" "b/alembic/versions/7af566a6091b_\344\277\256\346\224\271note.py" new file mode 100644 index 0000000..e570532 --- /dev/null +++ "b/alembic/versions/7af566a6091b_\344\277\256\346\224\271note.py" @@ -0,0 +1,32 @@ +"""修改note + +Revision ID: 7af566a6091b +Revises: 949e5fc5dfc4 +Create Date: 2025-05-11 22:56:54.523911 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = '7af566a6091b' +down_revision: Union[str, None] = '949e5fc5dfc4' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Upgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + pass + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + pass + # ### end Alembic commands ### diff --git "a/alembic/versions/923af5d7462a_\345\242\236\345\212\240\346\226\207\347\214\256\345\272\223\346\226\260.py" "b/alembic/versions/923af5d7462a_\345\242\236\345\212\240\346\226\207\347\214\256\345\272\223\346\226\260.py" new file mode 100644 index 0000000..a73adbe --- /dev/null +++ "b/alembic/versions/923af5d7462a_\345\242\236\345\212\240\346\226\207\347\214\256\345\272\223\346\226\260.py" @@ -0,0 +1,32 @@ +"""增加文献库新 + +Revision ID: 923af5d7462a +Revises: d7e135e9e071 +Create Date: 2025-04-22 12:02:19.315637 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = '923af5d7462a' +down_revision: Union[str, None] = 'd7e135e9e071' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Upgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('articleDB', sa.Column('author', sa.String(length=100), nullable=False)) + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column('articleDB', 'author') + # ### end Alembic commands ### diff --git "a/alembic/versions/949e5fc5dfc4_\345\242\236\345\212\240\347\224\250\346\210\267\344\277\241\346\201\257\346\217\217\350\277\260.py" "b/alembic/versions/949e5fc5dfc4_\345\242\236\345\212\240\347\224\250\346\210\267\344\277\241\346\201\257\346\217\217\350\277\260.py" new file mode 100644 index 0000000..8bb593e --- /dev/null +++ "b/alembic/versions/949e5fc5dfc4_\345\242\236\345\212\240\347\224\250\346\210\267\344\277\241\346\201\257\346\217\217\350\277\260.py" @@ -0,0 +1,34 @@ +"""增加用户信息描述 + +Revision ID: 949e5fc5dfc4 +Revises: a434b17f5caf +Create Date: 2025-05-11 22:31:11.511674 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = '949e5fc5dfc4' +down_revision: Union[str, None] = 'a434b17f5caf' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Upgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('notes', sa.Column('creator_id', sa.Integer(), nullable=True)) + op.create_foreign_key(None, 'notes', 'users', ['creator_id'], ['id']) + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.drop_constraint(None, 'notes', type_='foreignkey') + op.drop_column('notes', 'creator_id') + # ### end Alembic commands ### diff --git "a/alembic/versions/a434b17f5caf_\345\242\236\345\212\240\344\275\234\350\200\205\345\220\215.py" "b/alembic/versions/a434b17f5caf_\345\242\236\345\212\240\344\275\234\350\200\205\345\220\215.py" new file mode 100644 index 0000000..6081c9f --- /dev/null +++ "b/alembic/versions/a434b17f5caf_\345\242\236\345\212\240\344\275\234\350\200\205\345\220\215.py" @@ -0,0 +1,38 @@ +"""增加作者名 + +Revision ID: a434b17f5caf +Revises: 4b9d22943860 +Create Date: 2025-04-29 17:34:19.895192 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import mysql + +# revision identifiers, used by Alembic. +revision: str = 'a434b17f5caf' +down_revision: Union[str, None] = '4b9d22943860' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Upgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.alter_column('articleDB', 'author', + existing_type=mysql.VARCHAR(length=100), + type_=sa.String(length=300), + existing_nullable=False) + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.alter_column('articleDB', 'author', + existing_type=sa.String(length=300), + type_=mysql.VARCHAR(length=100), + existing_nullable=False) + # ### end Alembic commands ### diff --git "a/alembic/versions/cf83488540d9_\346\226\207\347\214\256\345\222\214\346\226\207\344\273\266\345\244\271\347\232\204\345\233\236\346\224\266\347\253\231\346\224\257\346\214\201.py" "b/alembic/versions/cf83488540d9_\346\226\207\347\214\256\345\222\214\346\226\207\344\273\266\345\244\271\347\232\204\345\233\236\346\224\266\347\253\231\346\224\257\346\214\201.py" new file mode 100644 index 0000000..9fa13d9 --- /dev/null +++ "b/alembic/versions/cf83488540d9_\346\226\207\347\214\256\345\222\214\346\226\207\344\273\266\345\244\271\347\232\204\345\233\236\346\224\266\347\253\231\346\224\257\346\214\201.py" @@ -0,0 +1,34 @@ +"""文献和文件夹的回收站支持 + +Revision ID: cf83488540d9 +Revises: +Create Date: 2025-04-20 21:08:32.574488 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = 'cf83488540d9' +down_revision: Union[str, None] = None +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Upgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('articles', sa.Column('visible', sa.Boolean(), nullable=False)) + op.add_column('folders', sa.Column('visible', sa.Boolean(), nullable=False)) + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column('folders', 'visible') + op.drop_column('articles', 'visible') + # ### end Alembic commands ### diff --git "a/alembic/versions/d6d6ae6d9680_\345\242\236\345\212\240\344\270\252\344\272\272\345\233\236\346\224\266\347\253\231\350\241\250.py" "b/alembic/versions/d6d6ae6d9680_\345\242\236\345\212\240\344\270\252\344\272\272\345\233\236\346\224\266\347\253\231\350\241\250.py" new file mode 100644 index 0000000..7043f11 --- /dev/null +++ "b/alembic/versions/d6d6ae6d9680_\345\242\236\345\212\240\344\270\252\344\272\272\345\233\236\346\224\266\347\253\231\350\241\250.py" @@ -0,0 +1,40 @@ +"""增加个人回收站表 + +Revision ID: d6d6ae6d9680 +Revises: 7af566a6091b +Create Date: 2025-05-14 11:25:12.719964 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = 'd6d6ae6d9680' +down_revision: Union[str, None] = '7af566a6091b' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Upgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('self_recycle_bin', + sa.Column('user_id', sa.Integer(), nullable=True), + sa.Column('type', sa.Integer(), nullable=False), + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('name', sa.Text(), nullable=False), + sa.Column('create_time', sa.DateTime(), nullable=False), + sa.ForeignKeyConstraint(['user_id'], ['users.id'], ), + sa.PrimaryKeyConstraint('type', 'id') + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table('self_recycle_bin') + # ### end Alembic commands ### diff --git "a/alembic/versions/d7e135e9e071_\345\242\236\345\212\240\346\226\207\347\214\256\345\272\223.py" "b/alembic/versions/d7e135e9e071_\345\242\236\345\212\240\346\226\207\347\214\256\345\272\223.py" new file mode 100644 index 0000000..3ae1e8f --- /dev/null +++ "b/alembic/versions/d7e135e9e071_\345\242\236\345\212\240\346\226\207\347\214\256\345\272\223.py" @@ -0,0 +1,41 @@ +"""增加文献库 + +Revision ID: d7e135e9e071 +Revises: db21037668bc +Create Date: 2025-04-22 10:59:41.153493 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = 'd7e135e9e071' +down_revision: Union[str, None] = 'db21037668bc' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Upgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('articleDB', + sa.Column('id', sa.Integer(), autoincrement=True, nullable=False), + sa.Column('title', sa.String(length=200), nullable=False), + sa.Column('url', sa.String(length=200), nullable=False), + sa.Column('create_time', sa.DateTime(), nullable=False), + sa.Column('update_time', sa.DateTime(), nullable=False), + sa.PrimaryKeyConstraint('id') + ) + op.create_index(op.f('ix_articleDB_id'), 'articleDB', ['id'], unique=False) + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index(op.f('ix_articleDB_id'), table_name='articleDB') + op.drop_table('articleDB') + # ### end Alembic commands ### diff --git "a/alembic/versions/db21037668bc_note\345\242\236\345\212\240\345\233\236\346\224\266\347\253\231\345\255\227\346\256\265.py" "b/alembic/versions/db21037668bc_note\345\242\236\345\212\240\345\233\236\346\224\266\347\253\231\345\255\227\346\256\265.py" new file mode 100644 index 0000000..ff1c9d6 --- /dev/null +++ "b/alembic/versions/db21037668bc_note\345\242\236\345\212\240\345\233\236\346\224\266\347\253\231\345\255\227\346\256\265.py" @@ -0,0 +1,32 @@ +"""Note增加回收站字段 + +Revision ID: db21037668bc +Revises: cf83488540d9 +Create Date: 2025-04-21 14:37:10.169767 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = 'db21037668bc' +down_revision: Union[str, None] = 'cf83488540d9' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Upgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('notes', sa.Column('visible', sa.Boolean(), nullable=False)) + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column('notes', 'visible') + # ### end Alembic commands ### diff --git "a/alembic/versions/dfc2d1f9dad2_article\350\241\250\347\232\204folder_lazy_selectin.py" "b/alembic/versions/dfc2d1f9dad2_article\350\241\250\347\232\204folder_lazy_selectin.py" new file mode 100644 index 0000000..b499b4c --- /dev/null +++ "b/alembic/versions/dfc2d1f9dad2_article\350\241\250\347\232\204folder_lazy_selectin.py" @@ -0,0 +1,32 @@ +"""Article表的folder: lazy=selectin + +Revision ID: dfc2d1f9dad2 +Revises: 0c8a143b1c4d +Create Date: 2025-04-23 11:12:58.159106 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = 'dfc2d1f9dad2' +down_revision: Union[str, None] = '0c8a143b1c4d' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Upgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + pass + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + pass + # ### end Alembic commands ### diff --git "a/alembic/versions/f0cfac833b7d_\344\270\272\347\254\224\350\256\260\345\242\236\345\212\240\346\240\207\351\242\230.py" "b/alembic/versions/f0cfac833b7d_\344\270\272\347\254\224\350\256\260\345\242\236\345\212\240\346\240\207\351\242\230.py" new file mode 100644 index 0000000..8dfce4d --- /dev/null +++ "b/alembic/versions/f0cfac833b7d_\344\270\272\347\254\224\350\256\260\345\242\236\345\212\240\346\240\207\351\242\230.py" @@ -0,0 +1,32 @@ +"""为笔记增加标题 + +Revision ID: f0cfac833b7d +Revises: 923af5d7462a +Create Date: 2025-04-22 14:50:36.985617 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = 'f0cfac833b7d' +down_revision: Union[str, None] = '923af5d7462a' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Upgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('notes', sa.Column('title', sa.String(length=100), nullable=False)) + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column('notes', 'title') + # ### end Alembic commands ### diff --git a/app/__init__.py b/app/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/api/__init__.py b/app/api/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/api/v1/__init__.py b/app/api/v1/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/api/v1/endpoints/__init__.py b/app/api/v1/endpoints/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/api/v1/endpoints/aichat.py b/app/api/v1/endpoints/aichat.py new file mode 100644 index 0000000..f4748ff --- /dev/null +++ b/app/api/v1/endpoints/aichat.py @@ -0,0 +1,50 @@ +from fastapi import Depends, Request +from fastapi.responses import StreamingResponse +from app.utils.aichat import kimi_chat_stream +from app.utils.redis import get_redis_client +from app.utils.auth import get_current_user +import json +from fastapi import APIRouter +from app.schemas.aichat import NoteInput + + +router = APIRouter() +redis_client = get_redis_client() + +@router.post("/note", response_model=dict) +async def generate_notes( + input: NoteInput, + current_user: dict = Depends(get_current_user) +): + user_id = current_user["id"] + redis_key = f"aichat:{user_id}" + + # 1. 读取历史对话 + history = redis_client.get(redis_key) + if history: + messages = json.loads(history) + else: + # 首轮对话可加 system prompt + messages = [{"role": "system", "content": "你是一个智能笔记助手。"}] + + # 2. 追加用户输入 + messages.append({"role": "user", "content": input.input}) + + async def ai_stream(): + full_reply = "" + async for content in kimi_chat_stream(messages): + full_reply += content + yield f"data: {json.dumps({'content': content}, ensure_ascii=False)}\n\n" + messages.append({"role": "assistant", "content": full_reply}) + redis_client.set(redis_key, json.dumps(messages), ex=3600) + + return StreamingResponse(ai_stream(), media_type="text/event-stream") + +@router.get("/clear", response_model=dict) +async def clear_notes( + current_user: dict = Depends(get_current_user) +): + user_id = current_user["id"] + redis_key = f"aichat:{user_id}" + redis_client.delete(redis_key) + return {"msg": "clear successfully"} \ No newline at end of file diff --git a/app/api/v1/endpoints/article.py b/app/api/v1/endpoints/article.py new file mode 100644 index 0000000..9159aba --- /dev/null +++ b/app/api/v1/endpoints/article.py @@ -0,0 +1,222 @@ +from fastapi import APIRouter, UploadFile, File, Query, Depends, HTTPException, Body +from fastapi.responses import FileResponse +from fastapi import BackgroundTasks +from sqlalchemy.ext.asyncio import AsyncSession +from typing import Optional, List +import os +import io +from zipfile import ZipFile +import zipfile +import tempfile + +from app.utils.get_db import get_db +from app.utils.auth import get_current_user +from app.curd.article import crud_upload_to_self_folder, crud_get_self_folders, crud_get_articles_in_folder, crud_self_create_folder, crud_self_article_to_recycle_bin, crud_self_folder_to_recycle_bin, crud_read_article, crud_import_self_folder, crud_export_self_folder,crud_create_tag, crud_delete_tag, crud_get_article_tags, crud_all_tags_order, crud_change_folder_name, crud_change_article_name, crud_article_statistic, crud_self_tree, crud_self_article_statistic, crud_items_in_recycle_bin, crud_delete_forever, crud_recover +from app.schemas.article import SelfCreateFolder + +router = APIRouter() + +@router.post("/uploadToSelfFolder", response_model="dict") +async def upload_to_self_folder(folder_id: int = Query(...), article: UploadFile = File(...), db: AsyncSession = Depends(get_db)): + # 检查上传的必须为 PDF + head = await article.read(5) # 读取文件的前 5 个字节,用于魔数检测 + if not head.startswith(b"%PDF-"): + raise HTTPException(status_code=405, detail="File uploaded must be a PDF.") + await article.seek(0) # 重置文件指针位置 + + # 用文件名(不带扩展名)作为 Article 名称 + name = os.path.splitext(article.filename)[0] + + # 新建 Article 记录 + article_id = await crud_upload_to_self_folder(name, folder_id, db) + + # 存储到云存储位置 + os.makedirs("/lhcos-data", exist_ok=True) + save_path = os.path.join("/lhcos-data", f"{article_id}.pdf") + with open(save_path, "wb") as f: + content = await article.read() + f.write(content) + + return {"msg": "Article created successfully.", "article_id": article_id} + +@router.get("/getSelfFolders", response_model="dict") +async def get_self_folders(page_number: Optional[int] = Query(None, ge=1), page_size: Optional[int] = Query(None, ge=1), + db: AsyncSession = Depends(get_db), user: dict = Depends(get_current_user)): + # 获取用户id + user_id = user.get("id") + + total_num, folders = await crud_get_self_folders(user_id, page_number, page_size, db) + result = [{"folder_id": folder.id, "folder_name": folder.name} for folder in folders] + return {"total_num": total_num, "result": result} + +@router.get("/getArticlesInFolder", response_model="dict") +async def get_articles_in_folder(folder_id: int = Query(...), page_number: Optional[int] = Query(None, ge=1), page_size: Optional[int] = Query(None, ge=1), + db: AsyncSession = Depends(get_db)): + total_num, articles = await crud_get_articles_in_folder(folder_id, page_number, page_size, db) + result = [{"article_id": article.id, "article_name": article.name} for article in articles] + return {"total_num": total_num, "result": result} + +@router.post("/selfCreateFolder", response_model="dict") +async def self_create_folder(model: SelfCreateFolder, db: AsyncSession = Depends(get_db), user: dict = Depends(get_current_user)): + folder_name = model.folder_name + if folder_name == "" or len(folder_name) > 30: + raise HTTPException(status_code=405, detail="Invalid folder name, empty or longer than 30") + + # 获取用户id + user_id = user.get("id") + + # 数据库插入 + folder_id = await crud_self_create_folder(folder_name, user_id, db) + + # 返回结果 + return {"msg": "User Folder Created Successfully", "folder_id": folder_id} + +@router.delete("/selfArticleToRecycleBin", response_model="dict") +async def self_article_to_recycle_bin(article_id: int = Query(...), db: AsyncSession = Depends(get_db), user: dict = Depends(get_current_user)): + user_id = user.get("id") + await crud_self_article_to_recycle_bin(article_id, user_id, db) + return {"msg": "Article is moved to recycle bin"} + +@router.delete("/selfFolderToRecycleBin", response_model="dict") +async def self_folder_to_recycle_bin(folder_id: int = Query(...), db: AsyncSession = Depends(get_db), user: dict = Depends(get_current_user)): + user_id = user.get("id") + await crud_self_folder_to_recycle_bin(folder_id, user_id, db) + return {"msg": "Folder is moved to recycle bin"} + +@router.post("/annotateSelfArticle", response_model="dict") +async def annotate_self_article(article_id: int = Query(...), article: UploadFile = File(...)): + # 将新文件存储到云存储位置 + os.makedirs("/lhcos-data", exist_ok=True) + save_path = os.path.join("/lhcos-data", f"{article_id}.pdf") + with open(save_path, "wb") as f: + content = await article.read() + f.write(content) + + return {"msg": "Article annotated successfully."} + +@router.get("/readArticle", response_class=FileResponse) +async def read_article(article_id: int = Query(...), db: AsyncSession = Depends(get_db)): + + file_path = f"/lhcos-data/{article_id}.pdf" + + # 查询文件名 + article_name = await crud_read_article(article_id, db) + + # 返回结果 + return FileResponse(path=file_path, filename=f"{article_name}.pdf", media_type='application/pdf') + +@router.post("/importSelfFolder", response_model="dict") +async def import_self_folder(folder_name: str = Query(...), zip: UploadFile = File(...), db: AsyncSession = Depends(get_db), user: dict = Depends(get_current_user)): + if folder_name == "" or len(folder_name) > 30: + raise HTTPException(status_code=405, detail="Invalid folder name, empty or longer than 30") + + # 获取用户id + user_id = user.get("id") + + # 获取压缩包中的所有文献名(去掉.pdf) + zip_bytes = await zip.read() + zip_file = ZipFile(io.BytesIO(zip_bytes)) + article_names = [os.path.splitext(os.path.basename(name))[0] for name in zip_file.namelist() if name.endswith('.pdf')] + + # 记入数据库 + result = await crud_import_self_folder(folder_name, article_names, user_id, db) + + # 存储文献到云存储 + os.makedirs("/lhcos-data", exist_ok=True) + for i in range(0, len(result), 2): + article_id = result[i] + article_name = result[i + 1] + pdf_filename_in_zip = f"{article_name}.pdf" + with zip_file.open(pdf_filename_in_zip) as source_file: + target_path = os.path.join("/lhcos-data", f"{article_id}.pdf") + with open(target_path, "wb") as out_file: + out_file.write(source_file.read()) + + return {"msg": "Successfully import articles"} + +@router.get("/exportSelfFolder", response_class=FileResponse) +async def export_self_folder(background_tasks: BackgroundTasks, folder_id: int = Query(...), db: AsyncSession = Depends(get_db)): + zip_name, article_ids, article_names = await crud_export_self_folder(folder_id, db) + + tmp_dir = tempfile.gettempdir() + zip_path = os.path.join(tmp_dir, f"{zip_name}.zip") + + with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zipf: + for article_id, article_name in zip(article_ids, article_names): + pdf_path = os.path.join("/lhcos-data", f"{article_id}.pdf") + arcname = f"{article_name}.pdf" # 压缩包内的文件名 + zipf.write(pdf_path, arcname=arcname) + + background_tasks.add_task(os.remove, zip_path) + + return FileResponse( + path=zip_path, + filename=f"{zip_name}.zip", + media_type="application/zip" + ) + +@router.post("/createTag", response_model="dict") +async def create_tag(article_id: int = Body(...), content: str = Body(...), db: AsyncSession = Depends(get_db)): + if len(content) > 30: + raise HTTPException(status_code=405, detail="Invalid tag content, longer than 30") + await crud_create_tag(article_id, content, db) + return {"msg": "Tag Created Successfully"} + +@router.delete("/deleteTag", response_model="dict") +async def delete_tag(tag_id: int = Query(...), db: AsyncSession = Depends(get_db)): + await crud_delete_tag(tag_id, db) + return {"msg": "Tag deleted successfully"} + +@router.get("/getArticleTags", response_model="dict") +async def get_article_tags(article_id: int = Query(...), db: AsyncSession = Depends(get_db)): + tags = await crud_get_article_tags(article_id, db) + result = [{"tag_id": tag.id, "tag_content": tag.content} for tag in tags] + return {"result": result} + +@router.post("/allTagsOrder", response_model="dict") +async def all_tags_order(article_id: int = Body(...), tag_contents: List[str] = Body(...), db: AsyncSession = Depends(get_db)): + for tag_content in tag_contents: + if len(tag_content) > 30: + raise HTTPException(status_code=405, detail="Invalid tag content existed, longer than 30") + await crud_all_tags_order(article_id, tag_contents, db) + return {"msg": "Tags and order changed successfully"} + +@router.post("/changeFolderName", response_model="dict") +async def change_folder_name(folder_id: int = Body(...), folder_name: str = Body(...), db: AsyncSession = Depends(get_db)): + if folder_name == "" or len(folder_name) > 30: + raise HTTPException(status_code=405, detail="Invalid folder name, empty or longer than 30") + await crud_change_folder_name(folder_id, folder_name, db) + return {"msg": "Folder name changed successfully"} + +@router.post("/changeArticleName", response_model="dict") +async def change_article_name(article_id: int = Body(...), article_name: str = Body(...), db: AsyncSession = Depends(get_db)): + await crud_change_article_name(article_id, article_name, db) + return {"msg": "Article name changed successfully"} + +@router.get("/selfTree", response_model="dict") +async def self_tree(page_number: Optional[int] = Query(None, ge=1), page_size: Optional[int] = Query(None, ge=1), db: AsyncSession = Depends(get_db), user: dict = Depends(get_current_user)): + user_id = user.get("id") + total_folder_num, folders = await crud_self_tree(user_id, page_number, page_size, db) + return {"total_folder_num": total_folder_num, "folders": folders} + +@router.get("/selfArticleStatistic", response_model=dict) +async def self_article_statistic(db: AsyncSession = Depends(get_db), user: dict = Depends(get_current_user)): + user_id = user.get("id") + article_total_num, articles = await crud_self_article_statistic(user_id, db) + return {"article_total_num": article_total_num, "articles": articles} + +@router.get("/itemsInRecycleBin", response_model=dict) +async def items_in_recycle_bin(page_number: Optional[int] = Query(None, ge=1), page_size: Optional[int] = Query(None, ge=1), db: AsyncSession = Depends(get_db), user: dict = Depends(get_current_user)): + user_id = user.get("id") + items = await crud_items_in_recycle_bin(user_id, page_number, page_size, db) + return {"items": items} + +@router.delete("/deleteForever", response_model=dict) +async def delete_forever(type: int = Query(...), id: int = Query(...), db: AsyncSession = Depends(get_db)): + await crud_delete_forever(type, id, db) + return {"msg": "Item and its child nodes deleted forever successfully"} + +@router.post("/recover", response_model=dict) +async def recover(type: int = Body(...), id: int = Body(...), db: AsyncSession = Depends(get_db)): + return_value = await crud_recover(type, id, db) + return return_value \ No newline at end of file diff --git a/app/api/v1/endpoints/articleDB.py b/app/api/v1/endpoints/articleDB.py new file mode 100644 index 0000000..9b0e030 --- /dev/null +++ b/app/api/v1/endpoints/articleDB.py @@ -0,0 +1,108 @@ +from fastapi import APIRouter, HTTPException, Depends, UploadFile, Form, File +from sqlalchemy.ext.asyncio import AsyncSession +from app.utils.get_db import get_db +from app.schemas.articleDB import UploadArticle, GetArticle, DeLArticle, GetResponse +from app.curd.articleDB import create_article_in_db, get_article_in_db, get_article_in_db_by_id, get_article_info_in_db_by_id +from app.core.config import settings +import os +import uuid +from fastapi.responses import FileResponse +from urllib.parse import quote +from app.curd.article import crud_upload_to_self_folder +router = APIRouter() + +@router.put("/upload", response_model=dict) +async def upload_article( + title: str = Form(None), + author: str = Form(None), + url: str = Form(None), + file: UploadFile = File(...), + db: AsyncSession = Depends(get_db) +): + """ + Upload an article to the database. + """ + # 将文件保存到指定目录 + if not os.path.exists(settings.UPLOAD_FOLDER): + os.makedirs(settings.UPLOAD_FOLDER) + + # 生成文件名,可以使用 UUID 或者其他方式来确保文件名唯一 + file_name = f"{uuid.uuid4()}.pdf" + file_path = os.path.join(settings.UPLOAD_FOLDER, file_name) + try: + with open(file_path, "wb") as f: + while chunk := await file.read(1024): # 每次读取 1024 字节 + f.write(chunk) + + await create_article_in_db(db=db, upload_article=UploadArticle(title=title, author=author, url=url, file_path=file_path)) + return {"msg": "Article uploaded successfully"} + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + +@router.get("/get", response_model=dict) +async def get_article(get_article: GetArticle = Depends(), db: AsyncSession = Depends(get_db)): + """ + Get an article from the database. + """ + articles, total_count = await get_article_in_db(db=db, get_article=get_article) + return { + "pagination": { + "page": get_article.page, + "page_size": get_article.page_size, + "total_count": total_count + }, + "articles": [articles.model_dump() for articles in articles] + } + +@router.get("/download/{article_id}", response_model=dict) +async def download_article(article_id: int, db: AsyncSession = Depends(get_db)): + """ + Download an article file by its ID. + """ + # 根据 ID 查询文章信息 + article = await get_article_in_db_by_id(db=db, article_id=article_id) + if not article or not article.file_path: + raise HTTPException(status_code=404, detail="File not found") + + if not os.path.exists(article.file_path): + raise HTTPException(status_code=404, detail="File not found on server") + + # 从文件路径获取文件名 + file_name = os.path.basename(article.file_path) + + # 设置原始文件名,如果有标题,使用标题作为文件名 + download_filename = f"{article.title}.pdf" if article.title else file_name + + # 返回文件,并设置文件名(使用 quote 处理中文文件名) + return FileResponse( + path=article.file_path, + filename=quote(download_filename), + media_type="application/pdf" + ) + +@router.put("/copy", response_model=dict) +async def copy_article(folder_id: int, article_id: int, db: AsyncSession = Depends(get_db)): + """ + Copy an article file by its ID to a specified directory. + """ + # 根据 ID 查询文章信息 + file_path, title = await get_article_info_in_db_by_id(db=db, article_id=article_id) + if not file_path: + raise HTTPException(status_code=404, detail="File not found") + + new_article_id = await crud_upload_to_self_folder(name=title, folder_id=folder_id, db=db) + + # 复制文件到新的目录 + old_file_path = file_path + new_file_path = os.path.join("/lhcos-data", f"{new_article_id}.pdf") + try: + with open(old_file_path, "rb") as source_file: + with open(new_file_path, "wb") as dest_file: + dest_file.write(source_file.read()) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + return {"msg": "Article copied successfully", "new_article_id": new_article_id} + + + + diff --git a/app/api/v1/endpoints/auth.py b/app/api/v1/endpoints/auth.py new file mode 100644 index 0000000..52c779d --- /dev/null +++ b/app/api/v1/endpoints/auth.py @@ -0,0 +1,162 @@ +from fastapi import APIRouter, HTTPException, Depends +from sqlalchemy.ext.asyncio import AsyncSession +from passlib.context import CryptContext +from datetime import datetime, timedelta +from jose import jwt, JWTError, ExpiredSignatureError +import aiosmtplib +from email.mime.text import MIMEText +from email.header import Header +import random +import time +from email.utils import formataddr + +from app.schemas.auth import UserCreate, UserLogin, UserSendCode, ReFreshToken +from app.core.config import settings +from app.curd.user import get_user_by_email, create_user +from app.curd.article import crud_self_create_folder, crud_article_statistic +from app.utils.get_db import get_db +from app.utils.redis import get_redis_client +from app.curd.note import find_recent_notes_in_db + +router = APIRouter() + +pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") # 使用 bcrypt 加密算法 +SECRET_KEY = settings.SECRET_KEY +ALGORITHM = settings.ALGORITHM +ACCESS_TOKEN_EXPIRE_MINUTES = settings.ACCESS_TOKEN_EXPIRE_MINUTES + +# 配置 Redis 连接 +redis_client = get_redis_client() + +async def create_access_token(data: dict, expires_delta: timedelta = None): + to_encode = data.copy() + if expires_delta: + expire = datetime.now() + expires_delta + else: + expire = datetime.now() + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES) + to_encode.update({"exp": expire}) + encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) + return encoded_jwt + +async def create_refresh_token(data: dict, expires_delta: timedelta = None): + to_encode = data.copy() + if expires_delta: + expire = datetime.now() + expires_delta + else: + expire = datetime.now() + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS) + to_encode.update({"exp": expire, "type": "refresh"}) + encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) + return encoded_jwt + +@router.post("/register", response_model=dict) +async def register(user: UserCreate, db: AsyncSession = Depends(get_db)): + existing_user = await get_user_by_email(db, user.email) + if redis_client.exists(f"email:{user.email}:code"): + code = redis_client.get(f"email:{user.email}:code") + if user.code != code: + raise HTTPException(status_code=400, detail="Invalid verification code") + else: + raise HTTPException(status_code=400, detail="Verification code expired or not sent") + + if existing_user: + raise HTTPException(status_code=400, detail="Email already registered") + hashed_password = pwd_context.hash(user.password) + new_user = await create_user(db, user.email, user.username, hashed_password) + await crud_self_create_folder("", new_user.id, db) + return {"msg": "User registered successfully"} + +@router.post("/login", response_model=dict) +async def login(user: UserLogin, db: AsyncSession = Depends(get_db)): + db_user = await get_user_by_email(db, user.email) + if not db_user or not pwd_context.verify(user.password, db_user.password): + raise HTTPException(status_code=401, detail="Invalid email or password") + access_token_expires = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES) + refresh_token_expires = timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS) + access_token = await create_access_token( + data={"sub": db_user.email, "id": db_user.id}, expires_delta=access_token_expires + ) + refresh_token = await create_refresh_token( + data={"sub": db_user.email, "id": db_user.id}, expires_delta=refresh_token_expires + ) + return { + "access_token": access_token, + "refresh_token": refresh_token, + "token_type": "bearer", + "user_id": db_user.id, + "email": db_user.email, + "username": db_user.username, + "avatar": db_user.avatar + } + +@router.post("/refresh", response_model=dict) +async def refresh_token(refresh_token: ReFreshToken): + try: + payload = jwt.decode(refresh_token.refresh_token, SECRET_KEY, algorithms=[ALGORITHM]) + if payload.get("type") != "refresh": + raise HTTPException(status_code=401, detail="Invalid refresh token type") + access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) + access_token = await create_access_token( + data={"sub": payload["sub"], "id": payload["id"]}, expires_delta=access_token_expires + ) + return {"access_token": access_token} + except ExpiredSignatureError: + raise HTTPException(status_code=401, detail="Refresh token expired") + except JWTError: + raise HTTPException(status_code=401, detail="Invalid refresh token") + +# 发送验证码 +@router.post("/send_code", response_model=dict) +async def send_code(user_send_code: UserSendCode): + if redis_client.exists(f"email:{user_send_code.email}:time"): + raise HTTPException(status_code=429, detail="You can only request a verification code once every 5 minutes.") + + # 生成随机验证码 + code = str(random.randint(100000, 999999)) + + # SMTP 配置 + smtp_server = settings.SMTP_SERVER + smtp_port = settings.SMTP_PORT + sender_email = settings.SENDER_EMAIL + sender_password = settings.SENDER_PASSWORD + + # 邮件内容 + subject = "验证码" + body = f"欢迎使用JieNote,很开心遇见您,您的验证码是:{code},请在5分钟内使用。" + + # 创建MIMEText对象时需要显式指定子类型和编码 + message = MIMEText(_text=body, _subtype='plain', _charset='utf-8') + message["From"] = formataddr(("JieNote团队", "jienote_buaa@163.com")) + message["To"] = user_send_code.email + message["Subject"] = Header(subject, 'utf-8').encode() + # 添加必要的内容传输编码头 + message.add_header('Content-Transfer-Encoding', 'base64') + + try: + await aiosmtplib.send( + message, + hostname=smtp_server, + port=smtp_port, + username=sender_email, + password=sender_password, + use_tls=True, + ) + + redis_client.setex(f"email:{user_send_code.email}:code", settings.CODE_EXPIRATION_TIME, code) + redis_client.setex(f"email:{user_send_code.email}:time", settings.CODE_EXPIRATION_TIME, int(time.time())) + + return {"msg": "Verification code sent"} + + except aiosmtplib.SMTPException as e: + raise HTTPException(status_code=500, detail=f"Failed to send email: {str(e)}") + +@router.get("/articleStatistic", response_model="dict") +async def article_statistic(db: AsyncSession = Depends(get_db)): + articles = await crud_article_statistic(db) + return {"articles": articles} + +@router.get("/recent", response_model=dict) +async def get_recent_notes(db: AsyncSession = Depends(get_db)): + notes = await find_recent_notes_in_db(db) + return { + "notes": notes + } \ No newline at end of file diff --git a/app/api/v1/endpoints/group.py b/app/api/v1/endpoints/group.py new file mode 100644 index 0000000..56f850e --- /dev/null +++ b/app/api/v1/endpoints/group.py @@ -0,0 +1,46 @@ +from fastapi import APIRouter, Query, Body, UploadFile, File, Depends, HTTPException +from sqlalchemy.ext.asyncio import AsyncSession +import os + +from app.utils.get_db import get_db +from app.utils.auth import get_current_user +from app.curd.group import crud_create, crud_apply_to_enter, crud_get_applications, crud_reply_to_enter +from app.schemas.group import ApplyToEnter + +router = APIRouter() + +@router.post("/create", response_model=dict) +async def create(group_name: str = Query(...), group_desc: str = Query(...), group_avatar: UploadFile | None = File(None) + , db: AsyncSession = Depends(get_db), user: dict = Depends(get_current_user)): + if len(group_name) > 30: + raise HTTPException(status_code=405, detail="Invalid group name, longer than 30") + if len(group_desc) > 200: + raise HTTPException(status_code=405, detail="Invalid group description, longer than 200") + group_id = await crud_create(user.get("id"), group_name, group_desc, db) + if group_avatar: + os.makedirs("/lhcos-data/group-avatar", exist_ok=True) + ext = os.path.splitext(group_avatar.filename)[1] + path = os.path.join("/lhcos-data/group-avatar", f"{group_id}{ext}") + with open(path, "wb") as f: + content = await group_avatar.read() + f.write(content) + return {"msg": "Group created successfully"} + +@router.post("/applyToEnter", response_model=dict) +async def apply_to_enter(model: ApplyToEnter, db: AsyncSession = Depends(get_db), user: dict = Depends(get_current_user)): + group_id = model.group_id + user_id = user.get("id") + await crud_apply_to_enter(user_id, group_id, db) + return {"msg": "Application sent successfully"} + +@router.get("/getApplications", response_model=dict) +async def get_applications(group_id: int = Query(...), db: AsyncSession = Depends(get_db)): + users = await crud_get_applications(group_id, db) + return {"users": users} + +@router.post("/replyToEnter", response_model=dict) +async def reply_to_enter(user_id: int = Body(...), group_id: int = Body(...), reply: int = Body(...), db: AsyncSession = Depends(get_db)): + if reply != 0 and reply != 1: + raise HTTPException(status_code=405, detail="Wrong parameter, reply should be either 0 or 1") + msg = await crud_reply_to_enter(user_id, group_id, reply, db) + return {"msg": msg} \ No newline at end of file diff --git a/app/api/v1/endpoints/note.py b/app/api/v1/endpoints/note.py new file mode 100644 index 0000000..56cb98c --- /dev/null +++ b/app/api/v1/endpoints/note.py @@ -0,0 +1,72 @@ +from fastapi import APIRouter, HTTPException, Depends +from sqlalchemy.ext.asyncio import AsyncSession +from app.schemas.note import NoteCreate, NoteUpdate, NoteFind +from app.utils.get_db import get_db +from app.curd.note import create_note_in_db, delete_note_in_db, update_note_in_db, find_notes_in_db, find_notes_title_in_db, find_self_notes_count_in_db, find_self_recent_notes_in_db +from typing import Optional +from app.utils.auth import get_current_user +router = APIRouter() + +@router.post("/create", response_model=dict) +async def create_note(note: NoteCreate, db: AsyncSession = Depends(get_db), current_user: dict = Depends(get_current_user)): + user_id = current_user["id"] + new_note = await create_note_in_db(note, db, user_id) + return {"msg": "Note created successfully", "note_id": new_note.id} + +@router.delete("/{note_id}", response_model=dict) +async def delete_note(note_id: int, db: AsyncSession = Depends(get_db), current_user: dict = Depends(get_current_user)): + user_id = current_user["id"] + note = await delete_note_in_db(note_id, user_id, db) + if not note: + raise HTTPException(status_code=404, detail="Note not found") + return {"msg": "Note deleted successfully"} + +@router.put("/{note_id}", response_model=dict) +async def update_note(note_id: int, content: Optional[str] = None, title: Optional[str] = None,db: AsyncSession = Depends(get_db)): + if content is None and title is None: + raise HTTPException(status_code=400, detail="At least one field must be provided for update") + note = NoteUpdate(id=note_id, content=content, title=title) + updated_note = await update_note_in_db(note_id, note, db) + if not updated_note: + raise HTTPException(status_code=404, detail="Note not found") + return {"msg": "Note updated successfully", "note_id": updated_note.id} + +@router.get("/get", response_model=dict) +async def get_notes(note_find: NoteFind = Depends(), db: AsyncSession = Depends(get_db)): + notes, total_count = await find_notes_in_db(note_find, db) + return { + "pagination": { + "total_count": total_count, + "page": note_find.page, + "page_size": note_find.page_size + }, + "notes": [note.model_dump() for note in notes] + } + +@router.get("/title", response_model=dict) +async def get_notes_title(note_find: NoteFind = Depends(), db: AsyncSession = Depends(get_db)): + notes, total_count = await find_notes_title_in_db(note_find, db) + return { + "pagination": { + "total_count": total_count, + "page": note_find.page, + "page_size": note_find.page_size + }, + "notes": notes + } + +@router.get("/count", response_model=dict) +async def get_notes_count(db: AsyncSession = Depends(get_db), current_user: dict = Depends(get_current_user)): + user_id = current_user["id"] + count = await find_self_notes_count_in_db(db, user_id) + return { + "count": count + } + +@router.get("/count/recent", response_model=dict) +async def get_recent_notes_count(db: AsyncSession = Depends(get_db), current_user: dict = Depends(get_current_user)): + user_id = current_user["id"] + notes = await find_self_recent_notes_in_db(db, user_id) + return { + "notes": notes + } diff --git a/app/api/v1/endpoints/user.py b/app/api/v1/endpoints/user.py new file mode 100644 index 0000000..d5330fc --- /dev/null +++ b/app/api/v1/endpoints/user.py @@ -0,0 +1,96 @@ +from fastapi import APIRouter, HTTPException, Depends, UploadFile, Form, File +from app.schemas.user import UserUpdate, PasswordUpdate +from app.curd.user import update_user_in_db, get_user_by_email, update_user_password +from sqlalchemy.ext.asyncio import AsyncSession +from app.utils.get_db import get_db +from app.utils.auth import get_current_user +from passlib.context import CryptContext +import os +from uuid import uuid4 +from typing import Optional +router = APIRouter() + +# update current user +@router.put("/update", response_model=dict) +async def update_current_user( + username: Optional[str] = Form(None), + avatar: Optional[UploadFile] = File(None), + address: Optional[str] = Form(None), + university: Optional[str] = Form(None), + introduction: Optional[str] = Form(None), + db: AsyncSession = Depends(get_db), + current_user: dict = Depends(get_current_user) +): + """ + Update the current user's information. + """ + db_user = await get_user_by_email(db, current_user["email"]) + if not db_user: + raise HTTPException(status_code=404, detail="User not found") + try: + avatar_url = None + if avatar: + avatar_file: UploadFile = avatar + file_extension = os.path.splitext(avatar_file.filename)[1] + unique_filename = f"{uuid4()}{file_extension}" + avatar_path = os.path.join("/lhcos-data/avatar", unique_filename) + + # 确保以二进制模式写入文件,避免编码问题 + with open(avatar_path, "wb") as f: + f.write(await avatar_file.read()) + + # 生成 URL 路径 + avatar_url = f"/lhcos-data/avatar/{unique_filename}" + + # 删除旧的头像文件 + if db_user.avatar and db_user.avatar != "/lhcos-data/avatar/default.png": + if os.path.exists(db_user.avatar): + os.remove(db_user.avatar) + + update_user_response = UserUpdate( + username=username, + avatar=avatar_url if avatar_url else db_user.avatar, + address=address, + university=university, + introduction=introduction + ) + await update_user_in_db(db, update_user_response, db_user.id) + return {"msg": "User updated successfully"} + + except Exception as e: + raise HTTPException(status_code=400, detail=str(e)) + +pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") + +@router.post("/password", response_model=dict) +async def change_password( + password_update: PasswordUpdate, + db: AsyncSession = Depends(get_db), + current_user: dict = Depends(get_current_user) +): + db_user = await get_user_by_email(db, current_user["email"]) + if not db_user: + raise HTTPException(status_code=404, detail="User not found") + if not pwd_context.verify(password_update.old_password, db_user.password): + raise HTTPException(status_code=400, detail="Old password is incorrect") + + await update_user_password(db, db_user.id, pwd_context.hash(password_update.new_password)) + return {"msg": "Password changed successfully"} + +@router.get("/get", response_model=dict) +async def get_user_id( + db: AsyncSession = Depends(get_db), + current_user: dict = Depends(get_current_user) +): + db_user = await get_user_by_email(db, current_user["email"]) +# 返回用户所有信息 + return { + "id": db_user.id, + "username": db_user.username, + "email": db_user.email, + "avatar": db_user.avatar, + "address": db_user.address, + "university": db_user.university, + "introduction": db_user.introduction, + "create_time": db_user.create_time + } diff --git a/app/core/__init__.py b/app/core/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/core/config.py b/app/core/config.py new file mode 100644 index 0000000..b80e7ba --- /dev/null +++ b/app/core/config.py @@ -0,0 +1,23 @@ +import os +from dotenv import load_dotenv + +load_dotenv() + +class Settings: + PROJECT_NAME: str = "JieNote Backend" # 项目名称 + VERSION: str = "1.0.0" # 项目版本 + SQLALCHEMY_DATABASE_URL = "mysql+asyncmy://root:oneapi@47.93.172.156:3306/JieNote" # 替换为实际的用户名、密码和数据库名称 + SECRET_KEY: str = os.getenv("SECRET_KEY", "default_secret_key") # JWT密钥 + ALGORITHM: str = "HS256" # JWT算法 + ACCESS_TOKEN_EXPIRE_MINUTES: int = 1440 # token过期时间 + REFRESH_TOKEN_EXPIRE_DAYS: int = 7 # 刷新token过期时间7天 + SMTP_SERVER: str = "smtp.163.com" # SMTP服务器 + SMTP_PORT: int = 465 # SMTP端口 + SENDER_EMAIL : str = "jienote_buaa@163.com" + SENDER_PASSWORD: str = os.getenv("SENDER_PASSWORD", "default_password") # 发件人邮箱密码 + KIMI_API_KEY: str = os.getenv("KIMI_API_KEY", "default_kimi_api_key") # KIMI API密钥 + UPLOAD_FOLDER: str = "/lhcos-data/acticleDB" + CODE_EXPIRATION_TIME: int = 300 # 验证码过期时间(秒) + + +settings = Settings() \ No newline at end of file diff --git a/app/curd/__init__.py b/app/curd/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/curd/article.py b/app/curd/article.py new file mode 100644 index 0000000..5656475 --- /dev/null +++ b/app/curd/article.py @@ -0,0 +1,373 @@ +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy import select, delete, insert, desc +from sqlalchemy import func, cast, Date +from datetime import datetime, timedelta +from app.models.model import User, Group, Folder, Article, Note, Tag, user_group, self_recycle_bin + +async def crud_upload_to_self_folder(name: str, folder_id: int, db: AsyncSession): + new_article = Article(name=name, folder_id=folder_id) + db.add(new_article) + await db.commit() + await db.refresh(new_article) + return new_article.id + +async def crud_get_self_folders(user_id: int, page_number: int, page_size: int, db: AsyncSession): + query = select(Folder).where(Folder.user_id == user_id, Folder.visible == True).order_by(Folder.id.desc()) + count_query = select(func.count()).select_from(query.subquery()) + count_result = await db.execute(count_query) + total_num = count_result.scalar() + + if page_number and page_size: + offset = (page_number - 1) * page_size + query = query.offset(offset).limit(page_size) + result = await db.execute(query) + folders = result.scalars().all() + + return total_num, folders + +async def crud_get_articles_in_folder(folder_id: int, page_number: int, page_size: int, db: AsyncSession): + query = select(Article).where(Article.folder_id == folder_id, Article.visible == True).order_by(Article.id.desc()) + count_query = select(func.count()).select_from(query.subquery()) + count_result = await db.execute(count_query) + total_num = count_result.scalar() + + if page_number and page_size: + offset = (page_number - 1) * page_size + query = query.offset(offset).limit(page_size) + result = await db.execute(query) + articles = result.scalars().all() + + return total_num, articles + +async def crud_self_create_folder(name: str, user_id: int, db: AsyncSession): + new_folder = Folder(name=name, user_id=user_id) + db.add(new_folder) + await db.commit() + await db.refresh(new_folder) + return new_folder.id + +async def crud_self_article_to_recycle_bin(article_id: int, user_id: int, db: AsyncSession): + # 维护 article 表 + query = select(Article).where(Article.id == article_id) + result = await db.execute(query) + article = result.scalar_one_or_none() + article.visible = False + + # 维护 self_recycle_bin 表 + recycle = insert(self_recycle_bin).values(user_id=user_id, type=2, id=article_id, name=article.name, folder_id=article.folder_id) + await db.execute(recycle) + + await db.commit() + await db.refresh(article) + +async def crud_self_folder_to_recycle_bin(folder_id: int, user_id: int, db: AsyncSession): + # 维护 folder 表 + query = select(Folder).where(Folder.id == folder_id) + result = await db.execute(query) + folder = result.scalar_one_or_none() + folder.visible = False + + # 维护 self_recycle_bin 表 + recycle = insert(self_recycle_bin).values(user_id=user_id, type=1, id=folder_id, name=folder.name) + await db.execute(recycle) + + await db.commit() + await db.refresh(folder) + +async def crud_read_article(article_id: int, db: AsyncSession): + query = select(Article).where(Article.id == article_id) + result = await db.execute(query) + article = result.scalar_one_or_none() + return article.name + +async def crud_import_self_folder(folder_name: str, article_names, user_id: int, db: AsyncSession): + result = [] + + # 新建文件夹 + new_folder = Folder(name=folder_name, user_id=user_id) + db.add(new_folder) + await db.commit() + await db.refresh(new_folder) + + # 新建文献 + new_articles = [Article(name=article_name, folder_id=new_folder.id) for article_name in article_names] + db.add_all(new_articles) + await db.commit() + for new_article in new_articles: + await db.refresh(new_article) + result.append(new_article.id) + result.append(new_article.name) + + return result + +async def crud_export_self_folder(folder_id: int, db: AsyncSession): + query = select(Folder).where(Folder.id == folder_id) + result = await db.execute(query) + folder = result.scalar_one_or_none() + folder_name = folder.name + + query = select(Article).where(Article.folder_id == folder_id, Article.visible == True).order_by(Article.id.desc()) + result = await db.execute(query) + articles = result.scalars().all() + article_id = [] + article_name = [] + for article in articles: + article_id.append(article.id) + article_name.append(article.name) + + return folder_name, article_id, article_name + +async def crud_create_tag(article_id: int, content: str, db: AsyncSession): + new_tag = Tag(article_id=article_id, content=content) + db.add(new_tag) + await db.commit() + await db.refresh(new_tag) + +async def crud_delete_tag(tag_id: int, db: AsyncSession): + query = select(Tag).filter(Tag.id == tag_id) + result = await db.execute(query) + tag = result.scalar_one_or_none() + await db.delete(tag) + await db.commit() + +async def crud_get_article_tags(article_id: int, db: AsyncSession): + query = select(Tag).where(Tag.article_id == article_id).order_by(Tag.id.asc()) + result = await db.execute(query) + tags = result.scalars().all() + return tags + +async def crud_all_tags_order(article_id: int, tag_contents, db: AsyncSession): + query = delete(Tag).where(Tag.article_id == article_id) + await db.execute(query) + await db.commit() + + new_tags = [] + for i in range(0, len(tag_contents)): + new_tags.append(Tag(content=tag_contents[i], article_id=article_id)) + db.add_all(new_tags) + await db.commit() + for i in range(0, len(new_tags)): + await db.refresh(new_tags[i]) + +async def crud_change_folder_name(folder_id: int, folder_name: str, db: AsyncSession): + query = select(Folder).where(Folder.id == folder_id) + result = await db.execute(query) + folder = result.scalar_one_or_none() + + folder.name = folder_name + await db.commit() + await db.refresh(folder) + +async def crud_change_article_name(article_id: int, article_name: str, db: AsyncSession): + query = select(Article).where(Article.id == article_id) + result = await db.execute(query) + article = result.scalar_one_or_none() + + article.name = article_name + await db.commit() + await db.refresh(article) + +async def crud_article_statistic(db: AsyncSession): + # 获取明天日期和7天前的日期 + tomorrow = datetime.now().date() + timedelta(days=1) + seven_days_ago = datetime.now().date() - timedelta(days=6) + + # 查询近7天内的文献数目,按日期分组 + query = ( + select( + cast(Article.create_time, Date).label("date"), # 按日期分组 + func.count(Article.id).label("count") # 统计每日期的文献数 + ) + .where( + Article.create_time >= seven_days_ago, # 大于等于7天前的0点 + Article.create_time < tomorrow # 小于明天0点 + ) + .group_by(cast(Article.create_time, Date)) # 按日期分组 + .order_by(cast(Article.create_time, Date)) # 按日期排序 + ) + + # 执行查询 + result = await db.execute(query) + data = result.fetchall() + + # 格式化结果为字典列表 + articles = [{"date": row.date, "count": row.count} for row in data] + + # 若某日期没有记录,则为0 + for i in range(0, 7): + if i == len(articles) or articles[i].get("date") != seven_days_ago + timedelta(days=i): + articles.insert(i, {"date": seven_days_ago + timedelta(days=i), "count": 0}) + + return articles + +async def crud_self_tree(user_id: int, page_number: int, page_size: int, db: AsyncSession): + query = select(Folder).where(Folder.user_id == user_id, Folder.visible == True).order_by(Folder.id.desc()) + count_query = select(func.count()).select_from(query.subquery()) + count_result = await db.execute(count_query) + total_num = count_result.scalar() + + if page_number and page_size: + offset = (page_number - 1) * page_size + query = query.offset(offset).limit(page_size) + result = await db.execute(query) + folders = result.scalars().all() + + folder_array = [{"folder_id": folder.id, "folder_name": folder.name, "articles": []} for folder in folders] + for i in range(len(folder_array)): + query = select(Article).where(Article.folder_id == folder_array[i].get("folder_id"), Article.visible == True).order_by(Article.id.desc()) + result = await db.execute(query) + articles = result.scalars().all() + article_array = [{"article_id": article.id, "article_name": article.name, "tags": [], "notes": []} for article in articles] + folder_array[i]["articles"] = article_array + for j in range(len(article_array)): + # 查找所有tag + query = select(Tag).where(Tag.article_id == article_array[j].get("article_id")).order_by(Tag.id.asc()) + result = await db.execute(query) + tags = result.scalars().all() + tag_array = [{"tag_id": tag.id, "tag_content": tag.content} for tag in tags] + article_array[j]["tags"] = tag_array + # 查找所有note + query = select(Note).where(Note.article_id == article_array[j].get("article_id"), Note.visible == True).order_by(Note.id.desc()) + result = await db.execute(query) + notes = result.scalars().all() + note_array = [{"note_id": note.id, "note_title": note.title} for note in notes] + article_array[j]["notes"] = note_array + + return total_num, folder_array + +async def crud_self_article_statistic(user_id: int, db: AsyncSession): + # 查询个人拥有的、未被删除的文献总数 + query = ( + select(func.count(Article.id)) + .join(Folder, Article.folder_id == Folder.id) + .where(Folder.user_id == user_id, Folder.visible == True, Article.visible == True) + ) + result = await db.execute(query) + article_total_num = result.scalar_one_or_none() + + # 获取明天日期和7天前的日期 + tomorrow = datetime.now().date() + timedelta(days=1) + seven_days_ago = datetime.now().date() - timedelta(days=6) + + # 查询近7天内的文献数目,按日期分组 + query = ( + select( + cast(Article.create_time, Date).label("date"), # 按日期分组 + func.count(Article.id).label("count") # 统计每日期的文献数 + ) + .join(Folder, Article.folder_id == Folder.id) + .where( + Folder.user_id == user_id, + Folder.visible == True, + Article.visible == True, + Article.create_time >= seven_days_ago, # 大于等于7天前的0点 + Article.create_time < tomorrow, # 小于明天0点 + ) + .group_by(cast(Article.create_time, Date)) # 按日期分组 + .order_by(cast(Article.create_time, Date)) # 按日期排序 + ) + + # 执行查询 + result = await db.execute(query) + data = result.fetchall() + + # 格式化结果为字典列表 + articles = [{"date": row.date, "count": row.count} for row in data] + + # 若某日期没有记录,则为0 + for i in range(0, 7): + if i == len(articles) or articles[i].get("date") != seven_days_ago + timedelta(days=i): + articles.insert(i, {"date": seven_days_ago + timedelta(days=i), "count": 0}) + + return article_total_num, articles + +async def crud_items_in_recycle_bin(user_id: int, page_number: int, page_size: int, db: AsyncSession): + query = select( + self_recycle_bin.c.type, + self_recycle_bin.c.id, + self_recycle_bin.c.name, + self_recycle_bin.c.create_time + ).where(self_recycle_bin.c.user_id == user_id).order_by(desc(self_recycle_bin.c.create_time)) + + if page_number and page_size: + offset = (page_number - 1) * page_size + query = query.offset(offset).limit(page_size) + + result = await db.execute(query) + items = result.fetchall() + + return [{"type": item.type, "id": item.id, "name": item.name, "time": item.create_time.strftime("%Y-%m-%d %H:%M:%S")} for item in items] + +async def crud_delete_forever(type: int, id: int, db: AsyncSession): + query = delete(self_recycle_bin).where(self_recycle_bin.c.type == type, self_recycle_bin.c.id == id) + await db.execute(query) + if type == 1: + query = delete(Folder).where(Folder.id==id) + elif type == 2: + query = delete(Article).where(Article.id==id) + else: + query = delete(Note).where(Note.id==id) + await db.execute(query) + await db.commit() + +async def crud_recover(type: int, id: int, db: AsyncSession): + query = select(self_recycle_bin).where(self_recycle_bin.c.type == type, self_recycle_bin.c.id == id) + result = await db.execute(query) + item = result.first() + if type == 3: + # 检查上级文献存在性 + query = select(Article).where(Article.id == item.article_id) + result = await db.execute(query) + article = result.scalar_one_or_none() + article_name = article.name + article_visible = article.visible + # 检查上级文件夹存在性 + query = select(Folder).where(Folder.id == item.folder_id) + result = await db.execute(query) + folder = result.scalar_one_or_none() + folder_name = folder.name + folder_visible = folder.visible + # 若上级不存在,则给用户以提示信息,请用户先恢复相应的文件夹和文献 + if not article_visible or not folder_visible: + return {"info": "Note recovered failed, please check its upper-level node", "folder_name": folder_name, "article_name": article_name} + # 若上级存在,则正常恢复即可,在回收站表中删除该表项,并将Note表中visible改为True + query = delete(self_recycle_bin).where(self_recycle_bin.c.type == type, self_recycle_bin.c.id == id) + await db.execute(query) + query = select(Note).where(Note.id == id) + result = await db.execute(query) + note = result.scalar_one_or_none() + note.visible = True + await db.commit() + await db.refresh(note) + return {"info": "Note recovered successfully"} + if type == 2: + # 检查上级文件夹存在性 + query = select(Folder).where(Folder.id == item.folder_id) + result = await db.execute(query) + folder = result.scalar_one_or_none() + folder_name = folder.name + folder_visible = folder.visible + # 若上级不存在,则给用户以提示信息,请用户先恢复相应的文件夹 + if not folder_visible: + return {"info": "Article recovered failed, please check its upper-level node", "folder_name": folder_name} + # 若上级存在,则正常恢复即可,在回收站表中删除该表项,并将Article表中visible改为True + query = delete(self_recycle_bin).where(self_recycle_bin.c.type == type, self_recycle_bin.c.id == id) + await db.execute(query) + query = select(Article).where(Article.id == id) + result = await db.execute(query) + article = result.scalar_one_or_none() + article.visible = True + await db.commit() + await db.refresh(article) + return {"info": "Article recovered successfully"} + if type == 1: + # 正常恢复即可,在回收站表中删除该表项,并将Folder表中visible改为True + query = delete(self_recycle_bin).where(self_recycle_bin.c.type == type, self_recycle_bin.c.id == id) + await db.execute(query) + query = select(Folder).where(Folder.id == id) + result = await db.execute(query) + folder = result.scalar_one_or_none() + folder.visible = True + await db.commit() + await db.refresh(folder) + return {"info": "Folder recovered successfully"} \ No newline at end of file diff --git a/app/curd/articleDB.py b/app/curd/articleDB.py new file mode 100644 index 0000000..9af3af5 --- /dev/null +++ b/app/curd/articleDB.py @@ -0,0 +1,55 @@ +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.future import select +from sqlalchemy import func +from app.models.model import ArticleDB +from app.schemas.articleDB import UploadArticle, GetArticle, DeLArticle, GetResponse + +async def create_article_in_db(db: AsyncSession, upload_article: UploadArticle): + """ + Create a new article in the database. + """ + article =ArticleDB(title=upload_article.title, url=upload_article.url, author=upload_article.author, file_path=upload_article.file_path) + db.add(article) + await db.commit() + await db.refresh(article) + return article + +async def get_article_in_db(db: AsyncSession, get_article: GetArticle): + + if get_article.id: + result = await db.execute(select(ArticleDB).where(ArticleDB.id == get_article.id)) + articles = result.scalars().first() + total_count = 1 + articles = [articles] if articles else [] + elif get_article.page and get_article.page_size: + count_result = await db.execute(select(func.count()).select_from(UploadArticle)) + total_count = count_result.scalar() # 获取总数 + # 分页查询文章 + result = await db.execute( + select(ArticleDB) + .offset((get_article.page - 1) * get_article.page_size) + .limit(get_article.page_size) + ) + articles = result.scalars().all() + else: + result = await db.execute(select(ArticleDB)) + articles = result.scalars().all() + total_count = len(articles) + + return [GetResponse.model_validate(article) for article in articles], total_count + +async def get_article_in_db_by_id(db: AsyncSession, article_id: int): + """ + Get an article by its ID. + """ + result = await db.execute(select(ArticleDB).where(ArticleDB.id == article_id)) + article = result.scalars().first() + return article + +async def get_article_info_in_db_by_id(db: AsyncSession, article_id: int): + """ + Get an article by its ID. + """ + result = await db.execute(select(ArticleDB).where(ArticleDB.id == article_id)) + article = result.scalars().first() + return article.file_path, article.title \ No newline at end of file diff --git a/app/curd/group.py b/app/curd/group.py new file mode 100644 index 0000000..4c8382d --- /dev/null +++ b/app/curd/group.py @@ -0,0 +1,58 @@ +from fastapi import HTTPException +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.exc import IntegrityError +from sqlalchemy import select, insert, delete +from app.models.model import User, Group, Folder, Article, Note, Tag, user_group, enter_application + +async def crud_create(leader: int, name: str, description: str, db: AsyncSession): + new_group = Group(leader=leader, name=name, description=description) + db.add(new_group) + await db.commit() + await db.refresh(new_group) + return new_group.id + +async def crud_apply_to_enter(user_id: int, group_id: int, db: AsyncSession): + # 是否已经在组织中 + query = select(user_group).where(user_group.c.user_id == user_id, user_group.c.group_id == group_id) + result = await db.execute(query) + existing = result.first() + if existing: + raise HTTPException(status_code=405, detail="Already in the group") + query = select(Group).where(Group.id == group_id) + result = await db.execute(query) + group = result.scalar_one_or_none() + if group.leader == user_id: + raise HTTPException(status_code=405, detail="Already in the group") + + # 插入申请表,若已存在申请则抛出异常 + query = insert(enter_application).values(user_id=user_id, group_id=group_id) + try: + await db.execute(query) + await db.commit() + except IntegrityError: + await db.rollback() + raise HTTPException(status_code=405, detail="Don't apply repeatedly") + +async def crud_get_applications(group_id: int, db: AsyncSession): + query = select(User.id, User.username).where(User.id.in_( + select(enter_application.c.user_id).where(enter_application.c.group_id == group_id) + )) + result = await db.execute(query) + users = result.all() + return [{"user_id": user.id, "user_name": user.username} for user in users] + +async def crud_reply_to_enter(user_id: int, group_id: int, reply: int, db: AsyncSession): + # 答复后,需要从待处理申请的表中删除表项 + query = delete(enter_application).where(enter_application.c.user_id == user_id, enter_application.c.group_id == group_id) + result = await db.execute(query) + if result.rowcount == 0: # 如果没有删除任何行,说明不存在该项 + raise HTTPException(status_code=405, detail="Application is not existed or already handled") + await db.commit() + + if reply == 1: + new_relation = insert(user_group).values(user_id=user_id, group_id=group_id) + await db.execute(new_relation) + await db.commit() + return "Add new member successfully" + + return "Refuse the application successfully" diff --git a/app/curd/note.py b/app/curd/note.py new file mode 100644 index 0000000..92f007b --- /dev/null +++ b/app/curd/note.py @@ -0,0 +1,180 @@ +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.future import select +from sqlalchemy import func, cast, Date, insert +from datetime import datetime, timedelta +from app.models.model import Note, self_recycle_bin, Article, Folder +from app.schemas.note import NoteCreate, NoteUpdate, NoteFind, NoteResponse + +async def create_note_in_db(note: NoteCreate, db: AsyncSession, user_id: int): + new_note = Note(content=note.content, article_id=note.article_id, title=note.title, creator_id=user_id) + db.add(new_note) + await db.commit() + await db.refresh(new_note) + return new_note + +async def delete_note_in_db(note_id: int, user_id: int, db: AsyncSession): + stmt = select(Note).where(Note.id == note_id) + result = await db.execute(stmt) + note = result.scalar_one_or_none() + if note: + # 将 visible 设置为 False,表示删除 + note.visible = False + # 找 folder_id + stmt = select(Article).where(Article.id == note.article_id) + result = await db.execute(stmt) + article = result.scalar_one_or_none() + # 插入 self_recycle_bin 表 + recycle = insert(self_recycle_bin).values(user_id=user_id, type=3, id=note_id, name=note.title, article_id=note.article_id, folder_id=article.folder_id) + await db.execute(recycle) + await db.commit() + return note + +async def update_note_in_db(note_id: int, note: NoteUpdate, db: AsyncSession): + stmt = select(Note).where(Note.id == note_id) + result = await db.execute(stmt) + existing_note = result.scalar_one_or_none() + if existing_note: + if note.title is not None: + existing_note.title = note.title + if note.content is not None: + existing_note.content = note.content + await db.commit() + await db.refresh(existing_note) + return existing_note + +async def find_notes_in_db(note_find: NoteFind, db: AsyncSession): + stmt = select(Note).where(Note.visible == True) # 只查询可见的笔记 + + if note_find.id is not None: + stmt = stmt.where(Note.id == note_find.id) + elif note_find.article_id is not None: + stmt = stmt.where(Note.article_id == note_find.article_id) + + total_count_stmt = select(func.count()).select_from(stmt) + total_count_result = await db.execute(total_count_stmt) + total_count = total_count_result.scalar() + + if note_find.page is not None and note_find.page_size is not None: + offset = (note_find.page - 1) * note_find.page_size + stmt = stmt.offset(offset).limit(note_find.page_size) + + result = await db.execute(stmt) + notes = [NoteResponse.model_validate(note) for note in result.scalars().all()] + return notes, total_count + +async def find_notes_title_in_db(note_find: NoteFind, db: AsyncSession): + stmt = select(Note.title).where(Note.visible == True) # 只查询可见的笔记 + + if note_find.id is not None: + stmt = stmt.where(Note.id == note_find.id) + elif note_find.article_id is not None: + stmt = stmt.where(Note.article_id == note_find.article_id) + + total_count_stmt = select(func.count()).select_from(stmt.subquery()) + total_count_result = await db.execute(total_count_stmt) + total_count = total_count_result.scalar() + + if note_find.page is not None and note_find.page_size is not None: + offset = (note_find.page - 1) * note_find.page_size + stmt = stmt.offset(offset).limit(note_find.page_size) + + result = await db.execute(stmt) + notes = [row[0] for row in result.fetchall()] + return notes, total_count + +async def find_recent_notes_in_db(db: AsyncSession): + """ + 返回近7天内创建的笔记的数目和对应日期 + """ + # 获取当前日期和7天前的日期 + tomorrow = datetime.now().date() + timedelta(days=1) + seven_days_ago = datetime.now().date() - timedelta(days=6) + + # 查询近7天内的笔记数目,按日期分组 + stmt = ( + select( + cast(Note.create_time, Date).label("date"), # 按日期分组 + func.count(Note.id).label("count") # 统计每日期的笔记数 + ) + .where( + Note.create_time >= seven_days_ago, # 筛选近7天的笔记 + Note.create_time < tomorrow # 包括今天 + ) + .group_by(cast(Note.create_time, Date)) # 按日期分组 + .order_by(cast(Note.create_time, Date)) # 按日期排序 + ) + + # 执行查询 + result = await db.execute(stmt) + data = result.fetchall() + + # 格式化结果为字典列表 + recent_notes = [{"date": row.date, "count": row.count} for row in data] + + # 若某日期没有记录,则为0 + for i in range(0, 7): + if i == len(recent_notes) or recent_notes[i].get("date") != seven_days_ago + timedelta(days=i): + recent_notes.insert(i, {"date": seven_days_ago + timedelta(days=i), "count": 0}) + + return recent_notes + +async def find_self_recent_notes_in_db(db: AsyncSession, user_id: int): + """ + 返回近7天内创建的笔记的数目和对应日期 + """ + # 获取当前日期和7天前的日期 + tomorrow = datetime.now().date() + timedelta(days=1) + seven_days_ago = datetime.now().date() - timedelta(days=6) + + # 查询近7天内的笔记数目,按日期分组 + stmt = ( + select( + cast(Note.create_time, Date).label("date"), # 按日期分组 + func.count(Note.id).label("count") # 统计每日期的笔记数 + ) + .join(Article, Note.article_id == Article.id) + .join(Folder, Article.folder_id == Folder.id) + .where( + Note.visible == True, + Article.visible == True, + Folder.visible == True, + Note.create_time >= seven_days_ago, # 筛选近7天的笔记 + Note.create_time < tomorrow, # 包括今天 + Note.creator_id == user_id # 筛选特定用户的笔记 + ) + .group_by(cast(Note.create_time, Date)) # 按日期分组 + .order_by(cast(Note.create_time, Date)) # 按日期排序 + ) + + # 执行查询 + result = await db.execute(stmt) + data = result.fetchall() + + # 格式化结果为字典列表 + recent_notes = [{"date": row.date, "count": row.count} for row in data] + + # 若某日期没有记录,则为0 + for i in range(0, 7): + if i == len(recent_notes) or recent_notes[i].get("date") != seven_days_ago + timedelta(days=i): + recent_notes.insert(i, {"date": seven_days_ago + timedelta(days=i), "count": 0}) + + return recent_notes + +async def find_self_notes_count_in_db(db: AsyncSession, user_id: int): + """ + 返回用户的笔记数目 + """ + stmt = ( + select(func.count(Note.id)) + .join(Article, Note.article_id == Article.id) + .join(Folder, Article.folder_id == Folder.id) + .where( + Note.creator_id == user_id, + Note.visible == True, + Article.visible == True, + Folder.visible == True + ) + ) + result = await db.execute(stmt) + count = result.scalar_one_or_none() + return count \ No newline at end of file diff --git a/app/curd/user.py b/app/curd/user.py new file mode 100644 index 0000000..6545c34 --- /dev/null +++ b/app/curd/user.py @@ -0,0 +1,44 @@ +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.future import select +from app.models.model import User +from app.schemas.user import UserUpdate + +async def get_user_by_email(db: AsyncSession, email: str): + stmt = select(User).where(User.email == email) + result = await db.execute(stmt) + return result.scalar_one_or_none() + +async def create_user(db: AsyncSession, email: str, username: str, hashed_password: str): + new_user = User(email=email, username=username, password=hashed_password, avatar="/static/avatar/default.png") + db.add(new_user) + await db.commit() + await db.refresh(new_user) + return new_user + +async def update_user_in_db(db: AsyncSession, user_update: UserUpdate, user_id: int): + stmt = select(User).where(User.id == user_id) + result = await db.execute(stmt) + user = result.scalar_one_or_none() + if user: + if user_update.username: + user.username = user_update.username + if user_update.address: + user.address = user_update.address + if user_update.university: + user.university = user_update.university + if user_update.introduction: + user.introduction = user_update.introduction + user.avatar = user_update.avatar + await db.commit() + await db.refresh(user) + return user + +async def update_user_password(db: AsyncSession, user_id: int, hashed_password: str): + stmt = select(User).where(User.id == user_id) + result = await db.execute(stmt) + user = result.scalar_one_or_none() + if user: + user.password = hashed_password + await db.commit() + await db.refresh(user) + return user diff --git a/app/db/base.py b/app/db/base.py new file mode 100644 index 0000000..d885a92 --- /dev/null +++ b/app/db/base.py @@ -0,0 +1,2 @@ +from app.db.base_class import Base # 导入创建的base类 +from app.models.model import User \ No newline at end of file diff --git a/app/db/base_class.py b/app/db/base_class.py new file mode 100644 index 0000000..d692f60 --- /dev/null +++ b/app/db/base_class.py @@ -0,0 +1,13 @@ +from typing import Any + +from sqlalchemy.ext.declarative import as_declarative, declared_attr + + +@as_declarative() +class Base: + id: Any + __name__: str + # Generate __tablename__ automatically + @declared_attr + def __tablename__(cls) -> str: + return cls.__name__.lower() \ No newline at end of file diff --git a/app/db/session.py b/app/db/session.py new file mode 100644 index 0000000..68701ee --- /dev/null +++ b/app/db/session.py @@ -0,0 +1,19 @@ +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker +from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession + +from app.core.config import settings + +engine = create_engine(settings.SQLALCHEMY_DATABASE_URL, pool_pre_ping=True) #连接mysql +SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) + +# 创建异步引擎 +async_engine = create_async_engine(settings.SQLALCHEMY_DATABASE_URL, pool_pre_ping=True) + +# 创建异步会话 +async_session = sessionmaker( + bind=async_engine, + class_=AsyncSession, + autocommit=False, + autoflush=False +) \ No newline at end of file diff --git a/app/main.py b/app/main.py new file mode 100644 index 0000000..e4e1c3c --- /dev/null +++ b/app/main.py @@ -0,0 +1,44 @@ +from fastapi import FastAPI, Request +from app.routers.router import include_routers +from fastapi_pagination import add_pagination +from loguru import logger +from fastapi.middleware.cors import CORSMiddleware +from fastapi.staticfiles import StaticFiles + +app = FastAPI() + +@app.get("/") +def read_root(): + return {"Hello": "World"} + +@app.get("/items/{item_id}") +def read_item(item_id: int, q: str = None): + return {"item_id": item_id, "q": q} + +# 注册路由 +include_routers(app) + +# 注册分页功能 +add_pagination(app) + +# 设置日志配置 +logger.add("app.log", rotation="1 MB", retention="7 days", level="INFO") + +@app.middleware("http") +async def log_requests(request: Request, call_next): + logger.info(f"Request: {request.method} {request.url}") + response = await call_next(request) + logger.info(f"Response status: {response.status_code}") + return response + +# 配置 CORS +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], # 允许的前端来源 + allow_credentials=True, # 允许发送凭据(如 Cookies 或 Authorization 头) + allow_methods=["*"], # 允许的 HTTP 方法 + allow_headers=["*"], # 允许的请求头 +) + +# 挂载静态文件目录 +app.mount("/static", StaticFiles(directory="app/static"), name="static") \ No newline at end of file diff --git a/app/models/__init__.py b/app/models/__init__.py new file mode 100644 index 0000000..026426c --- /dev/null +++ b/app/models/__init__.py @@ -0,0 +1 @@ +from .model import User \ No newline at end of file diff --git a/app/models/base.py b/app/models/base.py new file mode 100644 index 0000000..860e542 --- /dev/null +++ b/app/models/base.py @@ -0,0 +1,3 @@ +from sqlalchemy.ext.declarative import declarative_base + +Base = declarative_base() diff --git a/app/models/model.py b/app/models/model.py new file mode 100644 index 0000000..8c89021 --- /dev/null +++ b/app/models/model.py @@ -0,0 +1,134 @@ +from sqlalchemy import Column, Integer, String, Boolean, Table, ForeignKey, UniqueConstraint, CheckConstraint, Text, DateTime +from sqlalchemy.orm import relationship +from sqlalchemy.sql import func +from app.db.base_class import Base + +# 多对多关系表 +user_group = Table( + 'user_group', Base.metadata, + Column('user_id', Integer, ForeignKey('users.id'), primary_key=True), + Column('group_id', Integer, ForeignKey('groups.id'), primary_key=True), + Column('is_admin', Boolean, default=False) +) + +enter_application = Table( + 'enter_application', Base.metadata, + Column('user_id', Integer, ForeignKey('users.id'), primary_key=True), + Column('group_id', Integer, ForeignKey('groups.id'), primary_key=True), +) + +self_recycle_bin = Table( + 'self_recycle_bin', Base.metadata, + Column('user_id', Integer, ForeignKey('users.id')), + Column('type', Integer, primary_key=True), # 1: folder 2: article 3: note + Column('id', Integer, primary_key=True), + Column('name', Text, nullable=False), # 回收站显示 + Column('create_time', DateTime, default=func.now(), nullable=False), # 加入回收站的时间 + Column('article_id', Integer, ForeignKey('articles.id', ondelete="CASCADE")), + Column('folder_id', Integer, ForeignKey('folders.id', ondelete="CASCADE")) + # 最后两列为有上级时的上级节点信息,用于恢复时检查是否有上级节点在回收站中,和彻底删除时的级联删除 +) + +class User(Base): + __tablename__ = 'users' + + id = Column(Integer, primary_key=True, index=True, autoincrement=True) + email = Column(String(30), unique=True, index=True, nullable=False) + username = Column(String(30), index=True, nullable=False) + password = Column(String(60), nullable=False) + avatar = Column(String(100)) + address = Column(String(100)) + university = Column(String(100)) + introduction = Column(Text) + create_time = Column(DateTime, default=func.now(), nullable=False) # 创建时间 + groups = relationship('Group', secondary=user_group, back_populates='users') + folders = relationship('Folder', back_populates='user') + +class Group(Base): + __tablename__ = 'groups' + + id = Column(Integer, primary_key=True, index=True, autoincrement=True) + leader = Column(Integer) + name = Column(String(30), nullable=False) + description = Column(String(200), nullable=False) + create_time = Column(DateTime, default=func.now(), nullable=False) # 创建时间 + update_time = Column(DateTime, default=func.now(), onupdate=func.now(), nullable=False) # 更新时间 + users = relationship('User', secondary=user_group, back_populates='groups') + folders = relationship('Folder', back_populates='group') + +class Folder(Base): + __tablename__ = 'folders' + + id = Column(Integer, primary_key=True, index=True, autoincrement=True) + name = Column(String(30), nullable=False) + + user_id = Column(Integer, ForeignKey('users.id')) + group_id = Column(Integer, ForeignKey('groups.id')) + + create_time = Column(DateTime, default=func.now(), nullable=False) # 创建时间 + update_time = Column(DateTime, default=func.now(), onupdate=func.now(), nullable=False) # 更新时间 + + visible = Column(Boolean, default=True, nullable=False) # 是否可见 False表示在回收站中 + + # 关系定义 + user = relationship('User', back_populates='folders') + group = relationship('Group', back_populates='folders') + articles = relationship('Article', back_populates='folder', cascade="all, delete-orphan") + + __table_args__ = ( + # 不能同时为空 + UniqueConstraint('user_id', 'group_id', name='uq_user_group_folder'), # SQL中认为null 和 null 不相等 + CheckConstraint('user_id IS NOT NULL OR group_id IS NOT NULL', name='check_user_or_group'), + ) + +class Article(Base): + __tablename__ = 'articles' + + id = Column(Integer, primary_key=True, index=True, autoincrement=True) + name = Column(Text, nullable=False) + folder_id = Column(Integer, ForeignKey('folders.id', ondelete="CASCADE")) + create_time = Column(DateTime, default=func.now(), nullable=False) # 创建时间 + update_time = Column(DateTime, default=func.now(), onupdate=func.now(), nullable=False) # 更新时间 + + visible = Column(Boolean, default=True, nullable=False) # 是否可见 False表示在回收站中 + + folder = relationship('Folder', back_populates='articles', lazy='selectin') + notes = relationship('Note', back_populates='article', cascade="all, delete-orphan") + tags = relationship('Tag', back_populates='article') + +class Note(Base): + __tablename__ = 'notes' + + id = Column(Integer, primary_key=True, index=True, autoincrement=True) + title = Column(String(100), nullable=False) + content = Column(Text) # 将 content 字段类型改为 Text,以支持存储大量文本 + article_id = Column(Integer, ForeignKey('articles.id', ondelete="CASCADE")) + create_time = Column(DateTime, default=func.now(), nullable=False) # 创建时间 + update_time = Column(DateTime, default=func.now(), onupdate=func.now(), nullable=False) # 更新时间 + creator_id = Column(Integer, ForeignKey('users.id')) # 创建者ID + visible = Column(Boolean, default=True, nullable=False) # 是否可见 False表示在回收站中 + + article = relationship('Article', back_populates='notes') + +class Tag(Base): + __tablename__ = 'tags' + + id = Column(Integer, primary_key=True, index=True, autoincrement=True) + content = Column(String(30)) + article_id = Column(Integer, ForeignKey('articles.id')) + create_time = Column(DateTime, default=func.now(), nullable=False) # 创建时间 + update_time = Column(DateTime, default=func.now(), onupdate=func.now(), nullable=False) # 更新时间 + article = relationship('Article', back_populates='tags') + +class ArticleDB(Base): + __tablename__ = 'articleDB' + + id = Column(Integer, primary_key=True, index=True, autoincrement=True) + + title = Column(String(200), nullable=False) + url = Column(String(200), nullable=False) + author = Column(String(300), nullable=False) + file_path = Column(String(200), nullable=False) + + create_time = Column(DateTime, default=func.now(), nullable=False) # 创建时间 + update_time = Column(DateTime, default=func.now(), onupdate=func.now(), nullable=False) # 更新时间 \ No newline at end of file diff --git a/app/routers/__init__.py b/app/routers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/routers/router.py b/app/routers/router.py new file mode 100644 index 0000000..eee09a0 --- /dev/null +++ b/app/routers/router.py @@ -0,0 +1,39 @@ +from fastapi import Depends +from app.utils.auth import get_current_user +from app.api.v1.endpoints.auth import router as auth_router +from app.api.v1.endpoints.note import router as note_router +from app.api.v1.endpoints.user import router as user_router +from app.api.v1.endpoints.aichat import router as aichat_router +from app.api.v1.endpoints.article import router as article_router +from app.api.v1.endpoints.articleDB import router as articleDB_router +from app.api.v1.endpoints.group import router as group_router + +def include_auth_router(app): + app.include_router(auth_router, prefix="/public", tags=["auth"]) + +def include_note_router(app): + app.include_router(note_router, prefix="/notes", tags=["note"], dependencies=[Depends(get_current_user)]) + +def include_user_router(app): + app.include_router(user_router, prefix="/user", tags=["user"], dependencies=[Depends(get_current_user)]) + +def include_aichat_router(app): + app.include_router(aichat_router, prefix="/chat", tags=["aichat"], dependencies=[Depends(get_current_user)]) + +def include_article_router(app): + app.include_router(article_router, prefix="/article", tags=["article"], dependencies=[Depends(get_current_user)]) + +def include_articleDB_router(app): + app.include_router(articleDB_router, prefix="/database", tags=["articleDB"], dependencies=[Depends(get_current_user)]) + +def include_group_router(app): + app.include_router(group_router, prefix="/group", tags=["group"], dependencies=[Depends(get_current_user)]) + +def include_routers(app): + include_auth_router(app) + include_note_router(app) + include_user_router(app) + include_aichat_router(app) + include_article_router(app) + include_articleDB_router(app) + include_group_router(app) \ No newline at end of file diff --git a/app/schemas/__init__.py b/app/schemas/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/schemas/aichat.py b/app/schemas/aichat.py new file mode 100644 index 0000000..c6045c7 --- /dev/null +++ b/app/schemas/aichat.py @@ -0,0 +1,4 @@ +from pydantic import BaseModel + +class NoteInput(BaseModel): + input: str \ No newline at end of file diff --git a/app/schemas/article.py b/app/schemas/article.py new file mode 100644 index 0000000..3586f25 --- /dev/null +++ b/app/schemas/article.py @@ -0,0 +1,4 @@ +from pydantic import BaseModel + +class SelfCreateFolder(BaseModel): + folder_name: str \ No newline at end of file diff --git a/app/schemas/articleDB.py b/app/schemas/articleDB.py new file mode 100644 index 0000000..688188e --- /dev/null +++ b/app/schemas/articleDB.py @@ -0,0 +1,27 @@ +from pydantic import BaseModel +from datetime import datetime + +class UploadArticle(BaseModel): + title: str + author: str + url: str + file_path: str + +class GetArticle(BaseModel): + id: int | None = None + page: int | None = None + page_size: int | None = None + +class DeLArticle(BaseModel): + id: int + +class GetResponse(BaseModel): + id: int + title: str + url: str + create_time: datetime + update_time: datetime + file_path: str + + class Config: + from_attributes = True \ No newline at end of file diff --git a/app/schemas/auth.py b/app/schemas/auth.py new file mode 100644 index 0000000..e0826cb --- /dev/null +++ b/app/schemas/auth.py @@ -0,0 +1,17 @@ +from pydantic import BaseModel, EmailStr + +class UserCreate(BaseModel): + email: EmailStr + username: str + password: str + code: str + +class UserLogin(BaseModel): + email: EmailStr + password: str + +class UserSendCode(BaseModel): + email: EmailStr + +class ReFreshToken(BaseModel): + refresh_token: str \ No newline at end of file diff --git a/app/schemas/group.py b/app/schemas/group.py new file mode 100644 index 0000000..dd69fe7 --- /dev/null +++ b/app/schemas/group.py @@ -0,0 +1,4 @@ +from pydantic import BaseModel + +class ApplyToEnter(BaseModel): + group_id: int \ No newline at end of file diff --git a/app/schemas/note.py b/app/schemas/note.py new file mode 100644 index 0000000..4800eb9 --- /dev/null +++ b/app/schemas/note.py @@ -0,0 +1,32 @@ +from datetime import datetime +from pydantic import BaseModel + +class NoteCreate(BaseModel): + article_id: int + content: str + title: str + +class NoteDelete(BaseModel): + id: int + +class NoteUpdate(BaseModel): + id: int + content: str | None = None + title: str | None = None + +class NoteFind(BaseModel): + id: int | None = None + article_id: int | None = None + page: int | None = None + page_size: int | None = None + +class NoteResponse(BaseModel): + id: int + title: str + content: str + article_id: int + create_time: datetime + update_time: datetime + + class Config: + from_attributes = True \ No newline at end of file diff --git a/app/schemas/user.py b/app/schemas/user.py new file mode 100644 index 0000000..c41a9b1 --- /dev/null +++ b/app/schemas/user.py @@ -0,0 +1,13 @@ +from pydantic import BaseModel + +class UserUpdate(BaseModel): + username: str | None = None + avatar: str | None = None + address: str | None = None + university: str | None = None + introduction: str | None = None + +class PasswordUpdate(BaseModel): + old_password: str + new_password: str + diff --git a/app/static/__init__.py b/app/static/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/static/avatar/06914b55-d613-4f7d-854d-4442c1d7782e.png b/app/static/avatar/06914b55-d613-4f7d-854d-4442c1d7782e.png new file mode 100644 index 0000000..639c438 Binary files /dev/null and b/app/static/avatar/06914b55-d613-4f7d-854d-4442c1d7782e.png differ diff --git a/app/static/avatar/default.png b/app/static/avatar/default.png new file mode 100644 index 0000000..232972a Binary files /dev/null and b/app/static/avatar/default.png differ diff --git a/app/utils/__init__.py b/app/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/utils/aichat.py b/app/utils/aichat.py new file mode 100644 index 0000000..c3e7f0c --- /dev/null +++ b/app/utils/aichat.py @@ -0,0 +1,24 @@ +from openai import AsyncOpenAI +from app.core.config import settings + +client = AsyncOpenAI( + api_key=settings.KIMI_API_KEY, + base_url="https://api.moonshot.cn/v1", +) + +async def kimi_chat_stream(messages, model="moonshot-v1-8k", temperature=0.3): + """ + 异步AI流式对话工具方法,传入消息列表,流式返回AI回复内容。 + :param messages: List[dict] + :yield: str,AI回复内容片段 + """ + stream = await client.chat.completions.create( + model=model, + messages=messages, + temperature=temperature, + stream=True + ) + async for chunk in stream: + content = getattr(chunk.choices[0].delta, "content", None) + if content: + yield content \ No newline at end of file diff --git a/app/utils/auth.py b/app/utils/auth.py new file mode 100644 index 0000000..b79818e --- /dev/null +++ b/app/utils/auth.py @@ -0,0 +1,22 @@ +from fastapi.security import OAuth2PasswordBearer +from fastapi import Depends, HTTPException +from jose import JWTError, jwt # 用 jose 替代 jwt +from app.core.config import settings + +# 配置 OAuth2PasswordBearer +oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/auth/login") + +async def get_current_user(token: str = Depends(oauth2_scheme)): + try: + payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]) + email: str = payload.get("sub") + user_id: int = payload.get("id") + if email is None or user_id is None: + raise HTTPException( + status_code=401, detail="Invalid authentication credentials" + ) + return {"email": email, "id": user_id} + except JWTError: + raise HTTPException( + status_code=401, detail="Invalid authentication credentials" + ) \ No newline at end of file diff --git a/app/utils/get_db.py b/app/utils/get_db.py new file mode 100644 index 0000000..2165389 --- /dev/null +++ b/app/utils/get_db.py @@ -0,0 +1,5 @@ +from app.db.session import async_session + +async def get_db(): + async with async_session() as db: + yield db \ No newline at end of file diff --git a/app/utils/middleware.py b/app/utils/middleware.py new file mode 100644 index 0000000..e69de29 diff --git a/app/utils/ocr.py b/app/utils/ocr.py new file mode 100644 index 0000000..fee5e19 --- /dev/null +++ b/app/utils/ocr.py @@ -0,0 +1,38 @@ +from paddleocr import PaddleOCR +from pdf2image import convert_from_path +import numpy as np + +def pdf_to_text(pdf_path): + """ + 使用 PaddleOCR 将 PDF 文件转换为文字。 + + :param pdf_path: PDF 文件路径 + :param output_dir: 可选,保存中间图像文件的目录(如果需要) + :return: 提取的文字内容 + """ + # 初始化 PaddleOCR + ocr = PaddleOCR(use_angle_cls=True, lang='ch') # 支持中文 + + # 将 PDF 转换为图像 + images = convert_from_path(pdf_path) + + extracted_text = [] + + for i, image in enumerate(images): # 解包 enumerate 返回的元组 + # 将 PIL 图像转换为 OCR 可处理的格式 + image_np = np.array(image) + + # 使用 PaddleOCR 进行文字识别 + result = ocr.ocr(image_np, cls=True) + + # 提取文字部分 + for line in result[0]: + extracted_text.append(line[1][0]) # line[1][0] 是识别的文字 + + return "\n".join(extracted_text) + + +if __name__ == "__main__": + pdf_path = "example.pdf" + text = pdf_to_text(pdf_path) + print(text) \ No newline at end of file diff --git a/app/utils/redis.py b/app/utils/redis.py new file mode 100644 index 0000000..17c9e7c --- /dev/null +++ b/app/utils/redis.py @@ -0,0 +1,30 @@ +import redis +import time +import os + +redis_client = None # 全局 Redis 客户端变量 + +def get_redis_client(): + """ + 初始化并返回 Redis 客户端。 + 如果 Redis 客户端已存在,则直接返回。 + """ + global redis_client + if redis_client is None: + while True: + try: + print("Connecting to Redis...") + redis_client = redis.StrictRedis( + host=os.getenv("REDIS_HOST", "localhost"), + port=int(os.getenv("REDIS_PORT", 6379)), + password=os.getenv("REDIS_PASSWORD", None), + db=0, + decode_responses=True + ) + redis_client.ping() + print("Connected to Redis successfully.") + break + except redis.ConnectionError: + print("Redis connection failed, retrying...") + time.sleep(1) + return redis_client diff --git a/img/er_diagram.jpg b/img/er_diagram.jpg new file mode 100644 index 0000000..fd353ee Binary files /dev/null and b/img/er_diagram.jpg differ diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..91c416d --- /dev/null +++ b/requirements.txt @@ -0,0 +1,98 @@ +aiosmtplib==4.0.0 +albucore==0.0.23 +albumentations==2.0.5 +alembic==1.15.2 +annotated-types==0.7.0 +anyio==4.9.0 +astor==0.8.1 +async-timeout==5.0.1 +asyncmy==0.2.10 +bcrypt==4.3.0 +beautifulsoup4==4.13.4 +certifi==2025.1.31 +cffi==1.17.1 +charset-normalizer==3.4.1 +click==8.1.8 +colorama==0.4.6 +cryptography==44.0.2 +Cython==3.0.12 +decorator==5.2.1 +distro==1.9.0 +dnspython==2.7.0 +dotenv==0.9.9 +ecdsa==0.19.1 +email_validator==2.2.0 +fastapi==0.115.12 +fastapi-pagination==0.13.0 +fire==0.7.0 +fonttools==4.57.0 +greenlet==3.1.1 +h11==0.14.0 +httpcore==1.0.8 +httpx==0.28.1 +idna==3.10 +imageio==2.37.0 +iniconfig==2.1.0 +jiter==0.9.0 +jwt==1.3.1 +lazy_loader==0.4 +lmdb==1.6.2 +loguru==0.7.3 +lxml==5.3.2 +Mako==1.3.9 +MarkupSafe==3.0.2 +networkx==3.4.2 +numpy==2.2.4 +openai==1.75.0 +opencv-contrib-python==4.11.0.86 +opencv-python==4.11.0.86 +opencv-python-headless==4.11.0.86 +opt-einsum==3.3.0 +packaging==25.0 +paddleocr==2.10.0 +paddlepaddle==3.0.0 +pandas==2.2.3 +passlib==1.7.4 +pdf2image==1.17.0 +pillow==11.2.1 +pluggy==1.5.0 +protobuf==6.30.2 +pyasn1==0.4.8 +pyclipper==1.3.0.post6 +pycparser==2.22 +pydantic==2.11.2 +pydantic_core==2.33.1 +PyJWT==2.10.1 +PyMySQL==1.1.1 +pytest==8.3.5 +python-dateutil==2.9.0.post0 +python-docx==1.1.2 +python-dotenv==1.1.0 +python-jose==3.4.0 +python-multipart==0.0.20 +pytz==2025.2 +PyYAML==6.0.2 +RapidFuzz==3.13.0 +redis==5.2.1 +requests==2.32.3 +rsa==4.9.1 +scikit-image==0.25.2 +scipy==1.15.2 +setuptools==79.0.0 +shapely==2.1.0 +simsimd==6.2.1 +six==1.17.0 +sniffio==1.3.1 +soupsieve==2.7 +SQLAlchemy==2.0.40 +starlette==0.46.1 +stringzilla==3.12.5 +termcolor==3.0.1 +tifffile==2025.3.30 +tqdm==4.67.1 +typing-inspection==0.4.0 +typing_extensions==4.13.1 +tzdata==2025.2 +urllib3==2.4.0 +uvicorn==0.34.0 +win32_setctime==1.2.0 diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/in/test.pdf b/tests/in/test.pdf new file mode 100644 index 0000000..d19e7c6 Binary files /dev/null and b/tests/in/test.pdf differ diff --git a/tests/test_article.py b/tests/test_article.py new file mode 100644 index 0000000..4027add --- /dev/null +++ b/tests/test_article.py @@ -0,0 +1,79 @@ +from fastapi.testclient import TestClient +import sys +import os + +# 获取项目根目录 +project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +# 将项目根目录添加到sys.path +sys.path.insert(0, project_root) + +print(sys.path) + +from app.main import app +from app.db.session import SessionLocal + +client = TestClient(app) +# 初始化全局变量headers +Headers = {} +def setup_headers(): + + global Headers + # 模拟登录获取令牌 + login_response = client.post("/public/login", json={ + "email": "22371147@buaa.edu.cn", + "password": "123456" + }) + assert login_response.status_code == 200 + token = login_response.json().get('access_token') + if token is None: + raise ValueError("Failed to get access_token from login response") + Headers = {"Authorization": f"Bearer {token}"} + + + +def test_article_case1(): + db = SessionLocal() + db.begin() + try: + global Headers + # 确保headers已经设置 + if not Headers: + setup_headers() + + #创建文件夹 + create_folder_response = client.post("/article/selfCreateFolder", headers=Headers, json={ + "folder_name" : "测试文件夹1" + }) + assert create_folder_response.status_code == 200 + + #获取文件夹列表 + get_folder_list_response = client.get("/article/getSelfFolders", headers=Headers) + assert get_folder_list_response.status_code == 200 + folders = get_folder_list_response.json().get('result') + folder_id = None + for folder in folders: + if folder.get('folder_name') == '测试文件夹1': + folder_id = folder.get('folder_id') + break + + assert folder_id is not None + + #获取该文件夹下的文件列表 + get_file_list_response = client.get(f"/article/getArticlesInFolder", headers=Headers, + params={"folder_id": folder_id}) + assert get_file_list_response.status_code == 200 + files = get_file_list_response.json().get('result') + if len(files) == 0: + #向该文件夹下上传文件 + + upload_file_response = client.post("/article/uploadToSelfFolder", params={"folder_id": folder_id},headers=Headers, files={ + "article": os.path.join(project_root, "tests/in/test.pdf")}) + assert upload_file_response.status_code == 200 + except Exception as e: + db.rollback() + raise e + finally: + db.close() + + + diff --git a/tests/test_note.py b/tests/test_note.py new file mode 100644 index 0000000..b07b366 --- /dev/null +++ b/tests/test_note.py @@ -0,0 +1,116 @@ +from fastapi.testclient import TestClient +import sys +import os + +# 获取项目根目录 +project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +# 将项目根目录添加到sys.path +sys.path.insert(0, project_root) + +print(sys.path) + +from app.main import app +from app.db.session import SessionLocal + +client = TestClient(app) +# 初始化全局变量headers +Headers = {} + + +def setup_headers(): + global Headers + # 模拟登录获取令牌 + login_response = client.post("/public/login", json={ + "email": "22371147@buaa.edu.cn", + "password": "123456" + }) + assert login_response.status_code == 200 + token = login_response.json().get('access_token') + if token is None: + raise ValueError("Failed to get access_token from login response") + Headers = {"Authorization": f"Bearer {token}"} + + +# 创建后删除笔记 +def test_note_case1(): + db = SessionLocal() + db.begin() + try: + global Headers + # 确保headers已经设置 + if not Headers: + setup_headers() + + + # 创建笔记 + create_response = client.post("/notes", json={ + "article_id": 1, + "content": "
12
", + "title": "test" + }, headers=Headers) + assert create_response.status_code == 200 + note_id = create_response.json().get('note_id') + if note_id is None: + raise ValueError("Failed to get note_id from create note response") + + # 删除笔记 + delete_response = client.delete(f"/notes/{note_id}", headers=Headers) + assert delete_response.status_code == 200 + + # 再次删除不存在的笔记 + double_delete_response = client.delete(f"/notes/{note_id}", headers=Headers) + assert double_delete_response.status_code == 200 + except Exception as e: + db.rollback() + raise e + finally: + db.close() + + +def test_note_case2(): + db = SessionLocal() + db.begin() + try: + global Headers + if not Headers: + setup_headers() + + # 创建笔记 + create_response = client.post("/notes", json={ + "article_id": 1, + "content": "
12
", + "title": "test" + }, headers=Headers) + assert create_response.status_code == 200 + note_id = create_response.json().get('note_id') + if note_id is None: + raise ValueError("Failed to get note_id from create note response") + + print(note_id) + # 更新笔记 + update_response = client.put(f"/notes/{note_id}", params={ + "content": "
123
", + "title": "test2" + }, headers=Headers) + print(update_response.json()) + assert update_response.status_code == 200 + + + + # 获取笔记 + get_response = client.get(f"/notes",params={ + "id": note_id + }, headers=Headers) + assert get_response.status_code == 200 + assert get_response.json().get('notes')[0].get('id') == note_id + assert get_response.json().get('notes')[0].get('content') == "
123
" + + + #删除笔记 + delete_response = client.delete(f"/notes/{note_id}", headers=Headers) + assert delete_response.status_code == 200 + except Exception as e: + db.rollback() + raise e + finally: + db.close() diff --git a/tests/test_user.py b/tests/test_user.py new file mode 100644 index 0000000..6c580b0 --- /dev/null +++ b/tests/test_user.py @@ -0,0 +1,120 @@ +from fastapi.testclient import TestClient +import sys +import os + +# 获取项目根目录 +project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +# 将项目根目录添加到sys.path +sys.path.insert(0, project_root) + +print(sys.path) + +from app.main import app +from app.db.session import SessionLocal + +client = TestClient(app) +# 初始化全局变量headers +Headers = {} + + +def setup_headers(): + global Headers + # 模拟登录获取令牌 + login_response = client.post("/public/login", json={ + "email": "22371147@buaa.edu.cn", + "password": "123456" + }) + assert login_response.status_code == 200 + token = login_response.json().get('access_token') + if token is None: + raise ValueError("Failed to get access_token from login response") + Headers = {"Authorization": f"Bearer {token}"} + + +def test_user_case1(): + db = SessionLocal() + db.begin() + try: + global Headers + if not Headers: + setup_headers() + # 修改信息 + save_response = client.get("/user", headers=Headers) + assert save_response.status_code == 200 + save_data = save_response.json() + + advise_response = client.put("/user", data={"username":"李国庆test", + "address": "北京市海淀区中关村", + "university": "北京大学", + }, headers=Headers) + assert advise_response.status_code == 200 + + # 获取修改后的信息 + adviser_response = client.get("/user", headers=Headers) + assert adviser_response.status_code == 200 + assert adviser_response.json().get('address') == "北京市海淀区中关村" + assert adviser_response.json().get('university') == "北京大学" + assert adviser_response.json().get('username') == "李国庆test" + + restore_response = client.put("/user", data={ + "username": save_data.get('username'), + "address": save_data.get('address'), + "university": save_data.get('university') + }, headers=Headers) + assert restore_response.status_code == 200 + except Exception as e: + db.rollback() + raise e + finally: + db.close() + +def test_user_case2(): + db = SessionLocal() + db.begin() + try: + global Headers + if not Headers: + setup_headers() + + # 修改密码 + password_response = client.post("/user/password", json={ + "old_password": "123456", + "new_password": "654321" + }, headers=Headers) + assert password_response.status_code == 200 + # 验证密码是否修改成功 + login_response = client.post("/public/login", json={ + "email": "22371147@buaa.edu.cn", + "password": "123456" + }) + assert login_response.status_code == 401 + login_response = client.post("/public/login", json={ + "email": "22371147@buaa.edu.cn", + "password": "654321" + }) + assert login_response.status_code == 200 + + token = login_response.json().get('access_token') + if token is None: + raise ValueError("Failed to get access_token from login response") + Headers = {"Authorization": f"Bearer {token}"} + # 还原密码 + restore_response = client.post("/user/password", json={ + "old_password": "654321", + "new_password": "123456" + }, headers=Headers) + assert restore_response.status_code == 200 + # 验证密码是否还原成功 + login_response = client.post("/public/login", json={ + "email": "22371147@buaa.edu.cn", + "password": "123456" + }) + assert login_response.status_code == 200 + except Exception as e: + db.rollback() + raise e + finally: + db.close() + + +