Skip to content

Commit a4b392f

Browse files
committed
Adaptation to postgresql code generation
1 parent e47dda0 commit a4b392f

File tree

6 files changed

+115
-41
lines changed

6 files changed

+115
-41
lines changed

backend/app/generator/crud/crud_gen.py

Lines changed: 80 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -5,47 +5,100 @@
55
from sqlalchemy import Row, text
66
from sqlalchemy.ext.asyncio import AsyncSession
77

8+
from backend.core.conf import settings
9+
810

911
class CRUDGen:
1012
@staticmethod
1113
async def get_all_tables(db: AsyncSession, table_schema: str) -> Sequence[str]:
12-
stmt = text(
13-
'select table_name as table_name '
14-
'from information_schema.tables '
15-
'where table_name not like "sys_gen_%" '
16-
'and table_schema = :table_schema;'
17-
).bindparams(table_schema=table_schema)
14+
if settings.DATABASE_TYPE == 'mysql':
15+
sql = """
16+
SELECT table_name AS table_name FROM information_schema.tables
17+
WHERE table_name NOT LIKE 'sys_gen_%'
18+
AND table_schema = :table_schema;
19+
"""
20+
else:
21+
sql = """
22+
SELECT table_name AS table_name FROM information_schema.tables
23+
WHERE table_name NOT LIKE 'sys_gen_%'
24+
AND table_catalog = :table_schema
25+
AND table_schema = 'public'; -- schema 通常是 'public'
26+
"""
27+
stmt = text(sql).bindparams(table_schema=table_schema)
1828
result = await db.execute(stmt)
1929
return result.scalars().all()
2030

2131
@staticmethod
2232
async def get_table(db: AsyncSession, table_name: str) -> Row[tuple]:
23-
stmt = text(
24-
'select table_name as table_name, table_comment as table_comment '
25-
'from information_schema.tables '
26-
'where table_name not like "sys_gen_%" '
27-
'and table_name = :table_name;'
28-
).bindparams(table_name=table_name)
33+
if settings.DATABASE_TYPE == 'mysql':
34+
sql = """
35+
SELECT table_name AS table_name, table_comment AS table_comment FROM information_schema.tables
36+
WHERE table_name NOT LIKE 'sys_gen_%'
37+
AND table_name = :table_name;
38+
"""
39+
else:
40+
sql = """
41+
SELECT t.tablename AS table_name,
42+
pg_catalog.obj_description(t.tablename::regclass, 'pg_class') AS table_comment
43+
FROM pg_tables t
44+
WHERE t.tablename NOT LIKE 'sys_gen_%'
45+
AND t.tablename = :table_name
46+
AND t.schemaname = 'public'; -- schema 通常是 'public'
47+
"""
48+
stmt = text(sql).bindparams(table_name=table_name)
2949
result = await db.execute(stmt)
3050
return result.fetchone()
3151

