Skip to content

Commit 576ed77

Browse files
authored
Fix code generation file missing (#457)
1 parent 06cbe56 commit 576ed77

File tree

6 files changed

+36
-14
lines changed

6 files changed

+36
-14
lines changed

backend/app/generator/service/gen_service.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,17 @@ async def generate(self, *, pk: int) -> None:
127127
if not init_filepath.exists():
128128
async with aiofiles.open(init_filepath, 'w', encoding='utf-8') as f:
129129
await f.write(gen_template.init_content)
130+
if 'api' in str(code_folder):
131+
# api __init__.py
132+
api_init_filepath = code_folder.parent.joinpath('__init__.py')
133+
if not api_init_filepath.exists():
134+
async with aiofiles.open(api_init_filepath, 'w', encoding='utf-8') as f:
135+
await f.write(gen_template.init_content)
136+
# app __init__.py
137+
app_init_filepath = api_init_filepath.parent.joinpath('__init__.py')
138+
if not app_init_filepath:
139+
async with aiofiles.open(app_init_filepath, 'w', encoding='utf-8') as f:
140+
await f.write(gen_template.init_content)
130141
# 写入代码文件呢
131142
async with aiofiles.open(code_filepath, 'w', encoding='utf-8') as f:
132143
await f.write(code)
@@ -161,6 +172,10 @@ async def download(self, *, pk: int) -> io.BytesIO:
161172
f'from backend.app.{business.app_name}.model.{business.table_name_en} '
162173
f'import {to_pascal(business.table_name_en)}\n',
163174
)
175+
if 'api' in new_code_path:
176+
# api __init__.py
177+
api_init_filepath = os.path.join(*new_code_path.split('/')[:-2], '__init__.py')
178+
zf.writestr(api_init_filepath, gen_template.init_content)
164179
zf.close()
165180
bio.seek(0)
166181
return bio

backend/templates/py/api.jinja

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@
22
# -*- coding: utf-8 -*-
33
from typing import Annotated
44

5-
from fastapi import APIRouter, Depends, Path, Query
6-
75
from backend.app.{{ app_name }}.schema.{{ table_name_en }} import Create{{ schema_name }}Param, Get{{ schema_name }}ListDetails, Update{{ schema_name }}Param
86
from backend.app.{{ app_name }}.service.{{ table_name_en }}_service import {{ table_name_en }}_service
97
from backend.common.pagination import DependsPagination, paging_data
@@ -12,6 +10,7 @@ from backend.common.security.jwt import DependsJwtAuth
1210
from backend.common.security.permission import RequestPermission
1311
from backend.common.security.rbac import DependsRBAC
1412
from backend.database.db_mysql import CurrentSession
13+
from fastapi import APIRouter, Depends, Path, Query
1514

1615
router = APIRouter()
1716

backend/templates/py/crud.jinja

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,11 @@
22
# -*- coding: utf-8 -*-
33
from typing import Sequence
44

5-
from sqlalchemy import delete
6-
from sqlalchemy.ext.asyncio import AsyncSession
7-
from sqlalchemy_crud_plus import CRUDPlus
8-
95
from backend.app.{{ app_name }}.model import {{ table_name_class }}
106
from backend.app.{{ app_name }}.schema.{{ table_name_en }} import Create{{ schema_name }}Param, Update{{ schema_name }}Param
7+
from sqlalchemy import Select
8+
from sqlalchemy.ext.asyncio import AsyncSession
9+
from sqlalchemy_crud_plus import CRUDPlus
1110

1211

1312
class CRUD{{ table_name_class }}(CRUDPlus[{{ schema_name }}]):
@@ -21,6 +20,14 @@ class CRUD{{ table_name_class }}(CRUDPlus[{{ schema_name }}]):
2120
"""
2221
return await self.select_model(db, pk)
2322

23+
async def get_list(self) -> Select:
24+
"""
25+
获取 {{ schema_name }} 列表
26+
27+
:return:
28+
"""
29+
return await self.select_order('created_time', 'desc')
30+
2431
async def get_all(self, db: AsyncSession) -> Sequence[{{ table_name_class }}]:
2532
"""
2633
获取所有 {{ schema_name }}

backend/templates/py/model.jinja

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
#!/usr/bin/env python3
22
# -*- coding: utf-8 -*-
3+
from datetime import datetime
34
from uuid import UUID
45

56
import sqlalchemy as sa
6-
from sqlalchemy.dialects import mysql
7-
8-
from sqlalchemy.orm import Mapped, mapped_column
97

108
from backend.common.model import {% if have_datetime_column %}Base{% else %}MappedBase{% endif %}, id_key
9+
from sqlalchemy.dialects import mysql
10+
from sqlalchemy.orm import Mapped, mapped_column
1111

1212

1313
class {{ table_name_class }}({% if have_datetime_column %}Base{% else %}MappedBase{% endif %}):

backend/templates/py/schema.jinja

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
#!/usr/bin/env python3
22
# -*- coding: utf-8 -*-
3-
{% if have_datetime_column %}
43
from datetime import datetime
5-
{% endif %}
4+
65
from pydantic import ConfigDict
76

87
from backend.common.schema import SchemaBase

backend/templates/py/service.jinja

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ from backend.app.{{ app_name }}.model import {{ table_name_class }}
77
from backend.app.{{ app_name }}.schema.{{ table_name_en }} import Create{{ schema_name }}Param, Update{{ schema_name }}Param
88
from backend.common.exception import errors
99
from backend.database.db_mysql import async_db_session
10+
from sqlalchemy import Select
1011

1112

1213
class {{ table_name_class }}Service:
@@ -18,6 +19,10 @@ class {{ table_name_class }}Service:
1819
raise errors.NotFoundError(msg='{{ table_simple_name_zh }}不存在')
1920
return {{ table_name_en }}
2021

22+
@staticmethod
23+
async def get_select() -> Select:
24+
return await {{ table_name_en }}_dao.get_list()
25+
2126
@staticmethod
2227
async def get_all() -> Sequence[{{ table_name_class }}]:
2328
async with async_db_session() as db:
@@ -27,9 +32,6 @@ class {{ table_name_class }}Service:
2732
@staticmethod
2833
async def create(*, obj: Create{{ schema_name }}Param) -> None:
2934
async with async_db_session.begin() as db:
30-
{{ table_name_en }} = await {{ table_name_en }}_dao.get_by_name(db, obj.name)
31-
if {{ table_name_en }}:
32-
raise errors.ForbiddenError(msg='{{ table_simple_name_zh }}已存在')
3335
await {{ table_name_en }}_dao.create(db, obj)
3436

3537
@staticmethod

0 commit comments

Comments
 (0)