Skip to content

Commit 2eeac0a

Browse files
BE: Organization filters (#706)
* TLK_572 - global organization filter for all requests * TLK_572 - global organization filter for all requests * TLK_972 - fixes after discussion * TLK_572 - sync * TLK_572 - sync + validation * TLK_572 - sync + validation * TLK_572 - sync + validation + tests
1 parent de58a58 commit 2eeac0a

File tree

18 files changed

+325
-43
lines changed

18 files changed

+325
-43
lines changed

src/backend/alembic/versions/2024_08_19_c301506b3676_.py

Lines changed: 32 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -5,52 +5,55 @@
55
Create Date: 2024-08-19 12:36:34.118536
66
77
"""
8+
89
from typing import Sequence, Union
910

1011
import sqlalchemy as sa
1112
from alembic import op
1213

1314
# revision identifiers, used by Alembic.
14-
revision: str = 'c301506b3676'
15-
down_revision: Union[str, None] = 'a76ebb869eb8'
15+
revision: str = "c301506b3676"
16+
down_revision: Union[str, None] = "a76ebb869eb8"
1617
branch_labels: Union[str, Sequence[str], None] = None
1718
depends_on: Union[str, Sequence[str], None] = None
1819

1920

2021
def upgrade() -> None:
2122
# ### commands auto generated by Alembic - please adjust! ###
22-
op.create_table('groups',
23-
sa.Column('display_name', sa.String(), nullable=False),
24-
sa.Column('id', sa.String(), nullable=False),
25-
sa.Column('created_at', sa.DateTime(), nullable=True),
26-
sa.Column('updated_at', sa.DateTime(), nullable=True),
27-
sa.PrimaryKeyConstraint('id'),
28-
sa.UniqueConstraint('display_name', name='unique_display_name')
23+
op.create_table(
24+
"groups",
25+
sa.Column("display_name", sa.String(), nullable=False),
26+
sa.Column("id", sa.String(), nullable=False),
27+
sa.Column("created_at", sa.DateTime(), nullable=True),
28+
sa.Column("updated_at", sa.DateTime(), nullable=True),
29+
sa.PrimaryKeyConstraint("id"),
30+
sa.UniqueConstraint("display_name", name="unique_display_name"),
2931
)
30-
op.create_table('user_group',
31-
sa.Column('user_id', sa.String(), nullable=False),
32-
sa.Column('group_id', sa.String(), nullable=False),
33-
sa.Column('display', sa.String(), nullable=False),
34-
sa.Column('id', sa.String(), nullable=False),
35-
sa.Column('created_at', sa.DateTime(), nullable=True),
36-
sa.Column('updated_at', sa.DateTime(), nullable=True),
37-
sa.ForeignKeyConstraint(['group_id'], ['groups.id'], ondelete='CASCADE'),
38-
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ondelete='CASCADE'),
39-
sa.PrimaryKeyConstraint('user_id', 'group_id', 'id')
32+
op.create_table(
33+
"user_group",
34+
sa.Column("user_id", sa.String(), nullable=False),
35+
sa.Column("group_id", sa.String(), nullable=False),
36+
sa.Column("display", sa.String(), nullable=False),
37+
sa.Column("id", sa.String(), nullable=False),
38+
sa.Column("created_at", sa.DateTime(), nullable=True),
39+
sa.Column("updated_at", sa.DateTime(), nullable=True),
40+
sa.ForeignKeyConstraint(["group_id"], ["groups.id"], ondelete="CASCADE"),
41+
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
42+
sa.PrimaryKeyConstraint("user_id", "group_id", "id"),
4043
)
41-
op.add_column('users', sa.Column('user_name', sa.String(), nullable=True))
42-
op.add_column('users', sa.Column('external_id', sa.String(), nullable=True))
43-
op.add_column('users', sa.Column('active', sa.Boolean(), nullable=True))
44-
op.create_unique_constraint('unique_user_name', 'users', ['user_name'])
44+
op.add_column("users", sa.Column("user_name", sa.String(), nullable=True))
45+
op.add_column("users", sa.Column("external_id", sa.String(), nullable=True))
46+
op.add_column("users", sa.Column("active", sa.Boolean(), nullable=True))
47+
op.create_unique_constraint("unique_user_name", "users", ["user_name"])
4548
# ### end Alembic commands ###
4649

4750

4851
def downgrade() -> None:
4952
# ### commands auto generated by Alembic - please adjust! ###
50-
op.drop_constraint('unique_user_name', 'users', type_='unique')
51-
op.drop_column('users', 'active')
52-
op.drop_column('users', 'external_id')
53-
op.drop_column('users', 'user_name')
54-
op.drop_table('user_group')
55-
op.drop_table('groups')
53+
op.drop_constraint("unique_user_name", "users", type_="unique")
54+
op.drop_column("users", "active")
55+
op.drop_column("users", "external_id")
56+
op.drop_column("users", "user_name")
57+
op.drop_table("user_group")
58+
op.drop_table("groups")
5659
# ### end Alembic commands ###

src/backend/config/routers.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
)
1010
from backend.services.request_validators import (
1111
validate_chat_request,
12+
validate_organization_header,
1213
validate_user_header,
1314
)
1415

@@ -35,104 +36,126 @@ class RouterName(StrEnum):
3536
RouterName.AUTH: {
3637
"default": [
3738
Depends(get_session),
39+
Depends(validate_organization_header),
3840
],
3941
"auth": [
4042
Depends(get_session),
43+
Depends(validate_organization_header),
4144
],
4245
},
4346
RouterName.CHAT: {
4447
"default": [
4548
Depends(get_session),
4649
Depends(validate_user_header),
4750
Depends(validate_chat_request),
51+
Depends(validate_organization_header),
4852
],
4953
"auth": [
5054
Depends(get_session),
5155
Depends(validate_chat_request),
5256
Depends(validate_authorization),
57+
Depends(validate_organization_header),
5358
],
5459
},
5560
RouterName.CONVERSATION: {
5661
"default": [
5762
Depends(get_session),
5863
Depends(validate_user_header),
64+
Depends(validate_organization_header),
5965
],
6066
"auth": [
6167
Depends(get_session),
6268
Depends(validate_authorization),
69+
Depends(validate_organization_header),
6370
],
6471
},
6572
RouterName.DEPLOYMENT: {
6673
"default": [
6774
Depends(get_session),
75+
Depends(validate_organization_header),
6876
],
6977
"auth": [
7078
Depends(get_session),
7179
Depends(validate_authorization),
80+
Depends(validate_organization_header),
7281
],
7382
},
7483
RouterName.EXPERIMENTAL_FEATURES: {
7584
"default": [
7685
Depends(get_session),
86+
Depends(validate_organization_header),
7787
],
7888
"auth": [
7989
Depends(get_session),
8090
Depends(validate_authorization),
91+
Depends(validate_organization_header),
8192
],
8293
},
8394
RouterName.TOOL: {
8495
"default": [
8596
Depends(get_session),
97+
Depends(validate_organization_header),
8698
],
8799
"auth": [
88100
Depends(get_session),
89101
Depends(validate_authorization),
102+
Depends(validate_organization_header),
90103
],
91104
},
92105
RouterName.USER: {
93106
"default": [
94107
Depends(get_session),
108+
Depends(validate_organization_header),
95109
],
96110
"auth": [
97111
# TODO: Remove auth only for create user endpoint
98112
Depends(get_session),
113+
Depends(validate_organization_header),
99114
],
100115
},
101116
RouterName.AGENT: {
102117
"default": [
103118
Depends(get_session),
119+
Depends(validate_organization_header),
104120
],
105121
"auth": [
106122
Depends(get_session),
107123
Depends(validate_authorization),
124+
Depends(validate_organization_header),
108125
],
109126
},
110127
RouterName.DEFAULT_AGENT: {
111128
"default": [
112129
Depends(get_session),
130+
Depends(validate_organization_header),
113131
],
114132
"auth": [
115133
Depends(get_session),
116134
Depends(validate_authorization),
135+
Depends(validate_organization_header),
117136
],
118137
},
119138
RouterName.SNAPSHOT: {
120139
"default": [
121140
Depends(get_session),
122141
Depends(validate_user_header),
142+
Depends(validate_organization_header),
123143
],
124144
"auth": [
125145
Depends(get_session),
126146
Depends(validate_authorization),
147+
Depends(validate_organization_header),
127148
],
128149
},
129150
RouterName.MODEL: {
130151
"default": [
131152
Depends(get_session),
153+
Depends(validate_organization_header),
132154
],
133155
"auth": [
134156
Depends(get_session),
135157
Depends(validate_authorization),
158+
Depends(validate_organization_header),
136159
],
137160
},
138161
RouterName.SCIM: {

src/backend/crud/organization.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from sqlalchemy.orm import Session
22

3-
from backend.database_models import Agent
3+
from backend.database_models.agent import Agent
44
from backend.database_models.organization import Organization
55
from backend.database_models.user import User, UserOrganizationAssociation
66
from backend.schemas.organization import UpdateOrganization

src/backend/database_models/base.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,45 @@
1+
from enum import StrEnum
12
from uuid import uuid4
23

34
from sqlalchemy import DateTime, String, func
4-
from sqlalchemy.orm import DeclarativeBase, mapped_column
5+
from sqlalchemy.orm import DeclarativeBase, Query, mapped_column
6+
7+
8+
class FilterFields(StrEnum):
9+
ORGANIZATION_ID = "organization_id"
10+
11+
12+
class CustomFilterQuery(Query):
13+
"""
14+
Custom query class that filters by field if the entity has field
15+
and the filter value is set.
16+
"""
17+
18+
ALLOWED_FILTER_FIELDS = [FilterFields.ORGANIZATION_ID]
19+
20+
def __new__(cls, *args, **kwargs):
21+
from backend.services.context import GLOBAL_REQUEST_CONTEXT
22+
23+
request_ctx = GLOBAL_REQUEST_CONTEXT.get()
24+
if request_ctx and request_ctx.use_global_filtering:
25+
query = None
26+
for field in cls.ALLOWED_FILTER_FIELDS:
27+
if (
28+
args
29+
and hasattr(args[0][0], field)
30+
and hasattr(request_ctx, field)
31+
and getattr(request_ctx, field)
32+
):
33+
if query:
34+
query = query.filter_by(**{field: getattr(request_ctx, field)})
35+
else:
36+
query = Query(*args, **kwargs).filter_by(
37+
**{field: getattr(request_ctx, field)}
38+
)
39+
if query:
40+
return query
41+
42+
return object.__new__(cls)
543

644

745
class Base(DeclarativeBase):

src/backend/database_models/database.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from sqlalchemy.orm import Session
77

88
from backend.config.settings import Settings
9+
from backend.database_models.base import CustomFilterQuery
910

1011
load_dotenv()
1112

@@ -16,7 +17,7 @@
1617

1718

1819
def get_session() -> Generator[Session, Any, None]:
19-
with Session(engine) as session:
20+
with Session(engine, query_cls=CustomFilterQuery) as session:
2021
yield session
2122

2223

src/backend/routers/agent.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,9 @@ async def list_agents(
193193
# TODO: get organization_id from user
194194
user_id = ctx.get_user_id()
195195
logger = ctx.get_logger()
196+
# request organization_id is used for filtering agents instead of header Organization-Id if enabled
197+
if organization_id:
198+
ctx.without_global_filtering()
196199

197200
try:
198201
agents = agent_crud.get_agents(

src/backend/routers/organization.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
UpdateOrganization,
1212
)
1313
from backend.schemas.context import Context
14+
from backend.schemas.user import User
1415
from backend.services.context import get_context
1516
from backend.services.request_validators import validate_organization_request
1617

@@ -88,11 +89,11 @@ def get_organization(
8889
session (DBSessionDep): Database session.
8990
9091
Returns:
91-
ManagedTool: Tool with the given ID.
92+
ManagedTool: Organization with the given ID.
9293
"""
9394
organization = organization_crud.get_organization(session, organization_id)
9495
if not organization:
95-
raise HTTPException(status_code=404, detail="Model not found")
96+
raise HTTPException(status_code=404, detail="Organization not found")
9697
return organization
9798

9899

@@ -114,7 +115,7 @@ def delete_organization(
114115
"""
115116
organization = organization_crud.get_organization(session, organization_id)
116117
if not organization:
117-
raise HTTPException(status_code=404, detail="Tool not found")
118+
raise HTTPException(status_code=404, detail="Organization not found")
118119
organization_crud.delete_organization(session, organization_id)
119120

120121
return DeleteOrganization()
@@ -138,3 +139,24 @@ def list_organizations(
138139
"""
139140
all_organizations = organization_crud.get_organizations(session)
140141
return all_organizations
142+
143+
144+
@router.get("/{organization_id}/users", response_model=list[User])
145+
def get_organization_users(
146+
organization_id: str, session: DBSessionDep, ctx: Context = Depends(get_context)
147+
) -> list[User]:
148+
"""
149+
Get organization users by ID.
150+
151+
Args:
152+
organization_id (str): Organization ID.
153+
session (DBSessionDep): Database session.
154+
155+
Returns:
156+
list[User]: List of users in the organization
157+
"""
158+
organization = organization_crud.get_organization(session, organization_id)
159+
if not organization:
160+
raise HTTPException(status_code=404, detail="Organization not found")
161+
162+
return organization.users

0 commit comments

Comments
 (0)