Skip to content

Commit 6433ac8

Browse files
author
longbingljw
authored
feat:json metadat filter adapt (#65)
* config adapt revert * ci test * fix mysql migration test * fix * fix * lint fix * fix ob config * fix * fix * fix * test over * test * fix * fix * fix style * test over * retain gin for pg * gin for pg * uuid defalut in versions * ci test * ci test * fix * fix * fix * fix * pg josnb * fix
1 parent 84935b9 commit 6433ac8

File tree

57 files changed

+386
-487
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

57 files changed

+386
-487
lines changed

.github/workflows/api-tests.yml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
name: Run Pytest
22

33
on:
4-
workflow_call:
4+
push:
5+
branches:
6+
- mysql-adapt
57

68
concurrency:
79
group: api-tests-${{ github.head_ref || github.run_id }}

.github/workflows/db-migration-test.yml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
name: DB Migration Test
22

33
on:
4-
workflow_call:
4+
push:
5+
branches:
6+
- mysql-adapt
57

68
concurrency:
79
group: db-migration-test-${{ github.ref }}

.github/workflows/style.yml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
name: Style check
22

33
on:
4-
workflow_call:
4+
push:
5+
branches:
6+
- mysql-adapt
57

68
concurrency:
79
group: style-${{ github.head_ref || github.run_id }}

api/core/rag/retrieval/dataset_retrieval.py

Lines changed: 29 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,7 @@
77
from typing import Any, Union, cast
88

99
from flask import Flask, current_app
10-
from sqlalchemy import Float, and_, or_, select, text
11-
from sqlalchemy import cast as sqlalchemy_cast
10+
from sqlalchemy import and_, or_, select
1211

1312
from core.app.app_config.entities import (
1413
DatasetEntity,
@@ -1023,60 +1022,55 @@ def _process_metadata_filter_func(
10231022
self, sequence: int, condition: str, metadata_name: str, value: Any | None, filters: list
10241023
):
10251024
if value is None and condition not in ("empty", "not empty"):
1026-
return
1025+
return filters
1026+
1027+
json_field = DatasetDocument.doc_metadata[metadata_name].as_string()
10271028

1028-
key = f"{metadata_name}_{sequence}"
1029-
key_value = f"{metadata_name}_{sequence}_value"
10301029
match condition:
10311030
case "contains":
1032-
filters.append(
1033-
(text(f"documents.doc_metadata ->> :{key} LIKE :{key_value}")).params(
1034-
**{key: metadata_name, key_value: f"%{value}%"}
1035-
)
1036-
)
1031+
filters.append(json_field.like(f"%{value}%"))
1032+
10371033
case "not contains":
1038-
filters.append(
1039-
(text(f"documents.doc_metadata ->> :{key} NOT LIKE :{key_value}")).params(
1040-
**{key: metadata_name, key_value: f"%{value}%"}
1041-
)
1042-
)
1034+
filters.append(json_field.notlike(f"%{value}%"))
1035+
10431036
case "start with":
1044-
filters.append(
1045-
(text(f"documents.doc_metadata ->> :{key} LIKE :{key_value}")).params(
1046-
**{key: metadata_name, key_value: f"{value}%"}
1047-
)
1048-
)
1037+
filters.append(json_field.like(f"{value}%"))
10491038

10501039
case "end with":
1051-
filters.append(
1052-
(text(f"documents.doc_metadata ->> :{key} LIKE :{key_value}")).params(
1053-
**{key: metadata_name, key_value: f"%{value}"}
1054-
)
1055-
)
1040+
filters.append(json_field.like(f"%{value}"))
1041+
10561042
case "is" | "=":
10571043
if isinstance(value, str):
1058-
filters.append(DatasetDocument.doc_metadata[metadata_name] == f'"{value}"')
1059-
else:
1060-
filters.append(sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Float) == value)
1044+
filters.append(json_field == value)
1045+
elif isinstance(value, (int, float)):
1046+
filters.append(DatasetDocument.doc_metadata[metadata_name].as_float() == value)
1047+
10611048
case "is not" | "≠":
10621049
if isinstance(value, str):
1063-
filters.append(DatasetDocument.doc_metadata[metadata_name] != f'"{value}"')
1064-
else:
1065-
filters.append(sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Float) != value)
1050+
filters.append(json_field != value)
1051+
elif isinstance(value, (int, float)):
1052+
filters.append(DatasetDocument.doc_metadata[metadata_name].as_float() != value)
1053+
10661054
case "empty":
10671055
filters.append(DatasetDocument.doc_metadata[metadata_name].is_(None))
1056+
10681057
case "not empty":
10691058
filters.append(DatasetDocument.doc_metadata[metadata_name].isnot(None))
1059+
10701060
case "before" | "<":
1071-
filters.append(sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Float) < value)
1061+
filters.append(DatasetDocument.doc_metadata[metadata_name].as_float() < value)
1062+
10721063
case "after" | ">":
1073-
filters.append(sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Float) > value)
1064+
filters.append(DatasetDocument.doc_metadata[metadata_name].as_float() > value)
1065+
10741066
case "≤" | "<=":
1075-
filters.append(sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Float) <= value)
1067+
filters.append(DatasetDocument.doc_metadata[metadata_name].as_float() <= value)
1068+
10761069
case "≥" | ">=":
1077-
filters.append(sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Float) >= value)
1070+
filters.append(DatasetDocument.doc_metadata[metadata_name].as_float() >= value)
10781071
case _:
10791072
pass
1073+
10801074
return filters
10811075

