Skip to content

Commit a6db8af

Browse files
committed
feat: add new utility functions for SQLModel to handle relations and columns by schema; update col function to return InstrumentedAttribute for better type safety
1 parent e9c74ee commit a6db8af

File tree

1 file changed

+29
-4
lines changed

1 file changed

+29
-4
lines changed

sqlmodel/sql/expression.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
)
1212

1313
import sqlalchemy
14+
from pydantic import BaseModel
1415
from sqlalchemy import (
1516
Column,
1617
ColumnElement,
@@ -22,8 +23,7 @@
2223
TypeCoerce,
2324
WithinGroup,
2425
)
25-
from sqlalchemy.orm import InstrumentedAttribute
26-
from sqlalchemy.orm.attributes import QueryableAttribute
26+
from sqlalchemy.orm import InstrumentedAttribute, QueryableAttribute
2727
from sqlalchemy.sql._typing import (
2828
_ColumnExpressionArgument,
2929
_ColumnExpressionOrLiteralArgument,
@@ -39,6 +39,7 @@
3939
UnaryExpression,
4040
)
4141
from sqlalchemy.sql.type_api import TypeEngine
42+
from sqlmodel.main import SQLModel
4243
from typing_extensions import Literal
4344

4445
from ._expression_select_cls import Select as Select
@@ -210,7 +211,31 @@ def within_group(
210211
return sqlalchemy.within_group(element, *order_by)
211212

212213

213-
def col(column_expression: _T) -> QueryableAttribute[_T]:
214+
def col(column_expression: _T) -> InstrumentedAttribute[_T]:
214215
if not isinstance(column_expression, (ColumnClause, Column, InstrumentedAttribute)):
215216
raise RuntimeError(f"Not a SQLAlchemy column: {column_expression}")
216-
return column_expression # type: ignore
217+
return column_expression
218+
219+
220+
def relations(relations_expression) -> QueryableAttribute:
221+
if not isinstance(relations_expression, (QueryableAttribute)):
222+
raise RuntimeError(f"Not a SQLAlchemy relations: {relations_expression}")
223+
return relations_expression
224+
225+
226+
def columns_by_schema(
227+
schema: type[BaseModel],
228+
model: type[SQLModel],
229+
) -> list[InstrumentedAttribute]:
230+
schema_fields = {
231+
*schema.model_fields.keys(),
232+
}
233+
model_fields = {
234+
*model.__pydantic_fields__.keys(),
235+
*model.__sqlmodel_relationships__.keys(),
236+
*model.__sqlalchemy_association_proxies__.keys(),
237+
*model.__sqlalchemy_constructs__.keys(),
238+
}
239+
return [
240+
col(getattr(model, field)) for field in schema_fields if field in model_fields
241+
]

0 commit comments

Comments
 (0)