Skip to content

Commit 8f927f8

Browse files
feat: add auto update script for all datasets in marked channels #178 (#181)
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 0f30831 commit 8f927f8

File tree

6 files changed

+281
-1
lines changed

6 files changed

+281
-1
lines changed

Makefile

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,12 @@ statgpt_cli: install_dev
4545
statgpt_admin:
4646
poetry run python -m statgpt.admin.app $(ARGS)
4747

48+
statgpt_fix_statuses:
49+
poetry run python -m statgpt.admin.fix_statuses
50+
51+
statgpt_auto_update:
52+
poetry run python -m statgpt.admin.auto_update
53+
4854
statgpt_app:
4955
poetry run python -m statgpt.app.app $(ARGS)
5056

statgpt/admin/admin.sh

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,17 @@ case "${ADMIN_MODE:-}" in
2626
python -m statgpt.admin.fix_statuses
2727
;;
2828

29+
AUTO_UPDATE)
30+
python -m statgpt.admin.auto_update
31+
;;
32+
2933
*)
3034
echo "Unknown ADMIN_MODE = '${ADMIN_MODE:-}'. Possible values:"
3135
echo " APP - start the admin application"
3236
echo " ALEMBIC_UPGRADE - run alembic migrations to upgrade the database"
3337
echo " FIX_STATUSES - fix inconsistent statuses in the database"
3438
echo " INIT - run alembic migrations and fix inconsistent statuses"
39+
echo " AUTO_UPDATE - run batch auto-update for all eligible channels"
3540
exit 1
3641
;;
3742
esac

