Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions backend/app/generator/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,6 @@
class GeneratorSettings(BaseSettings):
"""代码生成配置"""

# 模版
TEMPLATE_BACKEND_DIR_NAME: str = 'py'

# 代码下载
DOWNLOAD_ZIP_FILENAME: str = 'fba_generator'

Expand Down
2 changes: 1 addition & 1 deletion backend/app/generator/schema/gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,5 @@ class ImportParam(SchemaBase):
"""导入参数"""

app: str = Field(description='应用名称,用于代码生成到指定 app')
table_name: str = Field(description='数据库表名')
table_schema: str = Field(description='数据库名')
table_name: str = Field(description='数据库表名')
2 changes: 1 addition & 1 deletion backend/app/generator/schema/gen_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from pydantic import ConfigDict, Field, field_validator

from backend.common.schema import SchemaBase
from backend.utils.type_conversion import sql_type_to_sqlalchemy
from backend.utils.generator.type_conversion import sql_type_to_sqlalchemy


class GenModelSchemaBase(SchemaBase):
Expand Down
2 changes: 1 addition & 1 deletion backend/app/generator/service/gen_model_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from backend.common.enums import GenModelMySQLColumnType
from backend.common.exception import errors
from backend.database.db import async_db_session
from backend.utils.type_conversion import sql_type_to_pydantic
from backend.utils.generator.type_conversion import sql_type_to_pydantic


class GenModelService:
Expand Down
84 changes: 48 additions & 36 deletions backend/app/generator/service/gen_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
from backend.common.exception import errors
from backend.core.path_conf import BASE_PATH
from backend.database.db import async_db_session
from backend.utils.gen_template import gen_template
from backend.utils.type_conversion import sql_type_to_pydantic
from backend.utils.generator.gen_template import gen_template
from backend.utils.generator.type_conversion import sql_type_to_pydantic


class GenService:
Expand Down Expand Up @@ -99,7 +99,7 @@ async def render_tpl_code(*, business: GenBusiness) -> dict[str, str]:
gen_vars = gen_template.get_vars(business, gen_models)
return {
tpl_path: await gen_template.get_template(tpl_path).render_async(**gen_vars)
for tpl_path in gen_template.get_template_paths()
for tpl_path in gen_template.get_template_files()
}

async def preview(self, *, pk: int) -> dict[str, bytes]:
Expand All @@ -115,10 +115,13 @@ async def preview(self, *, pk: int) -> dict[str, bytes]:
raise errors.NotFoundError(msg='业务不存在')

tpl_code_map = await self.render_tpl_code(business=business)
return {
tpl.replace('.jinja', '.py') if tpl.startswith('py') else ...: code.encode('utf-8')
for tpl, code in tpl_code_map.items()
}

codes = {}
for tpl, code in tpl_code_map.items():
if tpl.startswith('python'):
codes[tpl.replace('.jinja', '.py').split('/')[-1]] = code.encode('utf-8')

return codes

@staticmethod
async def get_generate_path(*, pk: int) -> list[str]:
Expand All @@ -133,9 +136,10 @@ async def get_generate_path(*, pk: int) -> list[str]:
if not business:
raise errors.NotFoundError(msg='业务不存在')

gen_path = business.gen_path or 'fba-backend-app-path'
gen_path = business.gen_path or 'fba-backend-app-dir'
target_files = gen_template.get_code_gen_paths(business)
return [os.path.join(gen_path, *target_file.split('/')[1:]) for target_file in target_files]

return [os.path.join(gen_path, *target_file.split('/')) for target_file in target_files]

async def generate(self, *, pk: int) -> None:
"""
Expand All @@ -155,32 +159,29 @@ async def generate(self, *, pk: int) -> None:
for tpl_path, code in tpl_code_map.items():
code_filepath = os.path.join(
gen_path,
*gen_template.get_code_gen_path(tpl_path, business).split('/')[1:],
*gen_template.get_code_gen_path(tpl_path, business).split('/'),
)
code_folder = Path(str(code_filepath)).parent
code_folder.mkdir(parents=True, exist_ok=True)

# 写入 init 文件
str_code_filepath = str(code_filepath)
code_folder = Path(str_code_filepath).parent
code_folder.mkdir(parents=True, exist_ok=True)

init_filepath = code_folder.joinpath('__init__.py')
if not init_filepath.exists():
async with aiofiles.open(init_filepath, 'w', encoding='utf-8') as f:
await f.write(gen_template.init_content)
async with aiofiles.open(init_filepath, 'w', encoding='utf-8') as f:
await f.write(gen_template.init_content)

if 'api' in str(code_folder):
# api __init__.py
# api __init__.py
if 'api' in str_code_filepath:
api_init_filepath = code_folder.parent.joinpath('__init__.py')
if not api_init_filepath.exists():
async with aiofiles.open(api_init_filepath, 'w', encoding='utf-8') as f:
await f.write(gen_template.init_content)
# app __init__.py
app_init_filepath = api_init_filepath.parent.joinpath('__init__.py')
if not app_init_filepath.exists():
async with aiofiles.open(app_init_filepath, 'w', encoding='utf-8') as f:
await f.write(gen_template.init_content)
async with aiofiles.open(api_init_filepath, 'w', encoding='utf-8') as f:
await f.write(gen_template.init_content)

# 写入代码文件
async with aiofiles.open(code_filepath, 'w', encoding='utf-8') as f:
await f.write(code)
# app __init__.py
if 'service' in str_code_filepath:
app_init_filepath = code_folder.parent.joinpath('__init__.py')
async with aiofiles.open(app_init_filepath, 'w', encoding='utf-8') as f:
await f.write(gen_template.init_content)

# model init 文件补充
if code_folder.name == 'model':
Expand All @@ -190,6 +191,10 @@ async def generate(self, *, pk: int) -> None:
f'import {to_pascal(business.table_name_en)}\n',
)

# 写入代码文件
async with aiofiles.open(code_filepath, 'w', encoding='utf-8') as f:
await f.write(code)

async def download(self, *, pk: int) -> io.BytesIO:
"""
下载生成的代码
Expand All @@ -206,13 +211,12 @@ async def download(self, *, pk: int) -> io.BytesIO:
with zipfile.ZipFile(bio, 'w') as zf:
tpl_code_map = await self.render_tpl_code(business=business)
for tpl_path, code in tpl_code_map.items():
# 写入代码文件
new_code_path = gen_template.get_code_gen_path(tpl_path, business)
zf.writestr(new_code_path, code)
code_filepath = gen_template.get_code_gen_path(tpl_path, business)

