Skip to content

Commit b89b8fe

Browse files
ast-grep --pattern 'session.query($NAME).where($COND).first()' -r 'session.scalars(select($NAME).where($COND).limit(1)).first()' -l py --update-all
1 parent 70e4d6b commit b89b8fe

File tree

8 files changed

+50
-61
lines changed

8 files changed

+50
-61
lines changed

api/controllers/console/workspace/__init__.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from typing import ParamSpec, TypeVar
44

55
from flask_login import current_user
6+
from sqlalchemy import select
67
from sqlalchemy.orm import Session
78
from werkzeug.exceptions import Forbidden
89

@@ -24,13 +25,9 @@ def decorated(*args: P.args, **kwargs: P.kwargs):
2425
tenant_id = user.current_tenant_id
2526

2627
with Session(db.engine) as session:
27-
permission = (
28-
session.query(TenantPluginPermission)
29-
.where(
30-
TenantPluginPermission.tenant_id == tenant_id,
31-
)
32-
.first()
33-
)
28+
permission = session.scalars(
29+
select(TenantPluginPermission).where(TenantPluginPermission.tenant_id == tenant_id).limit(1)
30+
).first()
3431

3532
if not permission:
3633
# no permission set, allow access for everyone

api/controllers/mcp/mcp.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from flask import Response
44
from flask_restx import Resource, reqparse
55
from pydantic import ValidationError
6+
from sqlalchemy import select
67
from sqlalchemy.orm import Session
78

89
from controllers.console.app.mcp_server import AppMCPServerStatus
@@ -89,11 +90,13 @@ def post(self, server_code: str):
8990

9091
def _get_mcp_server_and_app(self, server_code: str, session: Session) -> tuple[AppMCPServer, App]:
9192
"""Get and validate MCP server and app in one query session"""
92-
mcp_server = session.query(AppMCPServer).where(AppMCPServer.server_code == server_code).first()
93+
mcp_server = session.scalars(
94+
select(AppMCPServer).where(AppMCPServer.server_code == server_code).limit(1)
95+
).first()
9396
if not mcp_server:
9497
raise MCPRequestError(mcp_types.INVALID_REQUEST, "Server Not Found")
9598

96-
app = session.query(App).where(App.id == mcp_server.app_id).first()
99+
app = session.scalars(select(App).where(App.id == mcp_server.app_id).limit(1)).first()
97100
if not app:
98101
raise MCPRequestError(mcp_types.INVALID_REQUEST, "App Not Found")
99102

api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -473,9 +473,9 @@ def _agent_thought_to_stream_response(self, event: QueueAgentThoughtEvent) -> Op
473473
:return:
474474
"""
475475
with Session(db.engine, expire_on_commit=False) as session:
476-
agent_thought: Optional[MessageAgentThought] = (
477-
session.query(MessageAgentThought).where(MessageAgentThought.id == event.agent_thought_id).first()
478-
)
476+
agent_thought: Optional[MessageAgentThought] = session.scalars(
477+
select(MessageAgentThought).where(MessageAgentThought.id == event.agent_thought_id).limit(1)
478+
).first()
479479

480480
if agent_thought:
481481
return AgentThoughtStreamResponse(

api/core/rag/datasource/retrieval_service.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ def external_retrieve(
147147
@classmethod
148148
def _get_dataset(cls, dataset_id: str) -> Optional[Dataset]:
149149
with Session(db.engine) as session:
150-
return session.query(Dataset).where(Dataset.id == dataset_id).first()
150+
return session.scalars(select(Dataset).where(Dataset.id == dataset_id).limit(1)).first()
151151

152152
@classmethod
153153
def keyword_search(

api/core/tools/tool_file_manager.py

Lines changed: 11 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from uuid import uuid4
1111

1212
import httpx
13+
from sqlalchemy import select
1314
from sqlalchemy.orm import Session
1415

1516
from configs import dify_config
@@ -158,13 +159,7 @@ def get_file_binary(self, id: str) -> Union[tuple[bytes, str], None]:
158159
:return: the binary of the file, mime type
159160
"""
160161
with Session(self._engine, expire_on_commit=False) as session:
161-
tool_file: ToolFile | None = (
162-
session.query(ToolFile)
163-
.where(
164-
ToolFile.id == id,
165-
)
166-
.first()
167-
)
162+
tool_file: ToolFile | None = session.scalars(select(ToolFile).where(ToolFile.id == id).limit(1)).first()
168163