10821076
def _fetch_model_config(

api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py

Lines changed: 53 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,12 @@
66
from collections.abc import Mapping, Sequence
77
from typing import TYPE_CHECKING, Any, cast
88

9-
from sqlalchemy import Float, and_, func, or_, select, text
10-
from sqlalchemy import cast as sqlalchemy_cast
9+
from sqlalchemy import and_, func, literal, or_, select
1110
from sqlalchemy.orm import sessionmaker
1211

1312
from core.app.app_config.entities import DatasetRetrieveConfigEntity
1413
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
14+
from core.callback_handler.index_tool_callback_handler import DatasetDocument
1515
from core.entities.agent_entities import PlanningStrategy
1616
from core.entities.model_entities import ModelStatus
1717
from core.model_manager import ModelInstance, ModelManager
@@ -597,79 +597,79 @@ def _process_metadata_filter_func(
597597
if value is None and condition not in ("empty", "not empty"):
598598
return filters
599599

600-
key = f"{metadata_name}_{sequence}"
601-
key_value = f"{metadata_name}_{sequence}_value"
600+
json_field = DatasetDocument.doc_metadata[metadata_name].as_string()
601+
602602
match condition:
603603
case "contains":
604-
filters.append(
605-
(text(f"documents.doc_metadata ->> :{key} LIKE :{key_value}")).params(
606-
**{key: metadata_name, key_value: f"%{value}%"}
607-
)
608-
)
604+
filters.append(json_field.like(f"%{value}%"))
605+
609606
case "not contains":
610-
filters.append(
611-
(text(f"documents.doc_metadata ->> :{key} NOT LIKE :{key_value}")).params(
612-
**{key: metadata_name, key_value: f"%{value}%"}
613-
)
614-
)
607+
filters.append(json_field.notlike(f"%{value}%"))
608+
615609
case "start with":
616-
filters.append(
617-
(text(f"documents.doc_metadata ->> :{key} LIKE :{key_value}")).params(
618-
**{key: metadata_name, key_value: f"{value}%"}
619-
)
620-
)
610+
filters.append(json_field.like(f"{value}%"))
611+
621612
case "end with":
622-
filters.append(
623-
(text(f"documents.doc_metadata ->> :{key} LIKE :{key_value}")).params(
624-
**{key: metadata_name, key_value: f"%{value}"}
625-
)
626-
)
613+
filters.append(json_field.like(f"%{value}"))
627614
case "in":
628615
if isinstance(value, str):
629-
escaped_values = [v.strip().replace("'", "''") for v in str(value).split(",")]
630-
escaped_value_str = ",".join(escaped_values)
616+
value_list = [v.strip() for v in value.split(",") if v.strip()]
617+
elif isinstance(value, (list, tuple)):
618+
value_list = [str(v) for v in value if v is not None]
631619
else:
632-
escaped_value_str = str(value)
633-
filters.append(
634-
(text(f"documents.doc_metadata ->> :{key} = any(string_to_array(:{key_value},','))")).params(
635-
**{key: metadata_name, key_value: escaped_value_str}
636-
)
637-
)
620+
value_list = [str(value)] if value is not None else []
621+
622+
if not value_list:
623+
filters.append(literal(False))
624+
else:
625+
filters.append(json_field.in_(value_list))
626+
638627
case "not in":
639628
if isinstance(value, str):
640-
escaped_values = [v.strip().replace("'", "''") for v in str(value).split(",")]
641-
escaped_value_str = ",".join(escaped_values)
629+
value_list = [v.strip() for v in value.split(",") if v.strip()]
630+
elif isinstance(value, (list, tuple)):
631+
value_list = [str(v) for v in value if v is not None]
642632
else:
643-
escaped_value_str = str(value)
644-
filters.append(
645-
(text(f"documents.doc_metadata ->> :{key} != all(string_to_array(:{key_value},','))")).params(
646-
**{key: metadata_name, key_value: escaped_value_str}
647-
)
648-
)
649-
case "=" | "is":
650-
if isinstance(value, str):
651-
filters.append(Document.doc_metadata[metadata_name] == f'"{value}"')
633+
value_list = [str(value)] if value is not None else []
634+
635+
if not value_list:
636+
filters.append(literal(True))
652637
else:
653-
filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Float) == value)
638+
filters.append(json_field.notin_(value_list))
639+
640+
case "is" | "=":
641+
if isinstance(value, str):
642+
filters.append(json_field == value)
643+
elif isinstance(value, (int, float)):
644+
filters.append(DatasetDocument.doc_metadata[metadata_name].as_float() == value)
645+
654646
case "is not" | "≠":
655647
if isinstance(value, str):
656-
filters.append(Document.doc_metadata[metadata_name] != f'"{value}"')
657-
else:
658-
filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Float) != value)
648+
filters.append(json_field != value)
649+
elif isinstance(value, (int, float)):
650+
filters.append(DatasetDocument.doc_metadata[metadata_name].as_float() != value)
651+
659652
case "empty":
660-
filters.append(Document.doc_metadata[metadata_name].is_(None))
653+
filters.append(DatasetDocument.doc_metadata[metadata_name].is_(None))
654+
661655
case "not empty":
662-
filters.append(Document.doc_metadata[metadata_name].isnot(None))
656+
filters.append(DatasetDocument.doc_metadata[metadata_name].isnot(None))
657+
663658
case "before" | "<":
664-
filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Float) < value)
659+
filters.append(DatasetDocument.doc_metadata[metadata_name].as_float() < value)
660+
665661
case "after" | ">":
666-
filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Float) > value)
662+
filters.append(DatasetDocument.doc_metadata[metadata_name].as_float() > value)
663+
667664
case "≤" | "<=":
668-
filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Float) <= value)
665+
filters.append(DatasetDocument.doc_metadata[metadata_name].as_float() <= value)
666+
669667
case "≥" | ">=":
670-
filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Float) >= value)
668+
filters.append(DatasetDocument.doc_metadata[metadata_name].as_float() >= value)
669+
671670
case _:
672671
pass
672+
673673
return filters
674674