statgpt/admin/auto_update.py

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
"""
2+
Batch auto-update script for all datasets in channels with `allow_auto_update` enabled.
3+
"""
4+
5+
import asyncio
6+
import logging
7+
import sys
8+
9+
import statgpt.common.schemas as schemas
10+
from statgpt.admin.auth.auth_context import SystemUserAuthContext
11+
from statgpt.admin.services.channel import (
12+
AdminPortalChannelService,
13+
deduplicate_dimensions_in_background_task,
14+
)
15+
from statgpt.admin.services.dataset import AdminPortalDataSetService, auto_update_in_background_task
16+
from statgpt.common.auth.auth_context import AuthContext
17+
from statgpt.common.models import get_session_contex_manager, optional_msi_token_manager_context
18+
19+
_log = logging.getLogger(__name__)
20+
_SEPARATOR = "-" * 50
21+
22+
23+
async def _discover_and_create_jobs() -> list[schemas.AutoUpdateJob]:
24+
"""Find auto-update channels and create jobs for their datasets."""
25+
_log.info(_SEPARATOR)
26+
async with get_session_contex_manager() as session:
27+
channel_service = AdminPortalChannelService(session)
28+
all_channels = await channel_service.get_channels_schemas(limit=None, offset=0)
29+
channel_ids = [
30+
ch.id
31+
for ch in all_channels
32+
if (dq := ch.details.data_query) is not None and dq.details.allow_auto_update
33+
]
34+
_log.info(f"Found {len(channel_ids)} channel(s) with auto-update enabled")
35+
36+
if not channel_ids:
37+
return []
38+
39+
return await AdminPortalDataSetService(session).create_auto_update_jobs(channel_ids)
40+
41+
42+
async def _process_jobs(jobs: list[schemas.AutoUpdateJob], auth_context: AuthContext) -> None:
43+
"""Run all auto-update jobs concurrently.
44+
45+
NOTE: The number of concurrent executions is limited by the semaphore
46+
in the ``@background_task`` decorator applied to ``auto_update_in_background_task``.
47+
"""
48+
_log.info(_SEPARATOR)
49+
_log.info(f"Created {len(jobs)} auto-update job(s), starting processing...")
50+
51+
results = await asyncio.gather(
52+
*(
53+
auto_update_in_background_task(auto_update_job_id=job.id, auth_context=auth_context)
54+
for job in jobs
55+
),
56+
return_exceptions=True,
57+
)
58+
for job, result in zip(jobs, results):
59+
if isinstance(result, Exception):
60+
_log.error(f"Auto-update job {job.id} failed with exception:", exc_info=result)
61+
62+
63+
async def _get_reindex_channel_ids(job_ids: list[int]) -> set[int]:
64+
"""Get channel IDs that had at least one reindex triggered."""
65+
async with get_session_contex_manager() as session:
66+
return await AdminPortalDataSetService(session).get_reindex_channel_ids(job_ids)
67+
68+
69+
async def _log_results(job_ids: list[int]) -> bool:
70+
"""Log per-channel summary and return `True` if all jobs succeeded."""
71+
_log.info(_SEPARATOR)
72+
async with get_session_contex_manager() as session:
73+
results = await AdminPortalDataSetService(session).get_auto_update_results(job_ids)
74+
75+
for r in results:
76+
_log.info(f"channel '{r.deployment_id}' (id={r.channel_id}): {r.summary}")
77+
for reason in r.failed_reasons:
78+
_log.error(f" {reason}")
79+
80+
total = sum(r.total for r in results)
81+
failed = sum(r.failed for r in results)
82+
_log.info(
83+
f"Auto-update complete: {total - failed} succeeded, {failed} failed "
84+
f"out of {total} total"
85+
)
86+
return failed == 0
87+
88+
89+
async def _deduplicate_channels(channel_ids: set[int], auth_context: AuthContext) -> None:
90+
"""Run deduplication for channels that had a reindex.
91+
92+
NOTE: The number of concurrent executions is limited by the semaphore
93+
in the ``@background_task`` decorator applied to ``deduplicate_dimensions_in_background_task``.
94+
"""
95+
_log.info(_SEPARATOR)
96+
sorted_ids = sorted(channel_ids)
97+
_log.info(f"Running deduplication for {len(sorted_ids)} channel(s) with reindex: {sorted_ids}")
98+
results = await asyncio.gather(
99+
*(
100+
deduplicate_dimensions_in_background_task(
101+
channel_id=channel_id, auth_context=auth_context
102+
)
103+
for channel_id in sorted_ids
104+
),
105+
return_exceptions=True,
106+
)
107+
for channel_id, result in zip(sorted_ids, results):
108+
if isinstance(result, Exception):
109+
_log.error(
110+
f"Deduplication for channel {channel_id} failed with exception:", exc_info=result
111+
)
112+
_log.info("Deduplication complete")
113+
114+
115+
async def run_auto_update() -> bool:
116+
"""Run batch auto-update for all eligible channels.
117+
118+
Returns:
119+
`True` if all jobs succeeded, `False` otherwise.
120+
"""
121+
auth_context = SystemUserAuthContext()
122+
123+
jobs = await _discover_and_create_jobs()
124+
if not jobs:
125+
return True
126+
127+
await _process_jobs(jobs, auth_context)
128+
job_ids = [j.id for j in jobs]
129+
130+
reindex_channel_ids = await _get_reindex_channel_ids(job_ids)
131+
if reindex_channel_ids:
132+
await _deduplicate_channels(reindex_channel_ids, auth_context)
133+
134+
return await _log_results(job_ids)
135+
136+
137+
async def main() -> None:
138+
try:
139+
_log.info("Starting batch auto-update script...")
140+
async with optional_msi_token_manager_context():
141+
success = await run_auto_update()
142+
143+
_log.info(_SEPARATOR)
144+
if not success:
145+
_log.error("Batch auto-update finished with failures")
146+
sys.exit(1)
147+
_log.info("Batch auto-update script completed successfully")
148+
except Exception:
149+
_log.exception("Error in batch auto-update script:")
150+
sys.exit(1)
151+
152+
153+
if __name__ == "__main__":
154+
asyncio.run(main())