169164
if not tool_file:
170165
return None
@@ -182,13 +177,9 @@ def get_file_binary_by_message_file_id(self, id: str) -> Union[tuple[bytes, str]
182177
:return: the binary of the file, mime type
183178
"""
184179
with Session(self._engine, expire_on_commit=False) as session:
185-
message_file: MessageFile | None = (
186-
session.query(MessageFile)
187-
.where(
188-
MessageFile.id == id,
189-
)
190-
.first()
191-
)
180+
message_file: MessageFile | None = session.scalars(
181+
select(MessageFile).where(MessageFile.id == id).limit(1)
182+
).first()
192183

193184
# Check if message_file is not None
194185
if message_file is not None:
@@ -202,13 +193,9 @@ def get_file_binary_by_message_file_id(self, id: str) -> Union[tuple[bytes, str]
202193
else:
203194
tool_file_id = None
204195

205-
tool_file: ToolFile | None = (
206-
session.query(ToolFile)
207-
.where(
208-
ToolFile.id == tool_file_id,
209-
)
210-
.first()
211-
)
196+
tool_file: ToolFile | None = session.scalars(
197+
select(ToolFile).where(ToolFile.id == tool_file_id).limit(1)
198+
).first()
212199

213200
if not tool_file:
214201
return None
@@ -226,13 +213,9 @@ def get_file_generator_by_tool_file_id(self, tool_file_id: str) -> tuple[Optiona
226213
:return: the binary of the file, mime type
227214
"""
228215
with Session(self._engine, expire_on_commit=False) as session:
229-
tool_file: ToolFile | None = (
230-
session.query(ToolFile)
231-
.where(
232-
ToolFile.id == tool_file_id,
233-
)
234-
.first()
235-
)
216+
tool_file: ToolFile | None = session.scalars(
217+
select(ToolFile).where(ToolFile.id == tool_file_id).limit(1)
218+
).first()
236219

237220
if not tool_file:
238221
return None, None

api/services/plugin/plugin_auto_upgrade_service.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from sqlalchemy import select
12
from sqlalchemy.orm import Session
23

34
from extensions.ext_database import db
@@ -8,11 +9,11 @@ class PluginAutoUpgradeService:
89
@staticmethod
910
def get_strategy(tenant_id: str) -> TenantPluginAutoUpgradeStrategy | None:
1011
with Session(db.engine) as session:
11-
return (
12-
session.query(TenantPluginAutoUpgradeStrategy)
12+
return session.scalars(
13+
select(TenantPluginAutoUpgradeStrategy)
1314
.where(TenantPluginAutoUpgradeStrategy.tenant_id == tenant_id)
14-
.first()
15-
)
15+
.limit(1)
16+
).first()
1617

1718
@staticmethod
1819
def change_strategy(
@@ -24,11 +25,11 @@ def change_strategy(
2425
include_plugins: list[str],
2526
) -> bool:
2627
with Session(db.engine) as session:
27-
exist_strategy = (
28-
session.query(TenantPluginAutoUpgradeStrategy)
28+
exist_strategy = session.scalars(
29+
select(TenantPluginAutoUpgradeStrategy)
2930
.where(TenantPluginAutoUpgradeStrategy.tenant_id == tenant_id)
30-
.first()
31-
)
31+
.limit(1)
32+
).first()
3233
if not exist_strategy:
3334
strategy = TenantPluginAutoUpgradeStrategy(
3435
tenant_id=tenant_id,
@@ -52,11 +53,11 @@ def change_strategy(
5253
@staticmethod
5354
def exclude_plugin(tenant_id: str, plugin_id: str) -> bool:
5455
with Session(db.engine) as session:
55-
exist_strategy = (
56-
session.query(TenantPluginAutoUpgradeStrategy)
56+
exist_strategy = session.scalars(
57+
select(TenantPluginAutoUpgradeStrategy)
5758
.where(TenantPluginAutoUpgradeStrategy.tenant_id == tenant_id)
58-
.first()
59-
)
59+
.limit(1)
60+
).first()
6061
if not exist_strategy:
6162
# create for this tenant
6263
PluginAutoUpgradeService.change_strategy(

api/services/plugin/plugin_permission_service.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from sqlalchemy import select
12
from sqlalchemy.orm import Session
23

34
from extensions.ext_database import db
@@ -8,7 +9,9 @@ class PluginPermissionService:
89
@staticmethod
910
def get_permission(tenant_id: str) -> TenantPluginPermission | None:
1011
with Session(db.engine) as session:
11-
return session.query(TenantPluginPermission).where(TenantPluginPermission.tenant_id == tenant_id).first()
12+
return session.scalars(
13+
select(TenantPluginPermission).where(TenantPluginPermission.tenant_id == tenant_id).limit(1)
14+
).first()
1215

1316
@staticmethod
1417
def change_permission(
@@ -17,9 +20,9 @@ def change_permission(
1720
debug_permission: TenantPluginPermission.DebugPermission,
1821
):
1922
with Session(db.engine) as session:
20-
permission = (
21-
session.query(TenantPluginPermission).where(TenantPluginPermission.tenant_id == tenant_id).first()
22-
)
23+
permission = session.scalars(
24+
select(TenantPluginPermission).where(TenantPluginPermission.tenant_id == tenant_id).limit(1)
25+
).first()
2326
if not permission:
2427
permission = TenantPluginPermission(
2528
tenant_id=tenant_id, install_permission=install_permission, debug_permission=debug_permission

api/tests/unit_tests/models/test_types_enum_text.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import pytest
66
import sqlalchemy as sa
77
from sqlalchemy import exc as sa_exc
8-
from sqlalchemy import insert
8+
from sqlalchemy import insert, select
99
from sqlalchemy.orm import DeclarativeBase, Mapped, Session, mapped_column
1010
from sqlalchemy.sql.sqltypes import VARCHAR
1111

@@ -114,12 +114,14 @@ def test_insert_and_select(self):
114114
session.commit()
115115

116116
with Session(engine) as session:
117-
user = session.query(_User).where(_User.id == admin_user_id).first()
117+
user = session.scalars(select(_User).where(_User.id == admin_user_id).limit(1)).first()
118+
assert user is not None
118119
assert user.user_type == _UserType.admin
119120
assert user.user_type_nullable is None
120121

121122
with Session(engine) as session:
122-
user = session.query(_User).where(_User.id == normal_user_id).first()
123+
user = session.scalars(select(_User).where(_User.id == normal_user_id).limit(1)).first()
124+
assert user is not None
123125
assert user.user_type == _UserType.normal
124126
assert user.user_type_nullable == _UserType.normal
125127

@@ -188,4 +190,4 @@ def test_select_invalid_values(self):
188190

189191
with pytest.raises(ValueError) as exc:
190192
with Session(engine) as session:
191-
_user = session.query(_User).where(_User.id == 1).first()
193+
_user = session.scalars(select(_User).where(_User.id == 1).limit(1)).first()

0 commit comments

Comments
 (0)