Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,12 @@ statgpt_cli: install_dev
statgpt_admin:
poetry run python -m statgpt.admin.app $(ARGS)

statgpt_fix_statuses:
poetry run python -m statgpt.admin.fix_statuses

statgpt_auto_update:
poetry run python -m statgpt.admin.auto_update

statgpt_app:
poetry run python -m statgpt.app.app $(ARGS)

Expand Down
5 changes: 5 additions & 0 deletions statgpt/admin/admin.sh
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,17 @@ case "${ADMIN_MODE:-}" in
python -m statgpt.admin.fix_statuses
;;

AUTO_UPDATE)
python -m statgpt.admin.auto_update
;;

*)
echo "Unknown ADMIN_MODE = '${ADMIN_MODE:-}'. Possible values:"
echo " APP - start the admin application"
echo " ALEMBIC_UPGRADE - run alembic migrations to upgrade the database"
echo " FIX_STATUSES - fix inconsistent statuses in the database"
echo " INIT - run alembic migrations and fix inconsistent statuses"
echo " AUTO_UPDATE - run batch auto-update for all eligible channels"
exit 1
;;
esac
134 changes: 134 additions & 0 deletions statgpt/admin/auto_update.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
"""
Batch auto-update script for all datasets in channels with allow_auto_update enabled.
"""

import asyncio
import logging
import sys

import statgpt.common.schemas as schemas
from statgpt.admin.auth.auth_context import SystemUserAuthContext
from statgpt.admin.services.channel import (
AdminPortalChannelService,
deduplicate_dimensions_in_background_task,
)
from statgpt.admin.services.dataset import AdminPortalDataSetService, auto_update_in_background_task
from statgpt.common.auth.auth_context import AuthContext
from statgpt.common.models import get_session_contex_manager, optional_msi_token_manager_context

_log = logging.getLogger(__name__)
_SEPARATOR = "-" * 50


async def _discover_and_create_jobs() -> list[schemas.AutoUpdateJob]:
"""Find auto-update channels and create jobs for their datasets."""
_log.info(_SEPARATOR)
async with get_session_contex_manager() as session:
channels = await AdminPortalChannelService(session).get_auto_update_channels()
_log.info(f"Found {len(channels)} channel(s) with auto-update enabled")

if not channels:
return []

channel_ids = [ch.id for ch in channels]
return await AdminPortalDataSetService(session).create_auto_update_jobs(channel_ids)


async def _process_jobs(jobs: list[schemas.AutoUpdateJob], auth_context: AuthContext) -> None:
"""Run all auto-update jobs concurrently."""
_log.info(_SEPARATOR)
_log.info(f"Created {len(jobs)} auto-update job(s), starting processing...")

await asyncio.gather(
*(
auto_update_in_background_task(auto_update_job_id=job.id, auth_context=auth_context)
for job in jobs
),
return_exceptions=True,
)


async def _get_reindex_channel_ids(job_ids: list[int]) -> set[int]:
"""Get channel IDs that had at least one reindex triggered."""
async with get_session_contex_manager() as session:
return await AdminPortalDataSetService(session).get_reindex_channel_ids(job_ids)


async def _log_results(job_ids: list[int]) -> bool:
"""Log per-channel summary and return True if all jobs succeeded."""
_log.info(_SEPARATOR)
async with get_session_contex_manager() as session:
results = await AdminPortalDataSetService(session).get_auto_update_results(job_ids)

for r in results:
_log.info(f"channel '{r.deployment_id}' (id={r.channel_id}): {r.summary}")
for reason in r.failed_reasons:
_log.error(f" {reason}")

total = sum(r.total for r in results)
failed = sum(r.failed for r in results)
_log.info(
f"Auto-update complete: {total - failed} succeeded, {failed} failed "
f"out of {total} total"
)
return failed == 0


async def _deduplicate_channels(channel_ids: set[int], auth_context: AuthContext) -> None:
"""Run deduplication for channels that had a reindex."""
_log.info(_SEPARATOR)
_log.info(
f"Running deduplication for {len(channel_ids)} channel(s) "
f"with reindex: {sorted(channel_ids)}"
)
await asyncio.gather(
*(
deduplicate_dimensions_in_background_task(
channel_id=channel_id, auth_context=auth_context
)
for channel_id in channel_ids
),
return_exceptions=True,
)
_log.info("Deduplication complete")


