|
3 | 3 | import os.path |
4 | 4 | import uuid |
5 | 5 | import zipfile |
| 6 | +from collections import Counter, defaultdict |
6 | 7 | from collections.abc import Generator, Iterable |
7 | 8 | from typing import Any, NamedTuple |
8 | 9 |
|
9 | 10 | import yaml |
10 | 11 | from fastapi import BackgroundTasks, HTTPException, status |
11 | 12 | from pydantic import ValidationError |
12 | 13 | from sqlalchemy.ext.asyncio import AsyncSession |
| 14 | +from sqlalchemy.orm import selectinload |
13 | 15 | from sqlalchemy.sql.expression import func, select, text, update |
14 | 16 |
|
15 | 17 | import statgpt.common.models as models |
|
25 | 27 | from statgpt.common.schemas import ( |
26 | 28 | AuditActionType, |
27 | 29 | AuditEntityType, |
| 30 | + AutoUpdateResult, |
28 | 31 | ChannelIndexStatusScope, |
29 | 32 | HybridSearchConfig, |
30 | 33 | ) |
@@ -59,6 +62,15 @@ class _DataHashes(NamedTuple): |
59 | 62 | special_dimensions_hash: str | None |
60 | 63 |
|
61 | 64 |
|
| 65 | +class AutoUpdateChannelResult(NamedTuple): |
| 66 | + channel_id: int |
| 67 | + deployment_id: str |
| 68 | + total: int |
| 69 | + failed: int |
| 70 | + summary: str |
| 71 | + failed_reasons: list[str] |
| 72 | + |
| 73 | + |
62 | 74 | class AdminPortalDataSetService(DataSetService): |
63 | 75 |
|
64 | 76 | def __init__(self, session: AsyncSession) -> None: |
@@ -1989,6 +2001,105 @@ async def get_auto_update_jobs_count(self, channel_dataset_id: int) -> int: |
1989 | 2001 | ) |
1990 | 2002 | return await self._session.scalar(query) or 0 |
1991 | 2003 |
|
| 2004 | + async def create_auto_update_jobs(self, channel_ids: list[int]) -> list[schemas.AutoUpdateJob]: |
| 2005 | + """Create AutoUpdateJob records for all datasets in the given channels. |
| 2006 | +
|
| 2007 | + Returns a list of created job schemas. |
| 2008 | + """ |
| 2009 | + result = await self._session.execute( |
| 2010 | + select(models.ChannelDataset) |
| 2011 | + .where(models.ChannelDataset.channel_id.in_(channel_ids)) |
| 2012 | + .options(selectinload(models.ChannelDataset.channel)) |
| 2013 | + ) |
| 2014 | + channel_datasets = list(result.scalars().all()) |
| 2015 | + |
| 2016 | + jobs: list[models.AutoUpdateJob] = [] |
| 2017 | + for cd in channel_datasets: |
| 2018 | + job = models.AutoUpdateJob(channel_dataset_id=cd.id, status=StatusEnum.QUEUED) |
| 2019 | + self._session.add(job) |
| 2020 | + jobs.append(job) |
| 2021 | + await self._session.commit() |
| 2022 | + |
| 2023 | + # Log per-channel summary |
| 2024 | + channels_by_id = {cd.channel.id: cd.channel for cd in channel_datasets} |
| 2025 | + counts = Counter(cd.channel_id for cd in channel_datasets) |
| 2026 | + for ch_id, count in counts.items(): |
| 2027 | + ch = channels_by_id[ch_id] |
| 2028 | + _log.info( |
| 2029 | + f"Created {count} auto-update job(s) for channel '{ch.deployment_id}' (id={ch_id})" |
| 2030 | + ) |
| 2031 | + |
| 2032 | + return [schemas.AutoUpdateJob.model_validate(job, from_attributes=True) for job in jobs] |
| 2033 | + |
| 2034 | + async def get_reindex_channel_ids(self, job_ids: list[int]) -> set[int]: |
| 2035 | + """Get channel IDs that had at least one REINDEX_TRIGGERED result.""" |
| 2036 | + result = await self._session.execute( |
| 2037 | + select(models.AutoUpdateJob) |
| 2038 | + .where(models.AutoUpdateJob.id.in_(job_ids)) |
| 2039 | + .where(models.AutoUpdateJob.result == AutoUpdateResult.REINDEX_TRIGGERED) |
| 2040 | + .options(selectinload(models.AutoUpdateJob.channel_dataset)) |
| 2041 | + ) |
| 2042 | + jobs = list(result.scalars().all()) |
| 2043 | + return {j.channel_dataset.channel_id for j in jobs} |
| 2044 | + |
| 2045 | + async def get_auto_update_results(self, job_ids: list[int]) -> list[AutoUpdateChannelResult]: |
| 2046 | + """Collect per-channel auto-update results.""" |
| 2047 | + result = await self._session.execute( |
| 2048 | + select(models.AutoUpdateJob) |
| 2049 | + .where(models.AutoUpdateJob.id.in_(job_ids)) |
| 2050 | + .options( |
| 2051 | + selectinload(models.AutoUpdateJob.channel_dataset).selectinload( |
| 2052 | + models.ChannelDataset.channel |
| 2053 | + ), |
| 2054 | + selectinload(models.AutoUpdateJob.created_version), |
| 2055 | + ) |
| 2056 | + ) |
| 2057 | + jobs = list(result.scalars().all()) |
| 2058 | + |
| 2059 | + channel_jobs: defaultdict[int, list[models.AutoUpdateJob]] = defaultdict(list) |
| 2060 | + for job in jobs: |
| 2061 | + channel_jobs[job.channel_dataset.channel.id].append(job) |
| 2062 | + |
| 2063 | + results: list[AutoUpdateChannelResult] = [] |
| 2064 | + for channel_id, ch_jobs in channel_jobs.items(): |
| 2065 | + ch = ch_jobs[0].channel_dataset.channel |
| 2066 | + results.append( |
| 2067 | + AutoUpdateChannelResult( |
| 2068 | + channel_id=channel_id, |
| 2069 | + deployment_id=ch.deployment_id, |
| 2070 | + total=len(ch_jobs), |
| 2071 | + failed=sum(1 for j in ch_jobs if j.status == StatusEnum.FAILED), |
| 2072 | + summary=self._format_result_summary(ch_jobs), |
| 2073 | + failed_reasons=[ |
| 2074 | + f"job {j.id}: {j.reason_for_failure}" |
| 2075 | + for j in ch_jobs |
| 2076 | + if j.status == StatusEnum.FAILED |
| 2077 | + ], |
| 2078 | + ) |
| 2079 | + ) |
| 2080 | + return results |
| 2081 | + |
| 2082 | + @staticmethod |
| 2083 | + def _format_result_summary(jobs: list[models.AutoUpdateJob]) -> str: |
| 2084 | + """Build a human-readable summary of auto-update results.""" |
| 2085 | + reindex_statuses: dict[str, Counter[str]] = defaultdict(Counter) |
| 2086 | + job_statuses: list[str] = [] |
| 2087 | + for job in jobs: |
| 2088 | + job_status = job.result.value if job.result else job.status.value |
| 2089 | + job_statuses.append(job_status) |
| 2090 | + if job.created_version is not None: |
| 2091 | + reindex_statuses[job_status][job.created_version.preprocessing_status.value] += 1 |
| 2092 | + result_counts = Counter(job_statuses) |
| 2093 | + |
| 2094 | + parts: list[str] = [] |
| 2095 | + for job_status, count in result_counts.most_common(): |
| 2096 | + part = f"{count} {job_status}" |
| 2097 | + if job_status in reindex_statuses: |
| 2098 | + breakdown = ", ".join(f"{c} {s}" for s, c in reindex_statuses[job_status].items()) |
| 2099 | + part += f" ({breakdown})" |
| 2100 | + parts.append(part) |
| 2101 | + return ", ".join(parts) |
| 2102 | + |
1992 | 2103 | async def _set_auto_update_job_status( |
1993 | 2104 | self, |
1994 | 2105 | job: models.AutoUpdateJob, |
|
0 commit comments