Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion backend/plugin/code_generator/plugin.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[plugin]
summary = '代码生成'
version = '0.0.3'
version = '0.0.4'
description = '生成通用业务代码'
author = 'wu-clan'

Expand Down
18 changes: 9 additions & 9 deletions backend/plugin/code_generator/service/code_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,21 +125,21 @@ async def preview(self, *, pk: int) -> dict[str, bytes]:
tpl_code_map = await self.render_tpl_code(business=business)

codes = {}
for tpl, code in tpl_code_map.items():
if tpl.startswith('python'):
for tpl_path, code in tpl_code_map.items():
if tpl_path.startswith('python'):
rootpath = f'fastapi_best_architecture/backend/app/{business.app_name}'
template_name = tpl.split('/')[-1]
template_name = tpl_path.split('/')[-1]
match template_name:
case 'api.jinja':
filepath = f'{rootpath}/api/{business.api_version}/{business.app_name}.py'
filepath = f'{rootpath}/api/{business.api_version}/{business.filename}.py'
case 'crud.jinja':
filepath = f'{rootpath}/crud/crud_{business.app_name}.py'
filepath = f'{rootpath}/crud/crud_{business.filename}.py'
case 'model.jinja':
filepath = f'{rootpath}/model/{business.app_name}.py'
filepath = f'{rootpath}/model/{business.filename}.py'
case 'schema.jinja':
filepath = f'{rootpath}/schema/{business.app_name}.py'
filepath = f'{rootpath}/schema/{business.filename}.py'
case 'service.jinja':
filepath = f'{rootpath}/service/{business.app_name}_service.py'
filepath = f'{rootpath}/service/{business.filename}_service.py'
codes[filepath] = code.encode('utf-8')

return codes
Expand All @@ -157,7 +157,7 @@ async def get_generate_path(*, pk: int) -> list[str]:
if not business:
raise errors.NotFoundError(msg='业务不存在')

gen_path = business.gen_path or 'fba-backend-app-dir'
gen_path = business.gen_path or '.../backend/app/'
target_files = gen_template.get_code_gen_paths(business)

return [os.path.join(gen_path, *target_file.split('/')) for target_file in target_files]
Expand Down
2 changes: 1 addition & 1 deletion backend/plugin/code_generator/templates/python/crud.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -66,4 +66,4 @@ class CRUD{{ class_name }}(CRUDPlus[{{ schema_name }}]):
return await self.delete_model_by_column(db, allow_multiple=True, id__in=pks)


{{ instance_name }}_dao: CRUD{{ class_name }} = CRUD{{ class_name }}({{ class_name }})
{{ table_name }}_dao: CRUD{{ class_name }} = CRUD{{ class_name }}({{ class_name }})
7 changes: 5 additions & 2 deletions backend/plugin/code_generator/templates/python/model.jinja
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
{% if default_datetime_column %}
from datetime import datetime

{% endif %}
from uuid import UUID

import sqlalchemy as sa
Expand All @@ -12,10 +15,10 @@ from sqlalchemy.dialects import postgresql
{% endif -%}
from sqlalchemy.orm import Mapped, mapped_column

from backend.common.model import {% if default_datetime_column %}Base{% else %}MappedBase{% endif %}, id_key
from backend.common.model import {% if default_datetime_column %}Base{% else %}DataClassBase{% endif %}, id_key


class {{ class_name }}({% if default_datetime_column %}Base{% else %}MappedBase{% endif %}):
class {{ class_name }}({% if default_datetime_column %}Base{% else %}DataClassBase{% endif %}):
"""{{ table_comment }}"""

__tablename__ = '{{ table_name }}'
Expand Down
2 changes: 2 additions & 0 deletions backend/plugin/code_generator/templates/python/schema.jinja
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
{% if default_datetime_column %}
from datetime import datetime

{% endif %}
from pydantic import ConfigDict, Field

from backend.common.schema import SchemaBase
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,4 +75,4 @@ class {{ class_name }}Service:
return count


{{ instance_name }}_service: {{ class_name }}Service = {{ class_name }}Service()
{{ table_name }}_service: {{ class_name }}Service = {{ class_name }}Service()
25 changes: 11 additions & 14 deletions backend/plugin/code_generator/utils/code_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from typing import Sequence

from jinja2 import Environment, FileSystemLoader, Template, select_autoescape
from pydantic.alias_generators import to_pascal, to_snake

from backend.core.conf import settings
from backend.plugin.code_generator.model import GenBusiness, GenColumn
Expand Down Expand Up @@ -39,13 +38,13 @@ def get_template_files() -> list[str]:

:return:
"""
files = []

# python
python_template_path = JINJA2_TEMPLATE_DIR / 'python'
files.extend([f'python/{file.name}' for file in python_template_path.iterdir() if file.is_file()])

return files
return [
'python/api.jinja',
'python/crud.jinja',
'python/model.jinja',
'python/schema.jinja',
'python/service.jinja',
]

@staticmethod
def get_code_gen_paths(business: GenBusiness) -> list[str]:
Expand Down Expand Up @@ -73,8 +72,7 @@ def get_code_gen_path(self, tpl_path: str, business: GenBusiness) -> str:
:param business: 代码生成业务对象
:return:
"""
target_files = self.get_code_gen_paths(business)
code_gen_path_mapping = dict(zip(self.get_template_files(), target_files))
code_gen_path_mapping = dict(zip(self.get_template_files(), self.get_code_gen_paths(business)))
return code_gen_path_mapping[tpl_path]

@staticmethod
Expand All @@ -88,12 +86,11 @@ def get_vars(business: GenBusiness, models: Sequence[GenColumn]) -> dict[str, st
"""
return {
'app_name': business.app_name,
'table_name': to_snake(business.table_name),
'table_name': business.table_name,
'doc_comment': business.doc_comment,
'table_comment': business.table_comment,
'class_name': to_pascal(business.class_name),
'instance_name': to_snake(business.class_name),
'schema_name': to_pascal(business.schema_name),
'class_name': business.class_name,
'schema_name': business.schema_name,
'default_datetime_column': business.default_datetime_column,
'permission': str(business.table_name.replace('_', ':')),
'database_type': settings.DATABASE_TYPE,
Expand Down