async def run_auto_update() -> bool:
"""Run batch auto-update for all eligible channels.

Returns True if all jobs succeeded, False otherwise.
"""
auth_context = SystemUserAuthContext()

jobs = await _discover_and_create_jobs()
if not jobs:
return True

await _process_jobs(jobs, auth_context)
job_ids = [j.id for j in jobs]

reindex_channel_ids = await _get_reindex_channel_ids(job_ids)
if reindex_channel_ids:
await _deduplicate_channels(reindex_channel_ids, auth_context)

return await _log_results(job_ids)


async def main() -> None:
try:
_log.info("Starting batch auto-update script...")
async with optional_msi_token_manager_context():
success = await run_auto_update()

_log.info(_SEPARATOR)
if not success:
_log.error("Batch auto-update finished with failures")
sys.exit(1)
_log.info("Batch auto-update script completed successfully")
except Exception:
_log.exception("Error in batch auto-update script:")
sys.exit(1)


if __name__ == "__main__":
asyncio.run(main())
116 changes: 116 additions & 0 deletions statgpt/admin/services/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@
import os.path
import uuid
import zipfile
from collections import Counter, defaultdict
from collections.abc import Generator, Iterable
from typing import Any, NamedTuple

import yaml
from fastapi import BackgroundTasks, HTTPException, status
from pydantic import ValidationError
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from sqlalchemy.sql.expression import func, select, text, update

import statgpt.common.models as models
Expand All @@ -25,6 +27,7 @@
from statgpt.common.schemas import (
AuditActionType,
AuditEntityType,
AutoUpdateResult,
ChannelIndexStatusScope,
HybridSearchConfig,
)
Expand Down Expand Up @@ -59,6 +62,15 @@ class _DataHashes(NamedTuple):
special_dimensions_hash: str | None


class AutoUpdateChannelResult(NamedTuple):
channel_id: int
deployment_id: str
total: int
failed: int
summary: str
failed_reasons: list[str]


class AdminPortalDataSetService(DataSetService):

def __init__(self, session: AsyncSession) -> None:
Expand Down Expand Up @@ -1989,6 +2001,110 @@ async def get_auto_update_jobs_count(self, channel_dataset_id: int) -> int:
)
return await self._session.scalar(query) or 0

async def create_auto_update_jobs(self, channel_ids: list[int]) -> list[schemas.AutoUpdateJob]:
"""Create AutoUpdateJob records for all datasets in the given channels.

Returns a list of created job schemas.
"""
result = await self._session.execute(
select(models.ChannelDataset)
.where(models.ChannelDataset.channel_id.in_(channel_ids))
.options(selectinload(models.ChannelDataset.channel))
)
channel_datasets = list(result.scalars().all())

jobs: list[models.AutoUpdateJob] = []
for cd in channel_datasets:
job = models.AutoUpdateJob(
channel_dataset_id=cd.id,
status=StatusEnum.QUEUED,
)
self._session.add(job)
await self._session.flush()
jobs.append(job)

# Log per-channel summary
channels_by_id = {cd.channel.id: cd.channel for cd in channel_datasets}
counts = Counter(cd.channel_id for cd in channel_datasets)
for ch_id, count in counts.items():
ch = channels_by_id[ch_id]
_log.info(
f"Created {count} auto-update job(s) for channel "
f"'{ch.deployment_id}' (id={ch_id})"
)

await self._session.commit()
return [schemas.AutoUpdateJob.model_validate(job, from_attributes=True) for job in jobs]

async def get_reindex_channel_ids(self, job_ids: list[int]) -> set[int]:
"""Get channel IDs that had at least one REINDEX_TRIGGERED result."""
result = await self._session.execute(
select(models.AutoUpdateJob)
.where(models.AutoUpdateJob.id.in_(job_ids))
.where(models.AutoUpdateJob.result == AutoUpdateResult.REINDEX_TRIGGERED)
.options(selectinload(models.AutoUpdateJob.channel_dataset))
)
jobs = list(result.scalars().all())
return {j.channel_dataset.channel_id for j in jobs}

