|
11 | 11 | ) |
12 | 12 |
|
13 | 13 | import sqlalchemy |
| 14 | +from pydantic import BaseModel |
14 | 15 | from sqlalchemy import ( |
15 | 16 | Column, |
16 | 17 | ColumnElement, |
|
22 | 23 | TypeCoerce, |
23 | 24 | WithinGroup, |
24 | 25 | ) |
25 | | -from sqlalchemy.orm import InstrumentedAttribute |
26 | | -from sqlalchemy.orm.attributes import QueryableAttribute |
| 26 | +from sqlalchemy.orm import InstrumentedAttribute, QueryableAttribute |
27 | 27 | from sqlalchemy.sql._typing import ( |
28 | 28 | _ColumnExpressionArgument, |
29 | 29 | _ColumnExpressionOrLiteralArgument, |
|
39 | 39 | UnaryExpression, |
40 | 40 | ) |
41 | 41 | from sqlalchemy.sql.type_api import TypeEngine |
| 42 | +from sqlmodel.main import SQLModel |
42 | 43 | from typing_extensions import Literal |
43 | 44 |
|
44 | 45 | from ._expression_select_cls import Select as Select |
@@ -210,7 +211,31 @@ def within_group( |
210 | 211 | return sqlalchemy.within_group(element, *order_by) |
211 | 212 |
|
212 | 213 |
|
213 | | -def col(column_expression: _T) -> QueryableAttribute[_T]: |
| 214 | +def col(column_expression: _T) -> InstrumentedAttribute[_T]: |
214 | 215 | if not isinstance(column_expression, (ColumnClause, Column, InstrumentedAttribute)): |
215 | 216 | raise RuntimeError(f"Not a SQLAlchemy column: {column_expression}") |
216 | | - return column_expression # type: ignore |
| 217 | + return column_expression |
| 218 | + |
| 219 | + |
| 220 | +def relations(relations_expression) -> QueryableAttribute: |
| 221 | + if not isinstance(relations_expression, (QueryableAttribute)): |
| 222 | + raise RuntimeError(f"Not a SQLAlchemy relations: {relations_expression}") |
| 223 | + return relations_expression |
| 224 | + |
| 225 | + |
| 226 | +def columns_by_schema( |
| 227 | + schema: type[BaseModel], |
| 228 | + model: type[SQLModel], |
| 229 | +) -> list[InstrumentedAttribute]: |
| 230 | + schema_fields = { |
| 231 | + *schema.model_fields.keys(), |
| 232 | + } |
| 233 | + model_fields = { |
| 234 | + *model.__pydantic_fields__.keys(), |
| 235 | + *model.__sqlmodel_relationships__.keys(), |
| 236 | + *model.__sqlalchemy_association_proxies__.keys(), |
| 237 | + *model.__sqlalchemy_constructs__.keys(), |
| 238 | + } |
| 239 | + return [ |
| 240 | + col(getattr(model, field)) for field in schema_fields if field in model_fields |
| 241 | + ] |
0 commit comments