diff --git a/backend/plugin/code_generator/plugin.toml b/backend/plugin/code_generator/plugin.toml index 4004a2b3c..071db61d1 100644 --- a/backend/plugin/code_generator/plugin.toml +++ b/backend/plugin/code_generator/plugin.toml @@ -1,6 +1,6 @@ [plugin] summary = '代码生成' -version = '0.0.3' +version = '0.0.4' description = '生成通用业务代码' author = 'wu-clan' diff --git a/backend/plugin/code_generator/service/code_service.py b/backend/plugin/code_generator/service/code_service.py index f93fe0535..2e5eb13af 100644 --- a/backend/plugin/code_generator/service/code_service.py +++ b/backend/plugin/code_generator/service/code_service.py @@ -125,21 +125,21 @@ async def preview(self, *, pk: int) -> dict[str, bytes]: tpl_code_map = await self.render_tpl_code(business=business) codes = {} - for tpl, code in tpl_code_map.items(): - if tpl.startswith('python'): + for tpl_path, code in tpl_code_map.items(): + if tpl_path.startswith('python'): rootpath = f'fastapi_best_architecture/backend/app/{business.app_name}' - template_name = tpl.split('/')[-1] + template_name = tpl_path.split('/')[-1] match template_name: case 'api.jinja': - filepath = f'{rootpath}/api/{business.api_version}/{business.app_name}.py' + filepath = f'{rootpath}/api/{business.api_version}/{business.filename}.py' case 'crud.jinja': - filepath = f'{rootpath}/crud/crud_{business.app_name}.py' + filepath = f'{rootpath}/crud/crud_{business.filename}.py' case 'model.jinja': - filepath = f'{rootpath}/model/{business.app_name}.py' + filepath = f'{rootpath}/model/{business.filename}.py' case 'schema.jinja': - filepath = f'{rootpath}/schema/{business.app_name}.py' + filepath = f'{rootpath}/schema/{business.filename}.py' case 'service.jinja': - filepath = f'{rootpath}/service/{business.app_name}_service.py' + filepath = f'{rootpath}/service/{business.filename}_service.py' codes[filepath] = code.encode('utf-8') return codes @@ -157,7 +157,7 @@ async def get_generate_path(*, pk: int) -> list[str]: if not business: raise errors.NotFoundError(msg='业务不存在') - gen_path = business.gen_path or 'fba-backend-app-dir' + gen_path = business.gen_path or '.../backend/app/' target_files = gen_template.get_code_gen_paths(business) return [os.path.join(gen_path, *target_file.split('/')) for target_file in target_files] diff --git a/backend/plugin/code_generator/templates/python/crud.jinja b/backend/plugin/code_generator/templates/python/crud.jinja index 4f9b33b99..11bcef27d 100644 --- a/backend/plugin/code_generator/templates/python/crud.jinja +++ b/backend/plugin/code_generator/templates/python/crud.jinja @@ -66,4 +66,4 @@ class CRUD{{ class_name }}(CRUDPlus[{{ schema_name }}]): return await self.delete_model_by_column(db, allow_multiple=True, id__in=pks) -{{ instance_name }}_dao: CRUD{{ class_name }} = CRUD{{ class_name }}({{ class_name }}) +{{ table_name }}_dao: CRUD{{ class_name }} = CRUD{{ class_name }}({{ class_name }}) diff --git a/backend/plugin/code_generator/templates/python/model.jinja b/backend/plugin/code_generator/templates/python/model.jinja index 745495a4b..054922b12 100644 --- a/backend/plugin/code_generator/templates/python/model.jinja +++ b/backend/plugin/code_generator/templates/python/model.jinja @@ -1,6 +1,9 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- +{% if default_datetime_column %} from datetime import datetime + +{% endif %} from uuid import UUID import sqlalchemy as sa @@ -12,10 +15,10 @@ from sqlalchemy.dialects import postgresql {% endif -%} from sqlalchemy.orm import Mapped, mapped_column -from backend.common.model import {% if default_datetime_column %}Base{% else %}MappedBase{% endif %}, id_key +from backend.common.model import {% if default_datetime_column %}Base{% else %}DataClassBase{% endif %}, id_key -class {{ class_name }}({% if default_datetime_column %}Base{% else %}MappedBase{% endif %}): +class {{ class_name }}({% if default_datetime_column %}Base{% else %}DataClassBase{% endif %}): """{{ table_comment }}""" __tablename__ = '{{ table_name }}' diff --git a/backend/plugin/code_generator/templates/python/schema.jinja b/backend/plugin/code_generator/templates/python/schema.jinja index e4e81dd0d..c51d23dbb 100644 --- a/backend/plugin/code_generator/templates/python/schema.jinja +++ b/backend/plugin/code_generator/templates/python/schema.jinja @@ -1,7 +1,9 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- +{% if default_datetime_column %} from datetime import datetime +{% endif %} from pydantic import ConfigDict, Field from backend.common.schema import SchemaBase diff --git a/backend/plugin/code_generator/templates/python/service.jinja b/backend/plugin/code_generator/templates/python/service.jinja index c00888790..fa74c67af 100644 --- a/backend/plugin/code_generator/templates/python/service.jinja +++ b/backend/plugin/code_generator/templates/python/service.jinja @@ -75,4 +75,4 @@ class {{ class_name }}Service: return count -{{ instance_name }}_service: {{ class_name }}Service = {{ class_name }}Service() +{{ table_name }}_service: {{ class_name }}Service = {{ class_name }}Service() diff --git a/backend/plugin/code_generator/utils/code_template.py b/backend/plugin/code_generator/utils/code_template.py index f34e3d98a..898bcb571 100644 --- a/backend/plugin/code_generator/utils/code_template.py +++ b/backend/plugin/code_generator/utils/code_template.py @@ -3,7 +3,6 @@ from typing import Sequence from jinja2 import Environment, FileSystemLoader, Template, select_autoescape -from pydantic.alias_generators import to_pascal, to_snake from backend.core.conf import settings from backend.plugin.code_generator.model import GenBusiness, GenColumn @@ -39,13 +38,13 @@ def get_template_files() -> list[str]: :return: """ - files = [] - - # python - python_template_path = JINJA2_TEMPLATE_DIR / 'python' - files.extend([f'python/{file.name}' for file in python_template_path.iterdir() if file.is_file()]) - - return files + return [ + 'python/api.jinja', + 'python/crud.jinja', + 'python/model.jinja', + 'python/schema.jinja', + 'python/service.jinja', + ] @staticmethod def get_code_gen_paths(business: GenBusiness) -> list[str]: @@ -73,8 +72,7 @@ def get_code_gen_path(self, tpl_path: str, business: GenBusiness) -> str: :param business: 代码生成业务对象 :return: """ - target_files = self.get_code_gen_paths(business) - code_gen_path_mapping = dict(zip(self.get_template_files(), target_files)) + code_gen_path_mapping = dict(zip(self.get_template_files(), self.get_code_gen_paths(business))) return code_gen_path_mapping[tpl_path] @staticmethod @@ -88,12 +86,11 @@ def get_vars(business: GenBusiness, models: Sequence[GenColumn]) -> dict[str, st """ return { 'app_name': business.app_name, - 'table_name': to_snake(business.table_name), + 'table_name': business.table_name, 'doc_comment': business.doc_comment, 'table_comment': business.table_comment, - 'class_name': to_pascal(business.class_name), - 'instance_name': to_snake(business.class_name), - 'schema_name': to_pascal(business.schema_name), + 'class_name': business.class_name, + 'schema_name': business.schema_name, 'default_datetime_column': business.default_datetime_column, 'permission': str(business.table_name.replace('_', ':')), 'database_type': settings.DATABASE_TYPE,