async def get_auto_update_results(self, job_ids: list[int]) -> list[AutoUpdateChannelResult]:
"""Collect per-channel auto-update results."""
result = await self._session.execute(
select(models.AutoUpdateJob)
.where(models.AutoUpdateJob.id.in_(job_ids))
.options(
selectinload(models.AutoUpdateJob.channel_dataset).selectinload(
models.ChannelDataset.channel
),
selectinload(models.AutoUpdateJob.created_version),
)
)
jobs = list(result.scalars().all())

channel_jobs: defaultdict[int, list[models.AutoUpdateJob]] = defaultdict(list)
for job in jobs:
channel_jobs[job.channel_dataset.channel.id].append(job)

results: list[AutoUpdateChannelResult] = []
for channel_id, ch_jobs in channel_jobs.items():
ch = ch_jobs[0].channel_dataset.channel
results.append(
AutoUpdateChannelResult(
channel_id=channel_id,
deployment_id=ch.deployment_id,
total=len(ch_jobs),
failed=sum(1 for j in ch_jobs if j.status == StatusEnum.FAILED),
summary=self._format_result_summary(ch_jobs),
failed_reasons=[
f"job {j.id}: {j.reason_for_failure}"
for j in ch_jobs
if j.status == StatusEnum.FAILED
],
)
)
return results

@staticmethod
def _format_result_summary(jobs: list[models.AutoUpdateJob]) -> str:
"""Build a human-readable summary of auto-update results."""
reindex_statuses: dict[str, Counter[str]] = defaultdict(Counter)
job_statuses: list[str] = []
for job in jobs:
job_status = job.result.value if job.result else job.status.value
job_statuses.append(job_status)
if job.created_version is not None:
reindex_statuses[job_status][job.created_version.preprocessing_status.value] += 1
result_counts = Counter(job_statuses)

parts: list[str] = []
for job_status, count in result_counts.items():
part = f"{count} {job_status}"
if job_status in reindex_statuses:
breakdown = ", ".join(f"{c} {s}" for s, c in reindex_statuses[job_status].items())
part += f" ({breakdown})"
parts.append(part)
return ", ".join(parts)

async def _set_auto_update_job_status(
self,
job: models.AutoUpdateJob,
Expand Down
4 changes: 4 additions & 0 deletions statgpt/common/schemas/data_query_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,10 @@ class DataQueryDetails(BaseToolDetails):
prompts: DataQueryPrompts = Field(default_factory=DataQueryPrompts) # type: ignore
messages: DataQueryMessages = Field(default_factory=DataQueryMessages) # type: ignore
attachments: DataQueryAttachments = Field(default_factory=DataQueryAttachments) # type: ignore
allow_auto_update: bool = Field(
default=False,
description="Whether datasets in this channel should be auto-updated by the batch auto-update script.",
)
tool_response_max_cells: PositiveInt = Field(
default=300,
description=(
Expand Down
17 changes: 17 additions & 0 deletions statgpt/common/services/channel.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from fastapi import HTTPException, status
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from sqlalchemy.sql.expression import func

import statgpt.common.models as models
Expand Down Expand Up @@ -59,3 +60,19 @@ def is_channel_hybrid(channel: models.Channel) -> bool:
return False
indexer_version = channel_config.data_query.details.indexer_version
return indexer_version == schemas.IndexerVersion.hybrid

@staticmethod
def _is_auto_update_enabled(channel: models.Channel) -> bool:
"""Returns `True` if the channel has auto-update enabled in data_query config."""
config = schemas.ChannelConfig.model_validate(channel.details)
if config.data_query is None:
return False
return config.data_query.details.allow_auto_update

async def get_auto_update_channels(self) -> list[models.Channel]:
"""Get all channels with auto-update enabled, with mapped_datasets eager-loaded."""
result = await self._session.execute(
select(models.Channel).options(selectinload(models.Channel.mapped_datasets))
)
channels = list(result.scalars().all())
return [ch for ch in channels if self._is_auto_update_enabled(ch)]
Loading