Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,12 @@ statgpt_cli: install_dev
statgpt_admin:
poetry run python -m statgpt.admin.app $(ARGS)

statgpt_fix_statuses:
poetry run python -m statgpt.admin.fix_statuses

statgpt_auto_update:
poetry run python -m statgpt.admin.auto_update

statgpt_app:
poetry run python -m statgpt.app.app $(ARGS)

Expand Down
5 changes: 5 additions & 0 deletions statgpt/admin/admin.sh
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,17 @@ case "${ADMIN_MODE:-}" in
python -m statgpt.admin.fix_statuses
;;

AUTO_UPDATE)
python -m statgpt.admin.auto_update
;;

*)
echo "Unknown ADMIN_MODE = '${ADMIN_MODE:-}'. Possible values:"
echo " APP - start the admin application"
echo " ALEMBIC_UPGRADE - run alembic migrations to upgrade the database"
echo " FIX_STATUSES - fix inconsistent statuses in the database"
echo " INIT - run alembic migrations and fix inconsistent statuses"
echo " AUTO_UPDATE - run batch auto-update for all eligible channels"
exit 1
;;
esac
154 changes: 154 additions & 0 deletions statgpt/admin/auto_update.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
"""
Batch auto-update script for all datasets in channels with `allow_auto_update` enabled.
"""

import asyncio
import logging
import sys

import statgpt.common.schemas as schemas
from statgpt.admin.auth.auth_context import SystemUserAuthContext
from statgpt.admin.services.channel import (
AdminPortalChannelService,
deduplicate_dimensions_in_background_task,
)
from statgpt.admin.services.dataset import AdminPortalDataSetService, auto_update_in_background_task
from statgpt.common.auth.auth_context import AuthContext
from statgpt.common.models import get_session_contex_manager, optional_msi_token_manager_context

_log = logging.getLogger(__name__)
_SEPARATOR = "-" * 50


async def _discover_and_create_jobs() -> list[schemas.AutoUpdateJob]:
"""Find auto-update channels and create jobs for their datasets."""
_log.info(_SEPARATOR)
async with get_session_contex_manager() as session:
channel_service = AdminPortalChannelService(session)
all_channels = await channel_service.get_channels_schemas(limit=None, offset=0)
channel_ids = [
ch.id
for ch in all_channels
if (dq := ch.details.data_query) is not None and dq.details.allow_auto_update
]
_log.info(f"Found {len(channel_ids)} channel(s) with auto-update enabled")

if not channel_ids:
return []

return await AdminPortalDataSetService(session).create_auto_update_jobs(channel_ids)


async def _process_jobs(jobs: list[schemas.AutoUpdateJob], auth_context: AuthContext) -> None:
"""Run all auto-update jobs concurrently.

NOTE: The number of concurrent executions is limited by the semaphore
in the ``@background_task`` decorator applied to ``auto_update_in_background_task``.
"""
_log.info(_SEPARATOR)
_log.info(f"Created {len(jobs)} auto-update job(s), starting processing...")

results = await asyncio.gather(
*(
auto_update_in_background_task(auto_update_job_id=job.id, auth_context=auth_context)
for job in jobs
),
return_exceptions=True,
)
for job, result in zip(jobs, results):
if isinstance(result, Exception):
_log.error(f"Auto-update job {job.id} failed with exception:", exc_info=result)


async def _get_reindex_channel_ids(job_ids: list[int]) -> set[int]:
"""Get channel IDs that had at least one reindex triggered."""
async with get_session_contex_manager() as session:
return await AdminPortalDataSetService(session).get_reindex_channel_ids(job_ids)


async def _log_results(job_ids: list[int]) -> bool:
"""Log per-channel summary and return `True` if all jobs succeeded."""
_log.info(_SEPARATOR)
async with get_session_contex_manager() as session:
results = await AdminPortalDataSetService(session).get_auto_update_results(job_ids)

for r in results:
_log.info(f"channel '{r.deployment_id}' (id={r.channel_id}): {r.summary}")
for reason in r.failed_reasons:
_log.error(f" {reason}")

total = sum(r.total for r in results)
failed = sum(r.failed for r in results)
_log.info(
f"Auto-update complete: {total - failed} succeeded, {failed} failed "
f"out of {total} total"
)
return failed == 0