675675
@classmethod

api/libs/helper.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -178,12 +178,12 @@ def timezone(timezone_string):
178178

179179

180180
def convert_datetime_to_date(field, target_timezone: str = ":tz"):
181-
if dify_config.SQLALCHEMY_DATABASE_URI_SCHEME == "postgresql":
181+
if dify_config.DB_TYPE == "postgresql":
182182
return f"DATE(DATE_TRUNC('day', {field} AT TIME ZONE 'UTC' AT TIME ZONE {target_timezone}))"
183-
elif "mysql" in dify_config.SQLALCHEMY_DATABASE_URI_SCHEME:
183+
elif dify_config.DB_TYPE == "mysql":
184184
return f"DATE(CONVERT_TZ({field}, 'UTC', {target_timezone}))"
185185
else:
186-
raise NotImplementedError(f"Unsupported database URI scheme: {dify_config.SQLALCHEMY_DATABASE_URI_SCHEME}")
186+
raise NotImplementedError(f"Unsupported database type: {dify_config.DB_TYPE}")
187187

188188

189189
def generate_string(n):

api/migrations/versions/04c602f5dc9b_update_appmodelconfig_and_add_table_.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
"""
88
import sqlalchemy as sa
99
from alembic import op
10-
from uuid import uuid4
1110

1211
import models.types
1312

@@ -38,7 +37,7 @@ def upgrade():
3837
)
3938
else:
4039
op.create_table('tracing_app_configs',
41-
sa.Column('id', models.types.StringUUID(), default=lambda: str(uuid4()), nullable=False),
40+
sa.Column('id', models.types.StringUUID(), nullable=False),
4241
sa.Column('app_id', models.types.StringUUID(), nullable=False),
4342
sa.Column('tracing_provider', sa.String(length=255), nullable=True),
4443
sa.Column('tracing_config', sa.JSON(), nullable=True),

api/migrations/versions/053da0c1d756_add_api_tool_privacy.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
"""
88
import sqlalchemy as sa
99
from alembic import op
10-
from uuid import uuid4
1110
from sqlalchemy.dialects import postgresql
1211