3252
@staticmethod
3353
async def get_all_columns(db: AsyncSession, table_schema: str, table_name: str) -> Sequence[Row[tuple]]:
34-
stmt = text(
35-
'select column_name AS column_name, '
36-
'case when column_key = "PRI" then 1 else 0 end as is_pk, '
37-
'case when is_nullable = "NO" or column_key = "PRI" then 0 else 1 end as is_nullable, '
38-
'ordinal_position as sort, '
39-
'column_comment as column_comment, '
40-
'column_type as column_type '
41-
'from information_schema.columns '
42-
'where table_schema = :table_schema '
43-
'and table_name = :table_name '
44-
'and column_name != "id" '
45-
'and column_name != "created_time" '
46-
'and column_name != "updated_time" '
47-
'order by sort;'
48-
).bindparams(table_schema=table_schema, table_name=table_name)
54+
if settings.DATABASE_TYPE == 'mysql':
55+
sql = """
56+
SELECT column_name AS column_name,
57+
CASE WHEN column_key = 'PRI' THEN 1 ELSE 0 END AS is_pk,
58+
CASE WHEN is_nullable = 'NO' OR column_key = 'PRI' THEN 0 ELSE 1 END AS is_nullable,
59+
ordinal_position AS sort, column_comment AS column_comment,
60+
column_type AS column_type FROM information_schema.columns
61+
WHERE table_schema = :table_schema
62+
AND table_name = :table_name
63+
AND column_name <> 'id'
64+
AND column_name <> 'created_time'
65+
AND column_name <> 'updated_time'
66+
ORDER BY sort;
67+
"""
68+
stmt = text(sql).bindparams(table_schema=table_schema, table_name=table_name)
69+
else:
70+
sql = """
71+
SELECT a.attname AS column_name,
72+
CASE WHEN EXISTS (
73+
SELECT 1
74+
FROM pg_constraint c
75+
WHERE c.conrelid = t.oid
76+
AND c.contype = 'p'
77+
AND a.attnum = ANY(c.conkey)
78+
) THEN 1 ELSE 0 END AS is_pk,
79+
CASE WHEN a.attnotnull OR EXISTS (
80+
SELECT 1
81+
FROM pg_constraint c
82+
WHERE c.conrelid = t.oid
83+
AND c.contype = 'p'
84+
AND a.attnum = ANY(c.conkey)
85+
) THEN 0 ELSE 1 END AS is_nullable,
86+
a.attnum AS sort,
87+
col_description(t.oid, a.attnum) AS column_comment,
88+
pg_catalog.format_type(a.atttypid, a.atttypmod) AS column_type
89+
FROM pg_attribute a
90+
JOIN pg_class t ON a.attrelid = t.oid
91+
JOIN pg_namespace n ON n.oid = t.relnamespace
92+
WHERE n.nspname = 'public' -- 根据你的实际情况修改 schema 名称,通常是 'public'
93+
AND t.relname = :table_name
94+
AND a.attnum > 0
95+
AND NOT a.attisdropped
96+
AND a.attname <> 'id'
97+
AND a.attname <> 'created_time'
98+
AND a.attname <> 'updated_time'
99+
ORDER BY sort;
100+
"""
101+
stmt = text(sql).bindparams(table_name=table_name)
49102
result = await db.execute(stmt)
50103
return result.fetchall()
51104