async def _deduplicate_channels(channel_ids: set[int], auth_context: AuthContext) -> None:
"""Run deduplication for channels that had a reindex.

NOTE: The number of concurrent executions is limited by the semaphore
in the ``@background_task`` decorator applied to ``deduplicate_dimensions_in_background_task``.
"""
_log.info(_SEPARATOR)
sorted_ids = sorted(channel_ids)
_log.info(f"Running deduplication for {len(sorted_ids)} channel(s) with reindex: {sorted_ids}")
results = await asyncio.gather(
*(
deduplicate_dimensions_in_background_task(
channel_id=channel_id, auth_context=auth_context
)
for channel_id in sorted_ids
),
return_exceptions=True,
)
for channel_id, result in zip(sorted_ids, results):
if isinstance(result, Exception):
_log.error(
f"Deduplication for channel {channel_id} failed with exception:", exc_info=result
)
_log.info("Deduplication complete")


async def run_auto_update() -> bool:
"""Run batch auto-update for all eligible channels.

Returns:
`True` if all jobs succeeded, `False` otherwise.
"""
auth_context = SystemUserAuthContext()

jobs = await _discover_and_create_jobs()
if not jobs:
return True

await _process_jobs(jobs, auth_context)
job_ids = [j.id for j in jobs]

reindex_channel_ids = await _get_reindex_channel_ids(job_ids)
if reindex_channel_ids:
await _deduplicate_channels(reindex_channel_ids, auth_context)

return await _log_results(job_ids)


async def main() -> None:
try:
_log.info("Starting batch auto-update script...")
async with optional_msi_token_manager_context():
success = await run_auto_update()

_log.info(_SEPARATOR)
if not success:
_log.error("Batch auto-update finished with failures")
sys.exit(1)
_log.info("Batch auto-update script completed successfully")
except Exception:
_log.exception("Error in batch auto-update script:")
sys.exit(1)


if __name__ == "__main__":
asyncio.run(main())
111 changes: 111 additions & 0 deletions statgpt/admin/services/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@
import os.path
import uuid
import zipfile
from collections import Counter, defaultdict
from collections.abc import Generator, Iterable
from typing import Any, NamedTuple

import yaml
from fastapi import BackgroundTasks, HTTPException, status
from pydantic import ValidationError
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from sqlalchemy.sql.expression import func, select, text, update

import statgpt.common.models as models
Expand All @@ -25,6 +27,7 @@
from statgpt.common.schemas import (
AuditActionType,
AuditEntityType,
AutoUpdateResult,
ChannelIndexStatusScope,
HybridSearchConfig,
)
Expand Down Expand Up @@ -59,6 +62,15 @@ class _DataHashes(NamedTuple):
special_dimensions_hash: str | None


class AutoUpdateChannelResult(NamedTuple):
channel_id: int
deployment_id: str
total: int
failed: int
summary: str
failed_reasons: list[str]


class AdminPortalDataSetService(DataSetService):

def __init__(self, session: AsyncSession) -> None:
Expand Down Expand Up @@ -1989,6 +2001,105 @@ async def get_auto_update_jobs_count(self, channel_dataset_id: int) -> int:
)
return await self._session.scalar(query) or 0

async def create_auto_update_jobs(self, channel_ids: list[int]) -> list[schemas.AutoUpdateJob]:
"""Create AutoUpdateJob records for all datasets in the given channels.
Returns a list of created job schemas.
"""
result = await self._session.execute(
select(models.ChannelDataset)
.where(models.ChannelDataset.channel_id.in_(channel_ids))
.options(selectinload(models.ChannelDataset.channel))
)
channel_datasets = list(result.scalars().all())

jobs: list[models.AutoUpdateJob] = []
for cd in channel_datasets:
job = models.AutoUpdateJob(channel_dataset_id=cd.id, status=StatusEnum.QUEUED)
self._session.add(job)
jobs.append(job)
await self._session.commit()

# Log per-channel summary
channels_by_id = {cd.channel.id: cd.channel for cd in channel_datasets}
counts = Counter(cd.channel_id for cd in channel_datasets)
for ch_id, count in counts.items():
ch = channels_by_id[ch_id]
_log.info(
f"Created {count} auto-update job(s) for channel '{ch.deployment_id}' (id={ch_id})"
)

