Skip to content

Commit 3ee5ec2

Browse files
authored
Add the join condition onclause support (#54)
* Add the join condition onclause support * Update join_type field define * Improve performance * Update testcases * Add testcases * Update load strategy support * Expose JoinConfig * Add join config testcases
1 parent 4f679d0 commit 3ee5ec2

File tree

13 files changed

+1399
-396
lines changed

13 files changed

+1399
-396
lines changed

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,9 @@ order-by-type = true
5959
quote-style = "single"
6060
docstring-code-format = true
6161

62+
[tool.pytest.ini_options]
63+
asyncio_default_fixture_loop_scope = "session"
64+
6265
[build-system]
6366
requires = ["hatchling"]
6467
build-backend = "hatchling.build"

sqlalchemy_crud_plus/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#!/usr/bin/env python3
22
# -*- coding: utf-8 -*-
33
from .crud import CRUDPlus as CRUDPlus
4+
from .types import JoinConfig as JoinConfig
45

56
__version__ = '1.10.0'

sqlalchemy_crud_plus/crud.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,11 @@ async def count(
181181
if kwargs:
182182
filters.extend(parse_filters(self.model, **kwargs))
183183

184-
stmt = select(func.count()).select_from(self.model)
184+
if isinstance(self.primary_key, list):
185+
stmt = select(func.count()).select_from(self.model)
186+
else:
187+
stmt = select(func.count(self.primary_key)).select_from(self.model)
188+
185189
if filters:
186190
stmt = stmt.where(*filters)
187191

@@ -535,14 +539,14 @@ async def bulk_update_models(
535539
) -> int:
536540
"""
537541
Bulk update multiple instances with different data for each record.
538-
Each update item should have 'pk' key and other fields to update.
539542
540543
:param session: The SQLAlchemy async session
541544
:param objs: To save a list of Pydantic schemas or dict for data
542545
:param pk_mode: Primary key mode, when enabled, the data must contain the primary key data
543546
:param flush: If `True`, flush all object changes to the database
544547
:param commit: If `True`, commits the transaction immediately
545-
:return: Total number of updated records
548+
:param kwargs: Filter expressions using field__operator=value syntax
549+
:return:
546550
"""
547551
if not pk_mode:
548552
filters = parse_filters(self.model, **kwargs)

sqlalchemy_crud_plus/types.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,11 @@
22
# -*- coding: utf-8 -*-
33
from __future__ import annotations
44

5-
from typing import Literal, TypeVar
5+
from typing import Any, Literal, TypeVar
66

7-
from pydantic import BaseModel
7+
from pydantic import BaseModel, ConfigDict, Field
88
from sqlalchemy.orm import DeclarativeBase
9+
from sqlalchemy.orm.util import AliasedClass
910
from sqlalchemy.sql.base import ExecutableOption
1011

1112
Model = TypeVar('Model', bound=DeclarativeBase)
@@ -51,7 +52,16 @@
5152
'full',
5253
]
5354

54-
JoinConditions = list[str] | dict[str, JoinType]
55+
56+
class JoinConfig(BaseModel):
57+
model_config = ConfigDict(arbitrary_types_allowed=True)
58+
59+
model: type[Model] | AliasedClass
60+
join_on: Any
61+
join_type: JoinType = Field(default='inner')
62+
63+
64+
JoinConditions = list[str | JoinConfig] | dict[str, JoinType]
5565

5666
LoadOptions = list[ExecutableOption]
5767

sqlalchemy_crud_plus/utils.py

Lines changed: 33 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,13 @@
1717
load_only,
1818
noload,
1919
raiseload,
20-
selectin_polymorphic,
2120
selectinload,
2221
subqueryload,
2322
undefer,
2423
undefer_group,
25-
with_expression,
2624
)
2725
from sqlalchemy.orm.util import AliasedClass
26+
from sqlalchemy.sql.base import ExecutableOption
2827
from sqlalchemy.sql.operators import ColumnOperators
2928
from sqlalchemy.sql.schema import Column
3029

@@ -35,7 +34,7 @@
3534
ModelColumnError,
3635
SelectOperatorError,
3736
)
38-
from sqlalchemy_crud_plus.types import JoinConditions, LoadStrategies, Model
37+
from sqlalchemy_crud_plus.types import JoinConditions, JoinConfig, LoadStrategies, Model
3938