1312
import models.types
@@ -40,7 +39,7 @@ def upgrade():
4039
)
4140
else:
4241
op.create_table('tool_conversation_variables',
43-
sa.Column('id', models.types.StringUUID(), default=lambda: str(uuid4()), nullable=False),
42+
sa.Column('id', models.types.StringUUID(), nullable=False),
4443
sa.Column('user_id', models.types.StringUUID(), nullable=False),
4544
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
4645
sa.Column('conversation_id', models.types.StringUUID(), nullable=False),

api/migrations/versions/16fa53d9faec_add_provider_model_support.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
"""
88
import sqlalchemy as sa
99
from alembic import op
10-
from uuid import uuid4
1110
from sqlalchemy.dialects import postgresql
1211

1312
import models.types
@@ -43,7 +42,7 @@ def upgrade():
4342
)
4443
else:
4544
op.create_table('provider_models',
46-
sa.Column('id', models.types.StringUUID(), default=lambda: str(uuid4()), nullable=False),
45+
sa.Column('id', models.types.StringUUID(), nullable=False),
4746
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
4847
sa.Column('provider_name', sa.String(length=40), nullable=False),
4948
sa.Column('model_name', sa.String(length=40), nullable=False),
@@ -72,7 +71,7 @@ def upgrade():
7271
)
7372
else:
7473
op.create_table('tenant_default_models',
75-
sa.Column('id', models.types.StringUUID(), default=lambda: str(uuid4()), nullable=False),
74+
sa.Column('id', models.types.StringUUID(), nullable=False),
7675
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
7776
sa.Column('provider_name', sa.String(length=40), nullable=False),
7877
sa.Column('model_name', sa.String(length=40), nullable=False),
@@ -97,7 +96,7 @@ def upgrade():
9796
)
9897
else:
9998
op.create_table('tenant_preferred_model_providers',
100-
sa.Column('id', models.types.StringUUID(), default=lambda: str(uuid4()), nullable=False),
99+
sa.Column('id', models.types.StringUUID(), nullable=False),
101100
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
102101
sa.Column('provider_name', sa.String(length=40), nullable=False),
103102
sa.Column('preferred_provider_type', sa.String(length=40), nullable=False),

api/migrations/versions/2024_08_15_0956-0251a1c768cc_add_tidb_auth_binding.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
"""
88
import sqlalchemy as sa
99
from alembic import op
10-
from uuid import uuid4
1110

1211
import models as models
1312

@@ -41,7 +40,7 @@ def upgrade():
4140
)
4241
else:
4342
op.create_table('tidb_auth_bindings',
44-
sa.Column('id', models.types.StringUUID(), default=lambda: str(uuid4()), nullable=False),
43+
sa.Column('id', models.types.StringUUID(), nullable=False),
4544
sa.Column('tenant_id', models.types.StringUUID(), nullable=True),
4645
sa.Column('cluster_id', sa.String(length=255), nullable=False),
4746
sa.Column('cluster_name', sa.String(length=255), nullable=False),

0 commit comments

Comments
 (0)