Skip to content

Commit c159a1b

Browse files
committed
refactor/prework to refactor direct database access to follow dir pattern
1 parent 6ad1932 commit c159a1b

File tree

3 files changed

+92
-3
lines changed

3 files changed

+92
-3
lines changed

src/ai/backend/manager/data/auth/types.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,16 @@
1+
from __future__ import annotations
2+
13
import uuid
24
from dataclasses import dataclass
35
from datetime import datetime
46
from typing import Optional
57

8+
from ai.backend.manager.data.keypair.types import KeyPairData
9+
from ai.backend.manager.data.resource.types import (
10+
KeyPairResourcePolicyData,
11+
UserResourcePolicyData,
12+
)
13+
from ai.backend.manager.data.user.types import UserData as FullUserData
614
from ai.backend.manager.models.user import UserRole, UserStatus
715

816

@@ -47,3 +55,28 @@ class UserData:
4755
class GroupMembershipData:
4856
group_id: uuid.UUID
4957
user_id: uuid.UUID
58+
59+
60+
@dataclass
61+
class CredentialData:
62+
"""
63+
Aggregated credential data for authentication.
64+
65+
Combines user, keypair, and their resource policies into a single data structure.
66+
Used by authentication middleware to populate request context.
67+
"""
68+
69+
user: FullUserData
70+
user_resource_policy: UserResourcePolicyData
71+
keypair: KeyPairData
72+
keypair_resource_policy: KeyPairResourcePolicyData
73+
74+
@property
75+
def is_admin(self) -> bool:
76+
"""Check if the keypair has admin privileges."""
77+
return self.keypair.is_admin
78+
79+
@property
80+
def is_superadmin(self) -> bool:
81+
"""Check if the user has superadmin role."""
82+
return self.user.role == UserRole.SUPERADMIN

src/ai/backend/manager/repositories/auth/db_source/db_source.py

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,11 @@
1414
from ai.backend.common.resilience.policies.metrics import MetricArgs, MetricPolicy
1515
from ai.backend.common.resilience.policies.retry import BackoffStrategy, RetryArgs, RetryPolicy
1616
from ai.backend.common.resilience.resilience import Resilience
17-
from ai.backend.manager.data.auth.types import GroupMembershipData, UserData
17+
from ai.backend.manager.data.auth.types import CredentialData, GroupMembershipData, UserData
1818
from ai.backend.manager.errors.auth import GroupMembershipNotFoundError, UserCreationError
1919
from ai.backend.manager.models.group import association_groups_users, groups
2020
from ai.backend.manager.models.hasher.types import PasswordInfo
21-
from ai.backend.manager.models.keypair import keypairs
21+
from ai.backend.manager.models.keypair import KeyPairRow, keypairs
2222
from ai.backend.manager.models.user import (
2323
UserRow,
2424
UserStatus,
@@ -285,3 +285,43 @@ async def fetch_current_time(self) -> datetime:
285285
"""Fetch current time from database."""
286286
async with self._db.begin_readonly() as db_conn:
287287
return await db_conn.scalar(sa.select(sa.func.now()))
288+
289+
@auth_db_source_resilience.apply()
290+
async def fetch_credential_by_access_key(self, access_key: str) -> Optional[CredentialData]:
291+
"""
292+
Fetch user credential data by access key.
293+
294+
Queries keypair with user and resource policies using ORM relationships.
295+
Returns None if access key not found or inactive.
296+
297+
Args:
298+
access_key: The access key to look up.
299+
300+
Returns:
301+
CredentialData containing user, keypair, and resource policies,
302+
or None if not found.
303+
"""
304+
async with self._db.begin_session() as db_session:
305+
query = (
306+
sa.select(KeyPairRow)
307+
.where((KeyPairRow.access_key == access_key) & (KeyPairRow.is_active.is_(True)))
308+
.options(
309+
joinedload(KeyPairRow.resource_policy_row),
310+
joinedload(KeyPairRow.user_row).joinedload(UserRow.resource_policy_row),
311+
)
312+
)
313+
keypair_row = await db_session.scalar(query)
314+
315+
if keypair_row is None:
316+
return None
317+
318+
user_row = keypair_row.user_row
319+
if user_row is None:
320+
return None
321+
322+
return CredentialData(
323+
user=user_row.to_data(),
324+
user_resource_policy=user_row.resource_policy_row.to_dataclass(),
325+
keypair=keypair_row.to_data(),
326+
keypair_resource_policy=keypair_row.resource_policy_row.to_dataclass(),
327+
)

src/ai/backend/manager/repositories/auth/repository.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from ai.backend.common.metrics.metric import DomainType, LayerType
66
from ai.backend.common.resilience.policies.metrics import MetricArgs, MetricPolicy
77
from ai.backend.common.resilience.resilience import Resilience
8-
from ai.backend.manager.data.auth.types import GroupMembershipData, UserData
8+
from ai.backend.manager.data.auth.types import CredentialData, GroupMembershipData, UserData
99
from ai.backend.manager.models.hasher.types import PasswordInfo
1010
from ai.backend.manager.models.user import UserRow
1111
from ai.backend.manager.models.utils import ExtendedAsyncSAEngine
@@ -100,3 +100,19 @@ async def get_user_row_by_uuid(self, user_uuid: UUID) -> UserRow:
100100
@auth_repository_resilience.apply()
101101
async def get_current_time(self) -> datetime:
102102
return await self._db_source.fetch_current_time()
103+
104+
@auth_repository_resilience.apply()
105+
async def get_credential_by_access_key(self, access_key: str) -> Optional[CredentialData]:
106+
"""
107+
Get user credential data by access key.
108+
109+
Used by authentication middleware to populate request context.
110+
111+
Args:
112+
access_key: The access key to look up.
113+
114+
Returns:
115+
CredentialData containing user, keypair, and resource policies,
116+
or None if not found or inactive.
117+
"""
118+
return await self._db_source.fetch_credential_by_access_key(access_key)

0 commit comments

Comments
 (0)