return [schemas.AutoUpdateJob.model_validate(job, from_attributes=True) for job in jobs]

async def get_reindex_channel_ids(self, job_ids: list[int]) -> set[int]:
"""Get channel IDs that had at least one REINDEX_TRIGGERED result."""
result = await self._session.execute(
select(models.AutoUpdateJob)
.where(models.AutoUpdateJob.id.in_(job_ids))
.where(models.AutoUpdateJob.result == AutoUpdateResult.REINDEX_TRIGGERED)
.options(selectinload(models.AutoUpdateJob.channel_dataset))
)
jobs = list(result.scalars().all())
return {j.channel_dataset.channel_id for j in jobs}

async def get_auto_update_results(self, job_ids: list[int]) -> list[AutoUpdateChannelResult]:
"""Collect per-channel auto-update results."""
result = await self._session.execute(
select(models.AutoUpdateJob)
.where(models.AutoUpdateJob.id.in_(job_ids))
.options(
selectinload(models.AutoUpdateJob.channel_dataset).selectinload(
models.ChannelDataset.channel
),
selectinload(models.AutoUpdateJob.created_version),
)
)
jobs = list(result.scalars().all())

channel_jobs: defaultdict[int, list[models.AutoUpdateJob]] = defaultdict(list)
for job in jobs:
channel_jobs[job.channel_dataset.channel.id].append(job)

results: list[AutoUpdateChannelResult] = []
for channel_id, ch_jobs in channel_jobs.items():
ch = ch_jobs[0].channel_dataset.channel
results.append(
AutoUpdateChannelResult(
channel_id=channel_id,
deployment_id=ch.deployment_id,
total=len(ch_jobs),
failed=sum(1 for j in ch_jobs if j.status == StatusEnum.FAILED),
summary=self._format_result_summary(ch_jobs),
failed_reasons=[
f"job {j.id}: {j.reason_for_failure}"
for j in ch_jobs
if j.status == StatusEnum.FAILED
],
)
)
return results

@staticmethod
def _format_result_summary(jobs: list[models.AutoUpdateJob]) -> str:
"""Build a human-readable summary of auto-update results."""
reindex_statuses: dict[str, Counter[str]] = defaultdict(Counter)
job_statuses: list[str] = []
for job in jobs:
job_status = job.result.value if job.result else job.status.value
job_statuses.append(job_status)
if job.created_version is not None:
reindex_statuses[job_status][job.created_version.preprocessing_status.value] += 1
result_counts = Counter(job_statuses)

parts: list[str] = []
for job_status, count in result_counts.most_common():
part = f"{count} {job_status}"
if job_status in reindex_statuses:
breakdown = ", ".join(f"{c} {s}" for s, c in reindex_statuses[job_status].items())
part += f" ({breakdown})"
parts.append(part)
return ", ".join(parts)

async def _set_auto_update_job_status(
self,
job: models.AutoUpdateJob,
Expand Down
4 changes: 4 additions & 0 deletions statgpt/common/schemas/data_query_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,10 @@ class DataQueryDetails(BaseToolDetails):
prompts: DataQueryPrompts = Field(default_factory=DataQueryPrompts) # type: ignore
messages: DataQueryMessages = Field(default_factory=DataQueryMessages) # type: ignore
attachments: DataQueryAttachments = Field(default_factory=DataQueryAttachments) # type: ignore
allow_auto_update: bool = Field(
default=False,
description="Whether datasets in this channel should be auto-updated by the batch auto-update script.",
)
tool_response_max_cells: PositiveInt = Field(
default=300,
description=(
Expand Down
2 changes: 1 addition & 1 deletion statgpt/common/services/channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ async def get_channels_db(self, limit: int | None, offset: int) -> list[models.C
q_result = await self._session.execute(query)
return [item for item in q_result.scalars().all()]

async def get_channels_schemas(self, limit: int, offset: int) -> list[schemas.Channel]:
async def get_channels_schemas(self, limit: int | None, offset: int) -> list[schemas.Channel]:
channels = await self.get_channels_db(limit, offset)
return [ChannelSerializer.db_to_schema(item) for item in channels]

Expand Down
Loading