Skip to content

Commit 1f8e2d0

Browse files
tsbhanguclaude
andauthored
feat(fai): implement incremental document and query index sync (#4710)
Co-authored-by: Claude <[email protected]>
1 parent 21fd170 commit 1f8e2d0

File tree

2 files changed

+138
-10
lines changed

2 files changed

+138
-10
lines changed

servers/fai/src/fai/routes/document.py

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,11 @@
4040
get_query_index_name,
4141
)
4242
from fai.utils.turbopuffer.sync import (
43+
delete_documents_from_query_index,
44+
delete_documents_from_tpuf,
4345
sync_document_db_to_tpuf,
46+
sync_documents_to_query_index,
47+
sync_documents_to_tpuf,
4448
sync_index_to_target,
4549
)
4650

@@ -81,8 +85,10 @@ async def create_document(
8185
created_document_ids.append(document_id)
8286
await db.commit()
8387
await db.refresh(new_db_document)
84-
await sync_document_db_to_tpuf(domain, db)
85-
await sync_index_to_target(domain, get_document_index_name(), get_query_index_name())
88+
await sync_documents_to_tpuf(domain, created_document_ids, db)
89+
await sync_documents_to_query_index(
90+
domain, created_document_ids, get_document_index_name(), get_query_index_name()
91+
)
8692
LOGGER.info(f"Indexed document {new_db_document.id} for domain: {domain}")
8793
return JSONResponse(
8894
jsonable_encoder([CreateDocumentResponse(document_id=document_id) for document_id in created_document_ids])
@@ -132,8 +138,10 @@ async def batch_create_document(
132138
LOGGER.info(f"Created document {document_id} for domain: {domain}")
133139

134140
await db.commit()
135-
await sync_document_db_to_tpuf(domain, db)
136-
await sync_index_to_target(domain, get_document_index_name(), get_query_index_name())
141+
await sync_documents_to_tpuf(domain, created_document_ids, db)
142+
await sync_documents_to_query_index(
143+
domain, created_document_ids, get_document_index_name(), get_query_index_name()
144+
)
137145

138146
return JSONResponse(
139147
jsonable_encoder([CreateDocumentResponse(document_id=document_id) for document_id in created_document_ids])
@@ -181,8 +189,10 @@ async def update_document(
181189
db.add(db_document)
182190
await db.commit()
183191
await db.refresh(db_document)
184-
await sync_document_db_to_tpuf(domain, db)
185-
await sync_index_to_target(domain, get_document_index_name(), get_query_index_name())
192+
await sync_documents_to_tpuf(domain, [document_id], db)
193+
await sync_documents_to_query_index(
194+
domain, [document_id], get_document_index_name(), get_query_index_name()
195+
)
186196
LOGGER.info(f"Updated document {document_id} for domain: {domain}")
187197
return JSONResponse(jsonable_encoder(UpdateDocumentResponse(document=db_document.to_api())))
188198
return JSONResponse(status_code=404, content=jsonable_encoder({"message": "Document not found"}))
@@ -212,8 +222,10 @@ async def delete_document_by_id(
212222
if db_document:
213223
await db.delete(db_document)
214224
await db.commit()
215-
await sync_document_db_to_tpuf(domain, db)
216-
await sync_index_to_target(domain, get_document_index_name(), get_query_index_name())
225+
await delete_documents_from_tpuf(domain, [body.document_id])
226+
await delete_documents_from_query_index(
227+
domain, [body.document_id], get_document_index_name(), get_query_index_name()
228+
)
217229
LOGGER.info(f"Deleted document {body.document_id} for domain: {domain}")
218230
return JSONResponse(jsonable_encoder(DeleteDocumentResponse(success=True)))
219231
return JSONResponse(jsonable_encoder(DeleteDocumentResponse(success=False)))
@@ -237,6 +249,7 @@ async def batch_delete_document(
237249
) -> JSONResponse:
238250
try:
239251
deleted_count = 0
252+
deleted_document_ids = []
240253

241254
for document in body:
242255
db_document = await db.execute(
@@ -245,13 +258,16 @@ async def batch_delete_document(
245258
db_document = db_document.scalar_one_or_none()
246259
if db_document:
247260
await db.delete(db_document)
261+
deleted_document_ids.append(document.document_id)
248262
deleted_count += 1
249263
LOGGER.info(f"Deleted document {document.document_id} for domain: {domain}")
250264

251265
if deleted_count > 0:
252266
await db.commit()
253-
await sync_document_db_to_tpuf(domain, db)
254-
await sync_index_to_target(domain, get_document_index_name(), get_query_index_name())
267+
await delete_documents_from_tpuf(domain, deleted_document_ids)
268+
await delete_documents_from_query_index(
269+
domain, deleted_document_ids, get_document_index_name(), get_query_index_name()
270+
)
255271
return JSONResponse(jsonable_encoder(DeleteDocumentResponse(success=True)))
256272
else:
257273
return JSONResponse(jsonable_encoder(DeleteDocumentResponse(success=False)))

servers/fai/src/fai/utils/turbopuffer/sync.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,56 @@ def prefixed_id(namespace: str, original_id: str, max_len: int = 64) -> str:
3939
return f"{short_ns}:{hashed}"
4040

4141

42+
async def sync_documents_to_tpuf(domain: str, document_ids: list[str], db: AsyncSession) -> None:
43+
if not document_ids:
44+
LOGGER.info("No document IDs provided for sync, skipping")
45+
return
46+
47+
documents = await db.execute(select(DocumentDb).where(DocumentDb.domain == domain, DocumentDb.id.in_(document_ids)))
48+
documents = documents.scalars().all()
49+
50+
if not documents:
51+
LOGGER.warning(f"No documents found for IDs {document_ids} in domain {domain}")
52+
return
53+
54+
async with AsyncOpenAI(api_key=VARIABLES.OPENAI_API_KEY) as openai_client:
55+
async with AsyncTurbopuffer(
56+
region=CONFIG.TURBOPUFFER_DEFAULT_REGION,
57+
api_key=VARIABLES.TURBOPUFFER_API_KEY,
58+
) as tpuf_client:
59+
target_namespace_id = get_tpuf_namespace(domain, get_document_index_name())
60+
target_ns = tpuf_client.namespace(target_namespace_id)
61+
62+
tbuf_records = []
63+
for document in documents:
64+
tbuf_records.append(await document.to_tpuf_record(openai_client))
65+
66+
await target_ns.write(
67+
upsert_rows=[jsonable_encoder(record) for record in tbuf_records],
68+
distance_metric="cosine_distance",
69+
schema=get_data_index_tpuf_schema(),
70+
)
71+
LOGGER.info(f"Upserted {len(documents)} documents to {target_namespace_id}")
72+
73+
74+
async def delete_documents_from_tpuf(domain: str, document_ids: list[str]) -> None:
75+
if not document_ids:
76+
LOGGER.info("No document IDs provided for deletion, skipping")
77+
return
78+
79+
async with AsyncTurbopuffer(
80+
region=CONFIG.TURBOPUFFER_DEFAULT_REGION,
81+
api_key=VARIABLES.TURBOPUFFER_API_KEY,
82+
) as tpuf_client:
83+
target_namespace_id = get_tpuf_namespace(domain, get_document_index_name())
84+
target_ns = tpuf_client.namespace(target_namespace_id)
85+
86+
for document_id in document_ids:
87+
await target_ns.write(delete_by_filter=["id", "Eq", document_id])
88+
89+
LOGGER.info(f"Deleted {len(document_ids)} documents from {target_namespace_id}")
90+
91+
4292
async def sync_document_db_to_tpuf(domain: str, db: AsyncSession) -> None:
4393
documents = await db.execute(select(DocumentDb).where(DocumentDb.domain == domain))
4494
documents = documents.scalars().all()
@@ -118,6 +168,68 @@ async def sync_slack_context_db_to_tpuf(domain: str, db: AsyncSession) -> None:
118168
LOGGER.info(f"Wrote {len(slack_contexts)} slack contexts to {target_namespace_id}")
119169

120170

171+
async def sync_documents_to_query_index(
172+
domain: str, document_ids: list[str], source_index_name: str, target_index_name: str
173+
) -> None:
174+
if not document_ids:
175+
LOGGER.info("No document IDs provided for query index sync, skipping")
176+
return
177+
178+
source_namespace_id = get_tpuf_namespace(domain, source_index_name)
179+
target_namespace_id = get_tpuf_namespace(domain, target_index_name)
180+
181+
async with AsyncTurbopuffer(
182+
region=CONFIG.TURBOPUFFER_DEFAULT_REGION,
183+
api_key=VARIABLES.TURBOPUFFER_API_KEY,
184+
) as tpuf_client:
185+
source_ns = tpuf_client.namespace(source_namespace_id)
186+
target_ns = tpuf_client.namespace(target_namespace_id)
187+
188+
for document_id in document_ids:
189+
prefixed_doc_id = prefixed_id(source_namespace_id, document_id)
190+
await target_ns.write(delete_by_filter=["id", "Eq", prefixed_doc_id])
191+
192+
prefixed_rows = []
193+
for document_id in document_ids:
194+
result = await source_ns.query(filters=("id", "Eq", document_id), top_k=1, include_attributes=True)
195+
196+
if result.rows:
197+
row = result.rows[0]
198+
new_row = Row.from_dict(row.model_dump())
199+
new_row.id = prefixed_id(source_namespace_id, document_id)
200+
new_row.source = source_index_name
201+
prefixed_rows.append(new_row)
202+
203+
if prefixed_rows:
204+
await target_ns.write(
205+
upsert_rows=prefixed_rows, distance_metric="cosine_distance", schema=get_query_index_tpuf_schema()
206+
)
207+
LOGGER.info(f"Synced {len(prefixed_rows)} documents to query index {target_namespace_id}")
208+
209+
210+
async def delete_documents_from_query_index(
211+
domain: str, document_ids: list[str], source_index_name: str, target_index_name: str
212+
) -> None:
213+
if not document_ids:
214+
LOGGER.info("No document IDs provided for query index deletion, skipping")
215+
return
216+
217+
source_namespace_id = get_tpuf_namespace(domain, source_index_name)
218+
target_namespace_id = get_tpuf_namespace(domain, target_index_name)
219+
220+
async with AsyncTurbopuffer(
221+
region=CONFIG.TURBOPUFFER_DEFAULT_REGION,
222+
api_key=VARIABLES.TURBOPUFFER_API_KEY,
223+
) as tpuf_client:
224+
target_ns = tpuf_client.namespace(target_namespace_id)
225+
226+
for document_id in document_ids:
227+
prefixed_doc_id = prefixed_id(source_namespace_id, document_id)
228+
await target_ns.write(delete_by_filter=["id", "Eq", prefixed_doc_id])
229+
230+
LOGGER.info(f"Deleted {len(document_ids)} documents from query index {target_namespace_id}")
231+
232+
121233
async def sync_index_to_target(domain: str, source_index_name: str, target_index_name: str) -> None:
122234
source_namespace_id = get_tpuf_namespace(domain, source_index_name)
123235
target_namespace_id = get_tpuf_namespace(domain, target_index_name)

0 commit comments

Comments
 (0)