# 写入 init 文件
init_filepath = os.path.join(*new_code_path.split('/')[:-1], '__init__.py')
if 'model' not in new_code_path.split('/'):
code_dir = os.path.dirname(code_filepath)
init_filepath = os.path.join(code_dir, '__init__.py')
if 'model' not in code_filepath.split('/'):
zf.writestr(init_filepath, gen_template.init_content)
else:
zf.writestr(
Expand All @@ -222,11 +226,19 @@ async def download(self, *, pk: int) -> io.BytesIO:
f'import {to_pascal(business.table_name_en)}\n',
)

if 'api' in new_code_path:
# api __init__.py
api_init_filepath = os.path.join(*new_code_path.split('/')[:-2], '__init__.py')
# api __init__.py
if 'api' in code_dir:
api_init_filepath = os.path.join(os.path.dirname(code_dir), '__init__.py')
zf.writestr(api_init_filepath, gen_template.init_content)

# app __init__.py
if 'service' in code_dir:
app_init_filepath = os.path.join(os.path.dirname(code_dir), '__init__.py')
zf.writestr(app_init_filepath, gen_template.init_content)

# 写入代码文件
zf.writestr(code_filepath, code)

bio.seek(0)
return bio

Expand Down
2 changes: 1 addition & 1 deletion backend/core/path_conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
UPLOAD_DIR = STATIC_DIR / 'upload'

# jinja2 模版文件路径
JINJA2_TEMPLATE_DIR = BASE_PATH / 'templates'
JINJA2_TEMPLATE_DIR = BASE_PATH / 'templates' / 'generator'

# 插件目录
PLUGIN_DIR = BASE_PATH / 'plugin'
Expand Down
2 changes: 2 additions & 0 deletions backend/utils/generator/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from jinja2 import Environment, FileSystemLoader, Template, select_autoescape
from pydantic.alias_generators import to_pascal, to_snake

from backend.app.generator.conf import generator_settings
from backend.app.generator.model import GenBusiness, GenModel
from backend.core.conf import settings
from backend.core.path_conf import JINJA2_TEMPLATE_DIR
Expand Down Expand Up @@ -34,19 +33,19 @@ def get_template(self, jinja_file: str) -> Template:
return self.env.get_template(jinja_file)

@staticmethod
def get_template_paths() -> list[str]:
def get_template_files() -> list[str]:
"""
获取模板文件路径列表
获取模板文件列表

:return:
"""
return [
f'{generator_settings.TEMPLATE_BACKEND_DIR_NAME}/api.jinja',
f'{generator_settings.TEMPLATE_BACKEND_DIR_NAME}/crud.jinja',
f'{generator_settings.TEMPLATE_BACKEND_DIR_NAME}/model.jinja',
f'{generator_settings.TEMPLATE_BACKEND_DIR_NAME}/schema.jinja',
f'{generator_settings.TEMPLATE_BACKEND_DIR_NAME}/service.jinja',
]
files = []

# python
python_template_path = JINJA2_TEMPLATE_DIR / 'python'
files.extend([f'python/{file.name}' for file in python_template_path.iterdir() if file.is_file()])

return files

@staticmethod
def get_code_gen_paths(business: GenBusiness) -> list[str]:
Expand All @@ -59,11 +58,11 @@ def get_code_gen_paths(business: GenBusiness) -> list[str]:
app_name = business.app_name
module_name = business.table_name_en
return [
f'{generator_settings.TEMPLATE_BACKEND_DIR_NAME}/{app_name}/api/{business.api_version}/{module_name}.py',
f'{generator_settings.TEMPLATE_BACKEND_DIR_NAME}/{app_name}/crud/crud_{module_name}.py',
f'{generator_settings.TEMPLATE_BACKEND_DIR_NAME}/{app_name}/model/{module_name}.py',
f'{generator_settings.TEMPLATE_BACKEND_DIR_NAME}/{app_name}/schema/{module_name}.py',
f'{generator_settings.TEMPLATE_BACKEND_DIR_NAME}/{app_name}/service/{module_name}_service.py',
f'{app_name}/api/{business.api_version}/{module_name}.py',
f'{app_name}/crud/crud_{module_name}.py',
f'{app_name}/model/{module_name}.py',
f'{app_name}/schema/{module_name}.py',
f'{app_name}/service/{module_name}_service.py',
]

def get_code_gen_path(self, tpl_path: str, business: GenBusiness) -> str:
Expand All @@ -75,7 +74,7 @@ def get_code_gen_path(self, tpl_path: str, business: GenBusiness) -> str:
:return:
"""
target_files = self.get_code_gen_paths(business)
code_gen_path_mapping = dict(zip(self.get_template_paths(), target_files))
code_gen_path_mapping = dict(zip(self.get_template_files(), target_files))
return code_gen_path_mapping[tpl_path]

@staticmethod
Expand Down