4039
_SUPPORTED_FILTERS = {
4140
# Comparison: https://docs.sqlalchemy.org/en/20/core/operators.html#comparison-operators
@@ -181,7 +180,7 @@ def _create_arithmetic_filters(column: Column, op: str, value: dict[str, Any]) -
181180
return arithmetic_filters
182181

183182

184-
def _create_and_filters(column: Column, op: str, value: Any) -> list[ColumnElement | None]:
183+
def _create_and_filters(column: Column, op: str, value: Any) -> list[ColumnElement[Any] | None]:
185184
"""
186185
Create AND filter expressions.
187186
@@ -197,7 +196,7 @@ def _create_and_filters(column: Column, op: str, value: Any) -> list[ColumnEleme
197196
return and_filters
198197

199198

200-
def parse_filters(model: type[Model] | AliasedClass, **kwargs) -> list[ColumnElement]:
199+
def parse_filters(model: type[Model] | AliasedClass, **kwargs) -> list[ColumnElement[Any]]:
201200
"""
202201
Parse filter expressions from keyword arguments.
203202
@@ -218,6 +217,9 @@ def parse_filters(model: type[Model] | AliasedClass, **kwargs) -> list[ColumnEle
218217
if field_name == '__or' and op == '':
219218
__or__filters = []
220219

220+
if not isinstance(value, dict):
221+
raise SelectOperatorError('__or__ filter value must be a dictionary')
222+
221223
for _key, _value in value.items():
222224
if '__' not in _key:
223225
_column = get_column(model, _key)
@@ -303,14 +305,16 @@ def apply_sorting(
303305
return stmt
304306

305307

306-
def build_load_strategies(model: type[Model], load_strategies: LoadStrategies | None) -> list:
308+
def build_load_strategies(model: type[Model], load_strategies: LoadStrategies | None) -> list[ExecutableOption]:
307309
"""
308310
Build relationship loading strategy options.
309311
310312
:param model: SQLAlchemy model class
311313
:param load_strategies: Loading strategies configuration
312314
:return:
313315
"""
316+
if load_strategies is None:
317+
return []
314318

315319
strategies_map = {
316320
'contains_eager': contains_eager,
@@ -325,10 +329,10 @@ def build_load_strategies(model: type[Model], load_strategies: LoadStrategies |
325329
# Load
326330
'defer': defer,
327331
'load_only': load_only,
328-
'selectin_polymorphic': selectin_polymorphic,
332+
# 'selectin_polymorphic': selectin_polymorphic,
329333
'undefer': undefer,
330334
'undefer_group': undefer_group,
331-
'with_expression': with_expression,
335+
# 'with_expression': with_expression,
332336
}
333337

334338
options = []
@@ -359,7 +363,7 @@ def build_load_strategies(model: type[Model], load_strategies: LoadStrategies |
359363
return options
360364

361365

362-
def apply_join_conditions(model: type[Model], stmt: Select, join_conditions: JoinConditions | None):
366+
def apply_join_conditions(model: type[Model], stmt: Select, join_conditions: JoinConditions | None) -> Select:
363367
"""
364368
Apply JOIN conditions to the query statement.
365369
@@ -368,13 +372,24 @@ def apply_join_conditions(model: type[Model], stmt: Select, join_conditions: Joi
368372
:param join_conditions: JOIN conditions configuration
369373
:return:
370374
"""
375+
if join_conditions is None:
376+
return stmt
377+
371378
if isinstance(join_conditions, list):
372-
for column in join_conditions:
373-
try:
374-
attr = getattr(model, column)
375-
stmt = stmt.join(attr)
376-
except AttributeError:
377-
raise ModelColumnError(f'Invalid model column: {column}')
379+
for v in join_conditions:
380+
if isinstance(v, str):
381+
try:
382+
attr = getattr(model, v)
383+
stmt = stmt.join(attr)
384+
except AttributeError:
385+
raise ModelColumnError(f'Invalid model column: {v}')
386+
elif isinstance(v, JoinConfig):
387+
if v.join_type == 'inner':
388+
stmt = stmt.join(v.model, v.join_on)
389+
elif v.join_type == 'left':
390+
stmt = stmt.join(v.model, v.join_on, isouter=True)
391+
elif v.join_type == 'full':
392+
stmt = stmt.join(v.model, v.join_on, full=True)
378393

379394
elif isinstance(join_conditions, dict):
380395
for column, join_type in join_conditions.items():
@@ -383,10 +398,10 @@ def apply_join_conditions(model: type[Model], stmt: Select, join_conditions: Joi
383398
raise JoinConditionError(f'Invalid join type: {join_type}, only supports {allowed_join_types}')
384399
try:
385400
attr = getattr(model, column)
386-
if join_type == 'left':
387-
stmt = stmt.join(attr, isouter=True)
388-
elif join_type == 'inner':
401+
if join_type == 'inner':
389402
stmt = stmt.join(attr)
403+
elif join_type == 'left':
404+
stmt = stmt.join(attr, isouter=True)
390405
elif join_type == 'full':
391406
stmt = stmt.join(attr, full=True)
392407
else:

tests/test_create.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,3 +132,32 @@ async def test_bulk_create_models_composite_keys(async_db_session: AsyncSession,
132132
assert results[0].id == 1000
133133
assert results[0].name == 'bulk_pks_1'
134134
assert results[0].sex == 'male'
135+
136+
137+
@pytest.mark.asyncio
138+
async def test_bulk_create_models_with_flush(async_db_session: AsyncSession, crud_ins: CRUDPlus[Ins]):
139+
data = [
140+
{'name': 'bulk_flush_1', 'del_flag': False, 'created_time': datetime.now()},
141+
{'name': 'bulk_flush_2', 'del_flag': False, 'created_time': datetime.now()},
142+
]
143+
144+
async with async_db_session.begin():
145+
results = await crud_ins.bulk_create_models(async_db_session, data, flush=True)
146+
147+
assert len(results) == 2
148+
assert results[0].name == 'bulk_flush_1'
149+
assert results[1].name == 'bulk_flush_2'
150+
151+
152+
@pytest.mark.asyncio
153+
async def test_bulk_create_models_with_commit(async_db_session: AsyncSession, crud_ins: CRUDPlus[Ins]):
154+
data = [
155+
{'name': 'bulk_commit_1', 'del_flag': False, 'created_time': datetime.now()},
156+
{'name': 'bulk_commit_2', 'del_flag': False, 'created_time': datetime.now()},
157+
]
158+
159+
results = await crud_ins.bulk_create_models(async_db_session, data, commit=True)
160+
161+
assert len(results) == 2
162+
assert results[0].name == 'bulk_commit_1'
163+
assert results[1].name == 'bulk_commit_2'

0 commit comments

Comments
 (0)