5959 get_common_fk_constraints ,
6060 get_compiled_expression ,
6161 get_constraint_sort_key ,
62+ get_stdlib_module_names ,
6263 qualified_table_name ,
6364 render_callable ,
6465 uses_default_name ,
@@ -119,9 +120,7 @@ def generate(self) -> str:
119120@dataclass (eq = False )
120121class TablesGenerator (CodeGenerator ):
121122 valid_options : ClassVar [set [str ]] = {"noindexes" , "noconstraints" , "nocomments" }
122- builtin_module_names : ClassVar [set [str ]] = set (sys .builtin_module_names ) | {
123- "dataclasses"
124- }
123+ stdlib_module_names : ClassVar [set [str ]] = get_stdlib_module_names ()
125124
126125 def __init__ (
127126 self ,
@@ -276,7 +275,7 @@ def add_import(self, obj: Any) -> None:
276275
277276 if type_ .__name__ in dialect_pkg .__all__ :
278277 pkgname = dialect_pkgname
279- elif type_ . __name__ in dir (sqlalchemy ):
278+ elif type_ is getattr (sqlalchemy , type_ . __name__ , None ):
280279 pkgname = "sqlalchemy"
281280 else :
282281 pkgname = type_ .__module__
@@ -300,21 +299,26 @@ def group_imports(self) -> list[list[str]]:
300299 stdlib_imports : list [str ] = []
301300 thirdparty_imports : list [str ] = []
302301
303- for package in sorted (self .imports ):
304- imports = ", " .join (sorted (self .imports [package ]))
302+ def get_collection (package : str ) -> list [str ]:
305303 collection = thirdparty_imports
306304 if package == "__future__" :
307305 collection = future_imports
308- elif package in self .builtin_module_names :
306+ elif package in self .stdlib_module_names :
309307 collection = stdlib_imports
310308 elif package in sys .modules :
311309 if "site-packages" not in (sys .modules [package ].__file__ or "" ):
312310 collection = stdlib_imports
311+ return collection
313312
313+ for package in sorted (self .imports ):
314+ imports = ", " .join (sorted (self .imports [package ]))
315+
316+ collection = get_collection (package )
314317 collection .append (f"from { package } import { imports } " )
315318
316319 for module in sorted (self .module_imports ):
317- thirdparty_imports .append (f"import { module } " )
320+ collection = get_collection (module )
321+ collection .append (f"import { module } " )
318322
319323 return [
320324 group
@@ -1212,10 +1216,7 @@ def render_table_args(self, table: Table) -> str:
12121216 else :
12131217 return ""
12141218
1215- def render_column_attribute (self , column_attr : ColumnAttribute ) -> str :
1216- column = column_attr .column
1217- rendered_column = self .render_column (column , column_attr .name != column .name )
1218-
1219+ def render_column_python_type (self , column : Column [Any ]) -> str :
12191220 def get_type_qualifiers () -> tuple [str , TypeEngine [Any ], str ]:
12201221 column_type = column .type
12211222 pre : list [str ] = []
@@ -1254,7 +1255,14 @@ def render_python_type(column_type: TypeEngine[Any]) -> str:
12541255
12551256 pre , col_type , post = get_type_qualifiers ()
12561257 column_python_type = f"{ pre } { render_python_type (col_type )} { post } "
1257- return f"{ column_attr .name } : Mapped[{ column_python_type } ] = { rendered_column } "
1258+ return column_python_type
1259+
1260+ def render_column_attribute (self , column_attr : ColumnAttribute ) -> str :
1261+ column = column_attr .column
1262+ rendered_column = self .render_column (column , column_attr .name != column .name )
1263+ rendered_column_python_type = self .render_column_python_type (column )
1264+
1265+ return f"{ column_attr .name } : Mapped[{ rendered_column_python_type } ] = { rendered_column } "
12581266
12591267 def render_relationship (self , relationship : RelationshipAttribute ) -> str :
12601268 def render_column_attrs (column_attrs : list [ColumnAttribute ]) -> str :
@@ -1444,15 +1452,6 @@ def collect_imports_for_model(self, model: Model) -> None:
14441452 if model .relationships :
14451453 self .add_literal_import ("sqlmodel" , "Relationship" )
14461454
1447- def collect_imports_for_column (self , column : Column [Any ]) -> None :
1448- super ().collect_imports_for_column (column )
1449- try :
1450- python_type = column .type .python_type
1451- except NotImplementedError :
1452- self .add_literal_import ("typing" , "Any" )
1453- else :
1454- self .add_import (python_type )
1455-
14561455 def render_module_variables (self , models : list [Model ]) -> str :
14571456 declarations : list [str ] = []
14581457 if any (not isinstance (model , ModelClass ) for model in models ):
@@ -1485,25 +1484,17 @@ def render_class_variables(self, model: ModelClass) -> str:
14851484
14861485 def render_column_attribute (self , column_attr : ColumnAttribute ) -> str :
14871486 column = column_attr .column
1488- try :
1489- python_type = column .type .python_type
1490- except NotImplementedError :
1491- python_type_name = "Any"
1492- else :
1493- python_type_name = python_type .__name__
1487+ rendered_column = self .render_column (column , True )
1488+ rendered_column_python_type = self .render_column_python_type (column )
14941489
14951490 kwargs : dict [str , Any ] = {}
1496- if (
1497- column .autoincrement and column .name in column .table .primary_key
1498- ) or column .nullable :
1499- self .add_literal_import ("typing" , "Optional" )
1491+ if column .nullable :
15001492 kwargs ["default" ] = None
1501- python_type_name = f"Optional[{ python_type_name } ]"
1502-
1503- rendered_column = self .render_column (column , True )
15041493 kwargs ["sa_column" ] = f"{ rendered_column } "
1494+
15051495 rendered_field = render_callable ("Field" , kwargs = kwargs )
1506- return f"{ column_attr .name } : { python_type_name } = { rendered_field } "
1496+
1497+ return f"{ column_attr .name } : { rendered_column_python_type } = { rendered_field } "
15071498
15081499 def render_relationship (self , relationship : RelationshipAttribute ) -> str :
15091500 rendered = super ().render_relationship (relationship ).partition (" = " )[2 ]
0 commit comments