statgpt/admin/services/dataset.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,15 @@
33
import os.path
44
import uuid
55
import zipfile
6+
from collections import Counter, defaultdict
67
from collections.abc import Generator, Iterable
78
from typing import Any, NamedTuple
89

910
import yaml
1011
from fastapi import BackgroundTasks, HTTPException, status
1112
from pydantic import ValidationError
1213
from sqlalchemy.ext.asyncio import AsyncSession
14+
from sqlalchemy.orm import selectinload
1315
from sqlalchemy.sql.expression import func, select, text, update
1416

1517
import statgpt.common.models as models
@@ -25,6 +27,7 @@
2527
from statgpt.common.schemas import (
2628
AuditActionType,
2729
AuditEntityType,
30+
AutoUpdateResult,
2831
ChannelIndexStatusScope,
2932
HybridSearchConfig,
3033
)
@@ -59,6 +62,15 @@ class _DataHashes(NamedTuple):
5962
special_dimensions_hash: str | None
6063

6164

65+
class AutoUpdateChannelResult(NamedTuple):
66+
channel_id: int
67+
deployment_id: str
68+
total: int
69+
failed: int
70+
summary: str
71+
failed_reasons: list[str]
72+
73+
6274
class AdminPortalDataSetService(DataSetService):
6375

6476
def __init__(self, session: AsyncSession) -> None:
@@ -1989,6 +2001,105 @@ async def get_auto_update_jobs_count(self, channel_dataset_id: int) -> int:
19892001
)
19902002
return await self._session.scalar(query) or 0
19912003

