Skip to content

Commit ab4495c

Browse files
authored
Optimize and normalize the code generator (#430)
1 parent 1d3b0e7 commit ab4495c

File tree

11 files changed

+231
-194
lines changed

11 files changed

+231
-194
lines changed

backend/app/generator/api/v1/gen.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ async def get_all_businesses() -> ResponseModel:
3333

3434
@router.get('/businesses/{pk}', summary='获取代码生成业务详情', dependencies=[DependsJwtAuth])
3535
async def get_business(pk: Annotated[int, Path(...)]) -> ResponseModel:
36-
business = await gen_service.get_business_with_model(pk=pk)
36+
business = await gen_business_service.get(pk=pk)
3737
data = GetGenBusinessListDetails(**select_as_dict(business))
3838
return response_base.success(data=data)
3939

@@ -89,6 +89,12 @@ async def delete_business(pk: Annotated[int, Path(...)]) -> ResponseModel:
8989
return response_base.fail()
9090

9191

92+
@router.get('/models/types', summary='获取代码生成模型列类型', dependencies=[DependsJwtAuth])
93+
async def get_model_types() -> ResponseModel:
94+
model_types = await gen_model_service.get_types()
95+
return response_base.success(data=model_types)
96+
97+
9298
@router.get('/models/{pk}', summary='获取代码生成模型详情', dependencies=[DependsJwtAuth])
9399
async def get_model(pk: Annotated[int, Path(...)]) -> ResponseModel:
94100
model = await gen_model_service.get(pk=pk)

backend/app/generator/crud/crud_gen.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,21 +2,11 @@
22
# -*- coding: utf-8 -*-
33
from typing import Sequence
44

5-
from sqlalchemy import Row, select, text
5+
from sqlalchemy import Row, text
66
from sqlalchemy.ext.asyncio import AsyncSession
7-
from sqlalchemy.orm import selectinload
8-
9-
from backend.app.generator.model import GenBusiness
107

118

129
class CRUDGen:
13-
@staticmethod
14-
async def get_business_with_model(db: AsyncSession, business_id: int) -> GenBusiness:
15-
stmt = select(GenBusiness).options(selectinload(GenBusiness.gen_model)).where(GenBusiness.id == business_id)
16-
result = await db.execute(stmt)
17-
data = result.scalars().first()
18-
return data
19-
2010
@staticmethod
2111
async def get_all_tables(db: AsyncSession, table_schema: str) -> Sequence[str]:
2212
stmt = text(

backend/app/generator/crud/crud_gen_model.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# -*- coding: utf-8 -*-
33
from typing import Sequence
44

5+
from sqlalchemy import update
56
from sqlalchemy.ext.asyncio import AsyncSession
67
from sqlalchemy_crud_plus import CRUDPlus
78

@@ -34,6 +35,7 @@ async def create(self, db: AsyncSession, obj_in: CreateGenModelParam, pd_type: s
3435
3536
:param db:
3637
:param obj_in:
38+
:param pd_type:
3739
:return:
3840
"""
3941
await self.create_model(db, obj_in, pd_type=pd_type)
@@ -45,9 +47,12 @@ async def update(self, db: AsyncSession, pk: int, obj_in: UpdateGenModelParam, p
4547
:param db:
4648
:param pk:
4749
:param obj_in:
50+
:param pd_type:
4851
:return:
4952
"""
50-
return await self.update_model_by_column(db, obj_in, id=pk, pd_type=pd_type)
53+
stmt = update(self.model).where(self.model.id == pk).values(**obj_in.model_dump(), pd_type=pd_type)
54+
result = await db.execute(stmt)
55+
return result.rowcount
5156

5257
async def delete(self, db: AsyncSession, pk: int) -> int:
5358
"""

backend/app/generator/schema/gen_model.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,14 @@
22
# -*- coding: utf-8 -*-
33
from pydantic import ConfigDict, Field, field_validator
44

5-
from backend.common.enums import GenModelColumnType
65
from backend.common.schema import SchemaBase
76
from backend.utils.type_conversion import sql_type_to_sqlalchemy
87

98

109
class GenModelSchemaBase(SchemaBase):
1110
name: str
1211
comment: str | None = None
13-
type: GenModelColumnType = Field(GenModelColumnType.VARCHAR)
12+
type: str
1413
default: str | None = None
1514
sort: int
1615
length: int

backend/app/generator/service/gen_model_service.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from backend.app.generator.crud.crud_gen_model import gen_model_dao
66
from backend.app.generator.model import GenModel
77
from backend.app.generator.schema.gen_model import CreateGenModelParam, UpdateGenModelParam
8+
from backend.common.enums import GenModelMySQLColumnType
89
from backend.common.exception import errors
910
from backend.database.db_mysql import async_db_session
1011
from backend.utils.type_conversion import sql_type_to_pydantic
@@ -17,6 +18,12 @@ async def get(*, pk: int) -> GenModel:
1718
gen_model = await gen_model_dao.get(db, pk)
1819
return gen_model
1920

21+
@staticmethod
22+
async def get_types() -> list[str]:
23+
types = GenModelMySQLColumnType.get_member_keys()
24+
types.sort()
25+
return types
26+
2027
@staticmethod
2128
async def get_by_business(*, business_id: int) -> Sequence[GenModel]:
2229
async with async_db_session() as db:

backend/app/generator/service/gen_service.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,20 +18,14 @@
1818
from backend.app.generator.schema.gen_business import CreateGenBusinessParam
1919
from backend.app.generator.schema.gen_model import CreateGenModelParam
2020
from backend.app.generator.service.gen_model_service import gen_model_service
21-
from backend.common.enums import GenModelColumnType
2221
from backend.common.exception import errors
2322
from backend.core.path_conf import BasePath
2423
from backend.database.db_mysql import async_db_session
2524
from backend.utils.gen_template import gen_template
25+
from backend.utils.type_conversion import sql_type_to_pydantic
2626

2727

2828
class GenService:
29-
@staticmethod
30-
async def get_business_with_model(*, pk: int) -> GenBusiness:
31-
async with async_db_session() as db:
32-
business = await gen_dao.get_business_with_model(db, pk)
33-
return business
34-
3529
@staticmethod
3630
async def get_tables(*, table_schema: str) -> Sequence[str]:
3731
async with async_db_session() as db:
@@ -60,19 +54,18 @@ async def import_business_and_model(*, app: str, table_schema: str, table_name:
6054
column_info = await gen_dao.get_all_columns(db, table_schema, table_name)
6155
for column in column_info:
6256
column_type = column[-1].split('(')[0].upper()
57+
pd_type = sql_type_to_pydantic(column_type)
6358
model_data = {
6459
'name': column[0],
6560
'comment': column[-2],
6661
'type': column_type,
6762
'sort': column[-3],
68-
'length': column[-1].split('(')[1][:-1]
69-
if column_type == GenModelColumnType.CHAR or column_type == GenModelColumnType.VARCHAR
70-
else 0,
63+
'length': column[-1].split('(')[1][:-1] if pd_type == 'str' and '(' in column[-1] else 0,
7164
'is_pk': column[1],
7265
'is_nullable': column[2],
7366
'gen_business_id': new_business.id,
7467
}
75-
await gen_model_dao.create(db, CreateGenModelParam(**model_data))
68+
await gen_model_dao.create(db, CreateGenModelParam(**model_data), pd_type=pd_type)
7669

7770
@staticmethod
7871
async def render_tpl_code(*, business: GenBusiness) -> dict:

backend/common/enums.py

Lines changed: 148 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -90,50 +90,151 @@ class UserSocialType(StrEnum):
9090
linuxdo = 'LinuxDo'
9191

9292

93-
class GenModelColumnType(StrEnum):
94-
"""代码生成模型列类型"""
95-
96-
BIGINT = 'BIGINT'
97-
BINARY = 'BINARY'
98-
BIT = 'BIT'
99-
BLOB = 'BLOB'
100-
BOOL = 'BOOL'
101-
BOOLEAN = 'BOOLEAN'
102-
CHAR = 'CHAR'
103-
DATE = 'DATE'
104-
DATETIME = 'DATETIME'
105-
DECIMAL = 'DECIMAL'
106-
DOUBLE = 'DOUBLE'
107-
DOUBLE_PRECISION = 'DOUBLE PRECISION'
108-
ENUM = 'ENUM'
109-
FLOAT = 'FLOAT'
110-
GEOMETRY = 'GEOMETRY'
111-
GEOMETRYCOLLECTION = 'GEOMETRYCOLLECTION'
112-
INT = 'INT'
113-
INTEGER = 'INTEGER'
114-
JSON = 'JSON'
115-
LINESTRING = 'LINESTRING'
116-
LONGBLOB = 'LONGBLOB'
117-
LONGTEXT = 'LONGTEXT'
118-
MEDIUMBLOB = 'MEDIUMBLOB'
119-
MEDIUMINT = 'MEDIUMINT'
120-
MEDIUMTEXT = 'MEDIUMTEXT'
121-
MULTILINESTRING = 'MULTILINESTRING'
122-
MULTIPOINT = 'MULTIPOINT'
123-
MULTIPOLYGON = 'MULTIPOLYGON'
124-
NUMERIC = 'NUMERIC'
125-
POINT = 'POINT'
126-
POLYGON = 'POLYGON'
127-
REAL = 'REAL'
128-
SERIAL = 'SERIAL'
129-
SET = 'SET'
130-
SMALLINT = 'SMALLINT'
131-
TEXT = 'TEXT'
132-
TIME = 'TIME'
133-
TIMESTAMP = 'TIMESTAMP'
134-
TINYBLOB = 'TINYBLOB'
135-
TINYINT = 'TINYINT'
136-
TINYTEXT = 'TINYTEXT'
137-
VARBINARY = 'VARBINARY'
138-
VARCHAR = 'VARCHAR'
139-
YEAR = 'YEAR'
93+
class GenModelMySQLColumnType(StrEnum):
94+
"""代码生成模型列类型(MySQL)"""
95+
96+
# Python 类型映射
97+
BIGINT = 'int'
98+
BigInteger = 'int' # BIGINT
99+
BINARY = 'bytes'
100+
BLOB = 'bytes'
101+
BOOLEAN = 'bool' # BOOL
102+
Boolean = 'bool' # BOOL
103+
CHAR = 'str'
104+
CLOB = 'str'
105+
DATE = 'date'
106+
Date = 'date' # DATE
107+
DATETIME = 'datetime'
108+
DateTime = 'datetime' # DATETIME
109+
DECIMAL = 'Decimal'
110+
DOUBLE = 'float'
111+
Double = 'float' # DOUBLE
112+
DOUBLE_PRECISION = 'float'
113+
Enum = 'Enum' # Enum()
114+
FLOAT = 'float'
115+
Float = 'float' # FLOAT
116+
INT = 'int' # INTEGER
117+
INTEGER = 'int'
118+
Integer = 'int' # INTEGER
119+
Interval = 'timedelta' # DATETIME
120+
JSON = 'dict'
121+
LargeBinary = 'bytes' # BLOB
122+
NCHAR = 'str'
123+
NUMERIC = 'Decimal'
124+
Numeric = 'Decimal' # NUMERIC
125+
NVARCHAR = 'str' # String
126+
PickleType = 'bytes' # BLOB
127+
REAL = 'float'
128+
SMALLINT = 'int'
129+
SmallInteger = 'int' # SMALLINT
130+
String = 'str' # String
131+
TEXT = 'str'
132+
Text = 'str' # TEXT
133+
TIME = 'time'
134+
Time = 'time' # TIME
135+
TIMESTAMP = 'datetime'
136+
Unicode = 'str' # String
137+
UnicodeText = 'str' # TEXT
138+
UUID = 'str | UUID'
139+
Uuid = 'str' # CHAR(32)
140+
VARBINARY = 'bytes'
141+
VARCHAR = 'str' # String
142+
143+
# sa.dialects.mysql 导入
144+
BIT = 'bool'
145+
ENUM = 'Enum'
146+
LONGBLOB = 'bytes'
147+
LONGTEXT = 'str'
148+
MEDIUMBLOB = 'bytes'
149+
MEDIUMINT = 'int'
150+
MEDIUMTEXT = 'str'
151+
SET = 'list[str]'
152+
TINYBLOB = 'bytes'
153+
TINYINT = 'int'
154+
TINYTEXT = 'str'
155+
YEAR = 'int'
156+
157+
158+
class GenModelPostgreSQLColumnType(StrEnum):
159+
"""代码生成模型列类型(PostgreSQL),仅作为数据保留,并未实施"""
160+
161+
# Python 类型映射
162+
BIGINT = 'int'
163+
BigInteger = 'int' # BIGINT
164+
BINARY = 'bytes'
165+
BLOB = 'bytes'
166+
BOOLEAN = 'bool'
167+
Boolean = 'bool' # BOOLEAN
168+
CHAR = 'str'
169+
CLOB = 'str'
170+
DATE = 'date'
171+
Date = 'date' # DATE
172+
DATETIME = 'datetime'
173+
DateTime = 'datetime' # TIMESTAMP WITHOUT TIME ZONE
174+
DECIMAL = 'Decimal'
175+
DOUBLE = 'float'
176+
Double = 'float' # DOUBLE PRECISION
177+
DOUBLE_PRECISION = 'float' # DOUBLE PRECISION
178+
Enum = 'Enum' # Enum(name='enum')
179+
FLOAT = 'float'
180+
Float = 'float' # FLOAT
181+
INT = 'int' # INTEGER
182+
INTEGER = 'int'
183+
Integer = 'int' # INTEGER
184+
Interval = 'timedelta' # INTERVAL
185+
JSON = 'dict'
186+
LargeBinary = 'bytes' # BYTEA
187+
NCHAR = 'str'
188+
NUMERIC = 'Decimal'
189+
Numeric = 'Decimal' # NUMERIC
190+
NVARCHAR = 'str' # String
191+
PickleType = 'bytes' # BYTEA
192+
REAL = 'float'
193+
SMALLINT = 'int'
194+
SmallInteger = 'int' # SMALLINT
195+
String = 'str' # String
196+
TEXT = 'str'
197+
Text = 'str' # TEXT
198+
TIME = 'time' # TIME WITHOUT TIME ZONE
199+
Time = 'time' # TIME WITHOUT TIME ZONE
200+
TIMESTAMP = 'datetime' # TIMESTAMP WITHOUT TIME ZONE
201+
Unicode = 'str' # String
202+
UnicodeText = 'str' # TEXT
203+
UUID = 'str | UUID'
204+
Uuid = 'str'
205+
VARBINARY = 'bytes'
206+
VARCHAR = 'str' # String
207+
208+
# sa.dialects.postgresql 导入
209+
ARRAY = 'list'
210+
BIT = 'bool'
211+
BYTEA = 'bytes'
212+
CIDR = 'str'
213+
CITEXT = 'str'
214+
DATEMULTIRANGE = 'list[date]'
215+
DATERANGE = 'tuple[date, date]'
216+
DOMAIN = 'str'
217+
ENUM = 'Enum'
218+
HSTORE = 'dict'
219+
INET = 'str'
220+
INT4MULTIRANGE = 'list[int]'
221+
INT4RANGE = 'tuple[int, int]'
222+
INT8MULTIRANGE = 'list[int]'
223+
INT8RANGE = 'tuple[int, int]'
224+
INTERVAL = 'timedelta'
225+
JSONB = 'dict'
226+
JSONPATH = 'str'
227+
MACADDR = 'str'
228+
MACADDR8 = 'str'
229+
MONEY = 'Decimal'
230+
NUMMULTIRANGE = 'list[Decimal]'
231+
NUMRANGE = 'tuple[Decimal, Decimal]'
232+
OID = 'int'
233+
REGCLASS = 'str'
234+
REGCONFIG = 'str'
235+
TSMULTIRANGE = 'list[datetime]'
236+
TSQUERY = 'str'
237+
TSRANGE = 'tuple[datetime, datetime]'
238+
TSTZMULTIRANGE = 'list[datetime]'
239+
TSTZRANGE = 'tuple[datetime, datetime]'
240+
TSVECTOR = 'str'

0 commit comments

Comments
 (0)