Skip to content

Commit 01a17ba

Browse files
committed
Fix async db sesssion
1 parent 2dae729 commit 01a17ba

File tree

1 file changed

+43
-34
lines changed

1 file changed

+43
-34
lines changed

servers/fai/src/fai/routes/website.py

Lines changed: 43 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from sqlalchemy.ext.asyncio import AsyncSession
2222

2323
from fai.app import fai_app
24+
from fai.db import async_session_maker
2425
from fai.dependencies import (
2526
ask_ai_enabled,
2627
get_db,
@@ -66,7 +67,6 @@ async def index_website(
6667
body: IndexWebsiteRequest = Body(...),
6768
db: AsyncSession = Depends(get_db),
6869
_: None = Depends(verify_token),
69-
__: None = Depends(ask_ai_enabled),
7070
) -> JSONResponse:
7171
"""
7272
Start crawling and indexing a website.
@@ -86,7 +86,7 @@ async def index_website(
8686

8787
if index_source:
8888
index_source.status = "indexing"
89-
index_source.last_job_id = job_id
89+
index_source.job_id = job_id
9090
index_source.config = body.model_dump()
9191
index_source.updated_at = datetime.now(UTC)
9292
else:
@@ -98,15 +98,17 @@ async def index_website(
9898
source_identifier=body.base_url,
9999
config=body.model_dump(),
100100
status="indexing",
101-
last_job_id=job_id,
101+
job_id=job_id,
102102
created_at=datetime.now(UTC),
103103
updated_at=datetime.now(UTC),
104104
)
105105
db.add(index_source)
106106

107107
await db.commit()
108108

109-
asyncio.create_task(job_manager.execute_job(job_id, _crawl_website_job, index_source.id, domain, body, db))
109+
asyncio.create_task(
110+
job_manager.execute_job(job_id, _crawl_website_job, job_id, index_source.id, domain, body, db)
111+
)
110112

111113
LOGGER.info(f"Started website crawl job {job_id} for domain: {domain}, base_url: {body.base_url}")
112114
return JSONResponse(jsonable_encoder(IndexWebsiteResponse(job_id=job_id, base_url=body.base_url)))
@@ -193,38 +195,42 @@ async def _crawl_website_job(
193195
await sync_website_db_to_tpuf(domain, db)
194196
await sync_index_to_target(domain, get_website_index_name(), get_query_index_name())
195197

196-
result = await db.execute(select(IndexSourceDb).where(IndexSourceDb.id == source_id))
197-
index_source = result.scalar_one_or_none()
198+
# Get a fresh database session after long-running Turbopuffer sync
199+
async with async_session_maker() as fresh_db:
200+
result = await fresh_db.execute(select(IndexSourceDb).where(IndexSourceDb.id == source_id))
201+
index_source = result.scalar_one_or_none()
198202

199-
if index_source:
200-
index_source.status = "active"
201-
index_source.last_indexed_at = datetime.now(UTC)
202-
index_source.updated_at = datetime.now(UTC)
203+
if index_source:
204+
index_source.status = "active"
205+
index_source.last_indexed_at = datetime.now(UTC)
206+
index_source.updated_at = datetime.now(UTC)
203207

204-
index_source.metrics = {
205-
"pages_indexed": pages_indexed,
206-
"pages_failed": pages_failed,
207-
}
208+
index_source.metrics = {
209+
"pages_indexed": pages_indexed,
210+
"pages_failed": pages_failed,
211+
}
208212

209-
await db.commit()
213+
await fresh_db.commit()
210214

211215
LOGGER.info(f"Completed website crawl job {job_id} for domain: {domain}")
212216
except Exception:
213217
LOGGER.exception(f"Failed to complete website crawl job {job_id}")
214218

215-
result = await db.execute(select(IndexSourceDb).where(IndexSourceDb.id == source_id))
216-
index_source = result.scalar_one_or_none()
219+
# Get a fresh database session in case of failure after Turbopuffer sync
220+
async with async_session_maker() as fresh_db:
221+
result = await fresh_db.execute(select(IndexSourceDb).where(IndexSourceDb.id == source_id))
222+
index_source = result.scalar_one_or_none()
217223

218-
if index_source:
219-
index_source.status = "failed"
220-
index_source.updated_at = datetime.now(UTC)
224+
if index_source:
225+
index_source.status = "failed"
226+
index_source.updated_at = datetime.now(UTC)
221227

222-
index_source.metrics = {
223-
"pages_indexed": pages_indexed,
224-
"pages_failed": pages_failed,
225-
}
228+
index_source.metrics = {
229+
"pages_indexed": pages_indexed,
230+
"pages_failed": pages_failed,
231+
}
226232

227-
await db.commit()
233+
await fresh_db.commit()
228234

229235

230236
@fai_app.get(
@@ -233,11 +239,9 @@ async def _crawl_website_job(
233239
openapi_extra={"x-fern-audiences": ["customers"], "security": [{"bearerAuth": []}]},
234240
)
235241
async def get_website_status(
236-
domain: str,
237242
job_id: str = QueryParam(..., description="The job ID returned from the index endpoint"),
238243
db: AsyncSession = Depends(get_db),
239244
_: None = Depends(verify_token),
240-
__: None = Depends(ask_ai_enabled),
241245
) -> JSONResponse:
242246
"""
243247
Get the status of a website crawling job.
@@ -249,7 +253,7 @@ async def get_website_status(
249253
return JSONResponse(status_code=404, content={"detail": "Job not found"})
250254

251255
# Find the IndexSourceDb that corresponds to this job
252-
result = await db.execute(select(IndexSourceDb).where(IndexSourceDb.last_job_id == job_id))
256+
result = await db.execute(select(IndexSourceDb).where(IndexSourceDb.job_id == job_id))
253257
index_source = result.scalar_one_or_none()
254258

255259
if not index_source:
@@ -261,8 +265,8 @@ async def get_website_status(
261265
pages_failed = metrics.get("pages_failed", 0)
262266

263267
# Determine status: use job status if in progress, otherwise use source status
264-
if job.status.value in ["pending", "in_progress"]:
265-
status = job.status.value
268+
if job.status in ["pending", "in_progress"]:
269+
status = job.status
266270
error = None
267271
else:
268272
status = index_source.status
@@ -294,7 +298,6 @@ async def get_website_by_id(
294298
website_id: str,
295299
db: AsyncSession = Depends(get_db),
296300
_: None = Depends(verify_token),
297-
__: None = Depends(ask_ai_enabled),
298301
) -> JSONResponse:
299302
"""
300303
Get a single indexed website page by ID.
@@ -411,7 +414,7 @@ async def reindex_website(
411414
if index_source:
412415
# Update existing source
413416
index_source.status = "indexing"
414-
index_source.last_job_id = job_id
417+
index_source.job_id = job_id
415418
index_source.updated_at = datetime.now(UTC)
416419
# Reset metrics for reindexing
417420
index_source.metrics = {}
@@ -425,7 +428,7 @@ async def reindex_website(
425428
source_identifier=body.base_url,
426429
config={"base_url": body.base_url},
427430
status="indexing",
428-
last_job_id=job_id,
431+
job_id=job_id,
429432
created_at=datetime.now(UTC),
430433
updated_at=datetime.now(UTC),
431434
)
@@ -436,7 +439,13 @@ async def reindex_website(
436439
# Start the crawling job
437440
asyncio.create_task(
438441
job_manager.execute_job(
439-
job_id, _crawl_website_job, index_source.id, domain, IndexWebsiteRequest(base_url=body.base_url), db
442+
job_id,
443+
_crawl_website_job,
444+
job_id,
445+
index_source.id,
446+
domain,
447+
IndexWebsiteRequest(base_url=body.base_url),
448+
db,
440449
)
441450
)
442451

0 commit comments

Comments
 (0)