2004+
async def create_auto_update_jobs(self, channel_ids: list[int]) -> list[schemas.AutoUpdateJob]:
2005+
"""Create AutoUpdateJob records for all datasets in the given channels.
2006+
2007+
Returns a list of created job schemas.
2008+
"""
2009+
result = await self._session.execute(
2010+
select(models.ChannelDataset)
2011+
.where(models.ChannelDataset.channel_id.in_(channel_ids))
2012+
.options(selectinload(models.ChannelDataset.channel))
2013+
)
2014+
channel_datasets = list(result.scalars().all())
2015+
2016+
jobs: list[models.AutoUpdateJob] = []
2017+
for cd in channel_datasets:
2018+
job = models.AutoUpdateJob(channel_dataset_id=cd.id, status=StatusEnum.QUEUED)
2019+
self._session.add(job)
2020+
jobs.append(job)
2021+
await self._session.commit()
2022+
2023+
# Log per-channel summary
2024+
channels_by_id = {cd.channel.id: cd.channel for cd in channel_datasets}
2025+
counts = Counter(cd.channel_id for cd in channel_datasets)
2026+
for ch_id, count in counts.items():
2027+
ch = channels_by_id[ch_id]
2028+
_log.info(
2029+
f"Created {count} auto-update job(s) for channel '{ch.deployment_id}' (id={ch_id})"
2030+
)
2031+
2032+
return [schemas.AutoUpdateJob.model_validate(job, from_attributes=True) for job in jobs]
2033+
2034+
async def get_reindex_channel_ids(self, job_ids: list[int]) -> set[int]:
2035+
"""Get channel IDs that had at least one REINDEX_TRIGGERED result."""
2036+
result = await self._session.execute(
2037+
select(models.AutoUpdateJob)
2038+
.where(models.AutoUpdateJob.id.in_(job_ids))
2039+
.where(models.AutoUpdateJob.result == AutoUpdateResult.REINDEX_TRIGGERED)
2040+
.options(selectinload(models.AutoUpdateJob.channel_dataset))
2041+
)
2042+
jobs = list(result.scalars().all())
2043+
return {j.channel_dataset.channel_id for j in jobs}
2044+
2045+
async def get_auto_update_results(self, job_ids: list[int]) -> list[AutoUpdateChannelResult]:
2046+
"""Collect per-channel auto-update results."""
2047+
result = await self._session.execute(
2048+
select(models.AutoUpdateJob)
2049+
.where(models.AutoUpdateJob.id.in_(job_ids))
2050+
.options(
2051+
selectinload(models.AutoUpdateJob.channel_dataset).selectinload(
2052+
models.ChannelDataset.channel
2053+
),
2054+
selectinload(models.AutoUpdateJob.created_version),
2055+
)
2056+
)
2057+
jobs = list(result.scalars().all())
2058+
2059+
channel_jobs: defaultdict[int, list[models.AutoUpdateJob]] = defaultdict(list)
2060+
for job in jobs:
2061+
channel_jobs[job.channel_dataset.channel.id].append(job)
2062+
2063+
results: list[AutoUpdateChannelResult] = []
2064+
for channel_id, ch_jobs in channel_jobs.items():
2065+
ch = ch_jobs[0].channel_dataset.channel
2066+
results.append(
2067+
AutoUpdateChannelResult(
2068+
channel_id=channel_id,
2069+
deployment_id=ch.deployment_id,
2070+
total=len(ch_jobs),
2071+
failed=sum(1 for j in ch_jobs if j.status == StatusEnum.FAILED),
2072+
summary=self._format_result_summary(ch_jobs),
2073+
failed_reasons=[
2074+
f"job {j.id}: {j.reason_for_failure}"
2075+
for j in ch_jobs
2076+
if j.status == StatusEnum.FAILED
2077+
],
2078+
)
2079+
)
2080+
return results
2081+
2082+
@staticmethod
2083+
def _format_result_summary(jobs: list[models.AutoUpdateJob]) -> str:
2084+
"""Build a human-readable summary of auto-update results."""
2085+
reindex_statuses: dict[str, Counter[str]] = defaultdict(Counter)
2086+
job_statuses: list[str] = []
2087+
for job in jobs:
2088+
job_status = job.result.value if job.result else job.status.value
2089+
job_statuses.append(job_status)
2090+
if job.created_version is not None:
2091+
reindex_statuses[job_status][job.created_version.preprocessing_status.value] += 1
2092+
result_counts = Counter(job_statuses)
2093+
2094+
parts: list[str] = []
2095+
for job_status, count in result_counts.most_common():
2096+
part = f"{count} {job_status}"
2097+
if job_status in reindex_statuses:
2098+
breakdown = ", ".join(f"{c} {s}" for s, c in reindex_statuses[job_status].items())
2099+
part += f" ({breakdown})"
2100+
parts.append(part)
2101+
return ", ".join(parts)
2102+
19922103
async def _set_auto_update_job_status(
19932104
self,
19942105
job: models.AutoUpdateJob,

statgpt/common/schemas/data_query_tool.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,10 @@ class DataQueryDetails(BaseToolDetails):
268268
prompts: DataQueryPrompts = Field(default_factory=DataQueryPrompts) # type: ignore
269269
messages: DataQueryMessages = Field(default_factory=DataQueryMessages) # type: ignore
270270
attachments: DataQueryAttachments = Field(default_factory=DataQueryAttachments) # type: ignore
271+
allow_auto_update: bool = Field(
272+
default=False,
273+
description="Whether datasets in this channel should be auto-updated by the batch auto-update script.",
274+
)
271275
tool_response_max_cells: PositiveInt = Field(
272276
default=300,
273277
description=(

statgpt/common/services/channel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ async def get_channels_db(self, limit: int | None, offset: int) -> list[models.C
2727
q_result = await self._session.execute(query)
2828
return [item for item in q_result.scalars().all()]
2929

30-
async def get_channels_schemas(self, limit: int, offset: int) -> list[schemas.Channel]:
30+
async def get_channels_schemas(self, limit: int | None, offset: int) -> list[schemas.Channel]:
3131
channels = await self.get_channels_db(limit, offset)
3232
return [ChannelSerializer.db_to_schema(item) for item in channels]
3333

0 commit comments

Comments
 (0)