|
| 1 | +from collections.abc import Mapping |
1 | 2 | from datetime import UTC, datetime, timedelta |
2 | 3 | from os.path import realpath |
3 | 4 | from pathlib import Path |
4 | 5 |
|
5 | | -from beanie.operators import LT |
| 6 | +from beanie.operators import GTE, LT, Eq, In, Or |
6 | 7 | from bson import json_util |
7 | 8 |
|
8 | 9 | from tekst import db, search |
9 | 10 | from tekst.auth import AccessTokenDocument, create_initial_superuser |
10 | 11 | from tekst.config import TekstConfig, get_config |
11 | 12 | from tekst.db import migrations |
12 | 13 | from tekst.logs import log, log_op_end, log_op_start |
| 14 | +from tekst.models.common import PydanticObjectId |
13 | 15 | from tekst.models.message import UserMessageDocument |
14 | 16 | from tekst.models.platform import PlatformStateDocument |
| 17 | +from tekst.models.segment import ClientSegmentDocument, ClientSegmentHead |
| 18 | +from tekst.models.user import UserDocument |
15 | 19 | from tekst.resources import call_resource_precompute_hooks |
16 | 20 | from tekst.state import get_state, update_state |
17 | 21 |
|
@@ -119,3 +123,61 @@ async def cleanup_task(cfg: TekstConfig = get_config()) -> dict[str, float]: |
119 | 123 | return { |
120 | 124 | "took": round(log_op_end(op_id), 2), |
121 | 125 | } |
| 126 | + |
| 127 | + |
| 128 | +async def _get_segment_restriction_queries( |
| 129 | + user: UserDocument | None = None, |
| 130 | +) -> tuple[Mapping]: |
| 131 | + if user is None: |
| 132 | + return (In(ClientSegmentDocument.restriction, ["none", None]),) |
| 133 | + if user.is_superuser: |
| 134 | + return tuple() |
| 135 | + return (In(ClientSegmentDocument.restriction, ["none", "user"]),) |
| 136 | + |
| 137 | + |
| 138 | +async def get_segment( |
| 139 | + *, |
| 140 | + segment_id: PydanticObjectId | None = None, |
| 141 | + user: UserDocument | None = None, |
| 142 | +) -> ClientSegmentDocument | None: |
| 143 | + return await ClientSegmentDocument.find_one( |
| 144 | + Eq(ClientSegmentDocument.id, segment_id), |
| 145 | + *(await _get_segment_restriction_queries(user)), |
| 146 | + ) |
| 147 | + |
| 148 | + |
| 149 | +async def get_segments( |
| 150 | + *, |
| 151 | + system: bool | None = None, |
| 152 | + user: UserDocument | None = None, |
| 153 | + head_projection: bool = False, |
| 154 | +) -> list[ClientSegmentDocument]: |
| 155 | + system_segments_queries = ( |
| 156 | + tuple() |
| 157 | + if system is None |
| 158 | + else ( |
| 159 | + GTE(ClientSegmentDocument.key, "system"), |
| 160 | + LT(ClientSegmentDocument.key, "systen"), |
| 161 | + ) |
| 162 | + if system |
| 163 | + else ( |
| 164 | + Or( |
| 165 | + LT(ClientSegmentDocument.key, "system"), |
| 166 | + GTE(ClientSegmentDocument.key, "systen"), |
| 167 | + ), |
| 168 | + ) |
| 169 | + ) |
| 170 | + if not head_projection: |
| 171 | + return await ClientSegmentDocument.find( |
| 172 | + *system_segments_queries, |
| 173 | + *(await _get_segment_restriction_queries(user)), |
| 174 | + ).to_list() |
| 175 | + else: |
| 176 | + return ( |
| 177 | + await ClientSegmentDocument.find( |
| 178 | + *system_segments_queries, |
| 179 | + *(await _get_segment_restriction_queries(user)), |
| 180 | + ) |
| 181 | + .project(ClientSegmentHead) |
| 182 | + .to_list() |
| 183 | + ) |
0 commit comments