Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions backend/app/generator/service/gen_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,17 @@ async def generate(self, *, pk: int) -> None:
if not init_filepath.exists():
async with aiofiles.open(init_filepath, 'w', encoding='utf-8') as f:
await f.write(gen_template.init_content)
if 'api' in str(code_folder):
# api __init__.py
api_init_filepath = code_folder.parent.joinpath('__init__.py')
if not api_init_filepath.exists():
async with aiofiles.open(api_init_filepath, 'w', encoding='utf-8') as f:
await f.write(gen_template.init_content)
# app __init__.py
app_init_filepath = api_init_filepath.parent.joinpath('__init__.py')
if not app_init_filepath:
async with aiofiles.open(app_init_filepath, 'w', encoding='utf-8') as f:
await f.write(gen_template.init_content)
# 写入代码文件呢
async with aiofiles.open(code_filepath, 'w', encoding='utf-8') as f:
await f.write(code)
Expand Down Expand Up @@ -161,6 +172,10 @@ async def download(self, *, pk: int) -> io.BytesIO:
f'from backend.app.{business.app_name}.model.{business.table_name_en} '
f'import {to_pascal(business.table_name_en)}\n',
)
if 'api' in new_code_path:
# api __init__.py
api_init_filepath = os.path.join(*new_code_path.split('/')[:-2], '__init__.py')
zf.writestr(api_init_filepath, gen_template.init_content)
zf.close()
bio.seek(0)
return bio
Expand Down
3 changes: 1 addition & 2 deletions backend/templates/py/api.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
# -*- coding: utf-8 -*-
from typing import Annotated

from fastapi import APIRouter, Depends, Path, Query

from backend.app.{{ app_name }}.schema.{{ table_name_en }} import Create{{ schema_name }}Param, Get{{ schema_name }}ListDetails, Update{{ schema_name }}Param
from backend.app.{{ app_name }}.service.{{ table_name_en }}_service import {{ table_name_en }}_service
from backend.common.pagination import DependsPagination, paging_data
Expand All @@ -12,6 +10,7 @@ from backend.common.security.jwt import DependsJwtAuth
from backend.common.security.permission import RequestPermission
from backend.common.security.rbac import DependsRBAC
from backend.database.db_mysql import CurrentSession
from fastapi import APIRouter, Depends, Path, Query

router = APIRouter()

Expand Down
15 changes: 11 additions & 4 deletions backend/templates/py/crud.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,11 @@
# -*- coding: utf-8 -*-
from typing import Sequence

from sqlalchemy import delete
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy_crud_plus import CRUDPlus

from backend.app.{{ app_name }}.model import {{ table_name_class }}
from backend.app.{{ app_name }}.schema.{{ table_name_en }} import Create{{ schema_name }}Param, Update{{ schema_name }}Param
from sqlalchemy import Select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy_crud_plus import CRUDPlus


class CRUD{{ table_name_class }}(CRUDPlus[{{ schema_name }}]):
Expand All @@ -21,6 +20,14 @@ class CRUD{{ table_name_class }}(CRUDPlus[{{ schema_name }}]):
"""
return await self.select_model(db, pk)

async def get_list(self) -> Select:
"""
获取 {{ schema_name }} 列表

:return:
"""
return await self.select_order('created_time', 'desc')

async def get_all(self, db: AsyncSession) -> Sequence[{{ table_name_class }}]:
"""
获取所有 {{ schema_name }}
Expand Down
6 changes: 3 additions & 3 deletions backend/templates/py/model.jinja
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from datetime import datetime
from uuid import UUID

import sqlalchemy as sa
from sqlalchemy.dialects import mysql

from sqlalchemy.orm import Mapped, mapped_column

from backend.common.model import {% if have_datetime_column %}Base{% else %}MappedBase{% endif %}, id_key
from sqlalchemy.dialects import mysql
from sqlalchemy.orm import Mapped, mapped_column


class {{ table_name_class }}({% if have_datetime_column %}Base{% else %}MappedBase{% endif %}):
Expand Down
3 changes: 1 addition & 2 deletions backend/templates/py/schema.jinja
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
{% if have_datetime_column %}
from datetime import datetime
{% endif %}

from pydantic import ConfigDict

from backend.common.schema import SchemaBase
Expand Down
8 changes: 5 additions & 3 deletions backend/templates/py/service.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ from backend.app.{{ app_name }}.model import {{ table_name_class }}
from backend.app.{{ app_name }}.schema.{{ table_name_en }} import Create{{ schema_name }}Param, Update{{ schema_name }}Param
from backend.common.exception import errors
from backend.database.db_mysql import async_db_session
from sqlalchemy import Select


class {{ table_name_class }}Service:
Expand All @@ -18,6 +19,10 @@ class {{ table_name_class }}Service:
raise errors.NotFoundError(msg='{{ table_simple_name_zh }}不存在')
return {{ table_name_en }}

@staticmethod
async def get_select() -> Select:
return await {{ table_name_en }}_dao.get_list()

@staticmethod
async def get_all() -> Sequence[{{ table_name_class }}]:
async with async_db_session() as db:
Expand All @@ -27,9 +32,6 @@ class {{ table_name_class }}Service:
@staticmethod
async def create(*, obj: Create{{ schema_name }}Param) -> None:
async with async_db_session.begin() as db:
{{ table_name_en }} = await {{ table_name_en }}_dao.get_by_name(db, obj.name)
if {{ table_name_en }}:
raise errors.ForbiddenError(msg='{{ table_simple_name_zh }}已存在')
await {{ table_name_en }}_dao.create(db, obj)

@staticmethod
Expand Down