backend/templates/py/api.jinja

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ async def get_pagination_{{ table_name_en }}(db: CurrentSession) -> ResponseMode
3939
'',
4040
summary='创建{{ table_simple_name_zh }}',
4141
dependencies=[
42-
Depends(RequestPermission('{{ permission_sign }}:add')),
42+
Depends(RequestPermission('{{ permission }}:add')),
4343
DependsRBAC,
4444
],
4545
)
@@ -52,7 +52,7 @@ async def create_{{ table_name_en }}(obj: Create{{ schema_name }}Param) -> Respo
5252
'/{pk}',
5353
summary='更新{{ table_simple_name_zh }}',
5454
dependencies=[
55-
Depends(RequestPermission('{{ permission_sign }}:edit')),
55+
Depends(RequestPermission('{{ permission }}:edit')),
5656
DependsRBAC,
5757
],
5858
)
@@ -67,7 +67,7 @@ async def update_{{ table_name_en }}(pk: Annotated[int, Path(...)], obj: Update{
6767
'',
6868
summary='(批量)删除{{ table_simple_name_zh }}',
6969
dependencies=[
70-
Depends(RequestPermission('{{ permission_sign }}:del')),
70+
Depends(RequestPermission('{{ permission }}:del')),
7171
DependsRBAC,
7272
],
7373
)

backend/templates/py/model.jinja

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,16 @@ from uuid import UUID
55

66
import sqlalchemy as sa
77

8-
from backend.common.model import {% if have_datetime_column %}Base{% else %}MappedBase{% endif %}, id_key
8+
from backend.common.model import {% if default_datetime_column %}Base{% else %}MappedBase{% endif %}, id_key
9+
{% if database_type == 'mysql' -%}
910
from sqlalchemy.dialects import mysql
11+
{% else -%}
12+
from sqlalchemy.dialects import postgresql
13+
{% endif -%}
1014
from sqlalchemy.orm import Mapped, mapped_column
1115

1216

13-
class {{ table_name_class }}({% if have_datetime_column %}Base{% else %}MappedBase{% endif %}):
17+
class {{ table_name_class }}({% if default_datetime_column %}Base{% else %}MappedBase{% endif %}):
1418
"""{{ table_name_zh }}"""
1519

1620
__tablename__ = '{{ table_name_en }}'
@@ -23,9 +27,14 @@ class {{ table_name_class }}({% if have_datetime_column %}Base{% else %}MappedBa
2327
{%- endif %} = mapped_column(
2428
{%- if model.type in ['NVARCHAR', 'String', 'Unicode', 'VARCHAR'] -%}
2529
sa.String({{ model.length }})
26-
{%- elif model.type in ['BIT', 'ENUM', 'LONGBLOB', 'LONGTEXT', 'MEDIUMBLOB', 'MEDIUMINT', 'MEDIUMTEXT', 'SET',
27-
'TINYBLOB', 'TINYINT', 'TINYTEXT', 'YEAR'] -%}
30+
{%- elif database_type == 'mysql' and model.type in ['BIT', 'ENUM', 'LONGBLOB', 'LONGTEXT', 'MEDIUMBLOB',
31+
'MEDIUMINT', 'MEDIUMTEXT', 'SET', 'TINYBLOB', 'TINYINT', 'TINYTEXT', 'YEAR'] -%}
2832
mysql.{{ model.type }}()
33+
{%- elif database_type == 'postgresql' and model.type in [
34+
'ARRAY', 'BIT', 'BYTEA', 'CIDR', 'CITEXT', 'DATEMULTIRANGE', 'DATERANGE', 'DOMAIN', 'ENUM', 'HSTORE', 'INET',
35+
'INT4MULTIRANGE', 'INT4RANGE', 'INT8MULTIRANGE', 'INT8RANGE', 'INTERVAL', 'JSONB', 'JSONPATH', 'MACADDR',
36+
'MACADDR8', 'MONEY', 'NUMMULTIRANGE', 'NUMRANGE', 'OID', 'REGCLASS', 'REGCONFIG', 'TSMULTIRANGE', 'TSQUERY',
37+
'TSRANGE', 'TSTZMULTIRANGE', 'TSTZRANGE', 'TSVECTOR'] -%}
2938
{%- else -%}
3039
sa.{{ model.type }}()
3140
{%- endif -%}, default=

backend/templates/py/schema.jinja

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ class Get{{ schema_name }}ListDetails({{ schema_name }}SchemaBase):
2626
model_config = ConfigDict(from_attributes=True)
2727

2828
id: int
29-
{% if have_datetime_column %}
29+
{% if default_datetime_column %}
3030
created_time: datetime
3131
updated_time: datetime | None = None
3232
{% endif %}

backend/utils/gen_template.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from backend.app.generator.conf import generator_settings
99
from backend.app.generator.model import GenBusiness, GenModel
10+
from backend.core.conf import settings
1011
from backend.core.path_conf import JINJA2_TEMPLATE_DIR
1112

1213

@@ -95,8 +96,9 @@ def get_vars(business: GenBusiness, models: Sequence[GenModel]) -> dict:
9596
'table_simple_name_zh': business.table_simple_name_zh,
9697
'table_comment': business.table_comment,
9798
'schema_name': to_pascal(business.schema_name),
98-
'have_datetime_column': business.default_datetime_column,
99-
'permission_sign': str(business.table_name_en.replace('_', ':')),
99+
'default_datetime_column': business.default_datetime_column,
100+
'permission': str(business.table_name_en.replace('_', ':')),
101+
'database_type': settings.DATABASE_TYPE,
100102
'models': models,
101103
}
102104

backend/utils/type_conversion.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#!/usr/bin/env python3
22
# -*- coding: utf-8 -*-
3-
from backend.common.enums import GenModelMySQLColumnType
3+
from backend.common.enums import GenModelMySQLColumnType, GenModelPostgreSQLColumnType
4+
from backend.core.conf import settings
45

56

67
def sql_type_to_sqlalchemy(typing: str) -> str:
@@ -10,8 +11,12 @@ def sql_type_to_sqlalchemy(typing: str) -> str:
1011
:param typing:
1112
:return:
1213
"""
13-
if typing in GenModelMySQLColumnType.get_member_keys():
14-
return typing
14+
if settings.DATABASE_TYPE == 'mysql':
15+
if typing in GenModelMySQLColumnType.get_member_keys():
16+
return typing
17+
else:
18+
if typing in GenModelPostgreSQLColumnType.get_member_keys():
19+
return typing
1520
return 'String'
1621

1722

@@ -23,6 +28,11 @@ def sql_type_to_pydantic(typing: str) -> str:
2328
:return:
2429
"""
2530
try:
26-
return GenModelMySQLColumnType[typing].value
31+
if settings.DATABASE_TYPE == 'mysql':
32+
return GenModelMySQLColumnType[typing].value
33+
else:
34+
if typing == 'CHARACTER VARYING': # postgresql 中 DDL VARCHAR 的别名
35+
return 'str'
36+
return GenModelPostgreSQLColumnType[typing].value
2737
except KeyError:
2838
return 'str'

0 commit comments

Comments
 (0)