Skip to content

Commit 27d28de

Browse files
authored
Add wait_for_database helper function to poll for Knowledge Base database creation (#1)
* Initial plan * Add wait_for_database helper for knowledge base polling * Add unit tests and README documentation for wait_for_database * Fix linting and type checking for wait_for_database implementation
1 parent a5abfd7 commit 27d28de

File tree

5 files changed

+418
-2
lines changed

5 files changed

+418
-2
lines changed

README.md

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,49 @@ we recommend using [python-dotenv](https://pypi.org/project/python-dotenv/)
9494
to add `DIGITALOCEAN_ACCESS_TOKEN="My Access Token"`, `GRADIENT_MODEL_ACCESS_KEY="My Model Access Key"` to your `.env` file
9595
so that your keys are not stored in source control.
9696

97+
## Knowledge Base Database Polling
98+
99+
When creating a Knowledge Base, the database deployment can take several minutes. The `wait_for_database()` helper function simplifies polling for the database status:
100+
101+
```python
102+
from gradient import Gradient
103+
from gradient.resources.knowledge_bases import KnowledgeBaseDatabaseError
104+
from gradient._exceptions import APITimeoutError
105+
106+
client = Gradient()
107+
108+
# Create a knowledge base
109+
kb_response = client.knowledge_bases.create(
110+
name="My Knowledge Base",
111+
region="nyc1",
112+
embedding_model_uuid="your-embedding-model-uuid",
113+
)
114+
115+
kb_uuid = kb_response.knowledge_base.uuid
116+
117+
try:
118+
# Wait for the database to be ready (default: 10 minute timeout, 5 second poll interval)
119+
result = client.knowledge_bases.wait_for_database(kb_uuid)
120+
print(f"Database status: {result.database_status}") # "ONLINE"
121+
122+
# Custom timeout and poll interval
123+
result = client.knowledge_bases.wait_for_database(
124+
kb_uuid,
125+
timeout=900.0, # 15 minutes
126+
poll_interval=10.0 # Check every 10 seconds
127+
)
128+
129+
except KnowledgeBaseDatabaseError as e:
130+
# Database entered a failed state (DECOMMISSIONED or UNHEALTHY)
131+
print(f"Database failed: {e}")
132+
133+
except APITimeoutError:
134+
# Database did not become ready within the timeout period
135+
print("Timeout: Database did not become ready in time")
136+
```
137+
138+
The helper handles all state transitions and will raise appropriate exceptions for failed states or timeouts.
139+
97140
## Async usage
98141

99142
Simply import `AsyncGradient` instead of `Gradient` and use `await` with each API call:

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -246,5 +246,5 @@ known-first-party = ["gradient", "tests"]
246246
[tool.ruff.lint.per-file-ignores]
247247
"bin/**.py" = ["T201", "T203"]
248248
"scripts/**.py" = ["T201", "T203"]
249-
"tests/**.py" = ["T201", "T203"]
249+
"tests/**.py" = ["T201", "T203", "ARG001"]
250250
"examples/**.py" = ["T201", "T203"]

src/gradient/resources/knowledge_bases/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from .knowledge_bases import (
2020
KnowledgeBasesResource,
2121
AsyncKnowledgeBasesResource,
22+
KnowledgeBaseDatabaseError,
2223
KnowledgeBasesResourceWithRawResponse,
2324
AsyncKnowledgeBasesResourceWithRawResponse,
2425
KnowledgeBasesResourceWithStreamingResponse,
@@ -40,6 +41,7 @@
4041
"AsyncIndexingJobsResourceWithStreamingResponse",
4142
"KnowledgeBasesResource",
4243
"AsyncKnowledgeBasesResource",
44+
"KnowledgeBaseDatabaseError",
4345
"KnowledgeBasesResourceWithRawResponse",
4446
"AsyncKnowledgeBasesResourceWithRawResponse",
4547
"KnowledgeBasesResourceWithStreamingResponse",

src/gradient/resources/knowledge_bases/knowledge_bases.py

Lines changed: 180 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
from __future__ import annotations
44

5+
import time
6+
import asyncio
57
from typing import Iterable
68

79
import httpx
@@ -25,6 +27,7 @@
2527
DataSourcesResourceWithStreamingResponse,
2628
AsyncDataSourcesResourceWithStreamingResponse,
2729
)
30+
from ..._exceptions import APITimeoutError
2831
from .indexing_jobs import (
2932
IndexingJobsResource,
3033
AsyncIndexingJobsResource,
@@ -40,7 +43,13 @@
4043
from ...types.knowledge_base_update_response import KnowledgeBaseUpdateResponse
4144
from ...types.knowledge_base_retrieve_response import KnowledgeBaseRetrieveResponse
4245

43-
__all__ = ["KnowledgeBasesResource", "AsyncKnowledgeBasesResource"]
46+
__all__ = ["KnowledgeBasesResource", "AsyncKnowledgeBasesResource", "KnowledgeBaseDatabaseError"]
47+
48+
49+
class KnowledgeBaseDatabaseError(Exception):
50+
"""Raised when a knowledge base database enters a failed state."""
51+
52+
pass
4453

4554

4655
class KnowledgeBasesResource(SyncAPIResource):
@@ -330,6 +339,85 @@ def delete(
330339
cast_to=KnowledgeBaseDeleteResponse,
331340
)
332341

342+
def wait_for_database(
343+
self,
344+
uuid: str,
345+
*,
346+
timeout: float = 600.0,
347+
poll_interval: float = 5.0,
348+
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
349+
# The extra values given here take precedence over values defined on the client or passed to this method.
350+
extra_headers: Headers | None = None,
351+
extra_query: Query | None = None,
352+
extra_body: Body | None = None,
353+
) -> KnowledgeBaseRetrieveResponse:
354+
"""
355+
Poll the knowledge base until the database status is ONLINE or a failed state is reached.
356+
357+
This helper function repeatedly calls retrieve() to check the database_status field.
358+
It will wait for the database to become ONLINE, or raise an exception if it enters
359+
a failed state (DECOMMISSIONED or UNHEALTHY) or if the timeout is exceeded.
360+
361+
Args:
362+
uuid: The knowledge base UUID to poll
363+
364+
timeout: Maximum time to wait in seconds (default: 600 seconds / 10 minutes)
365+
366+
poll_interval: Time to wait between polls in seconds (default: 5 seconds)
367+
368+
extra_headers: Send extra headers
369+
370+
extra_query: Add additional query parameters to the request
371+
372+
extra_body: Add additional JSON properties to the request
373+
374+
Returns:
375+
The final KnowledgeBaseRetrieveResponse when the database status is ONLINE
376+
377+
Raises:
378+
KnowledgeBaseDatabaseError: If the database enters a failed state (DECOMMISSIONED, UNHEALTHY)
379+
380+
APITimeoutError: If the timeout is exceeded before the database becomes ONLINE
381+
"""
382+
if not uuid:
383+
raise ValueError(f"Expected a non-empty value for `uuid` but received {uuid!r}")
384+
385+
start_time = time.time()
386+
failed_states = {"DECOMMISSIONED", "UNHEALTHY"}
387+
388+
while True:
389+
elapsed = time.time() - start_time
390+
if elapsed >= timeout:
391+
raise APITimeoutError(
392+
request=httpx.Request(
393+
method="GET",
394+
url=f"https://api.digitalocean.com/v2/gen-ai/knowledge_bases/{uuid}",
395+
)
396+
)
397+
398+
response = self.retrieve(
399+
uuid,
400+
extra_headers=extra_headers,
401+
extra_query=extra_query,
402+
extra_body=extra_body,
403+
)
404+
405+
status = response.database_status
406+
407+
if status == "ONLINE":
408+
return response
409+
410+
if status in failed_states:
411+
raise KnowledgeBaseDatabaseError(
412+
f"Knowledge base database entered failed state: {status}"
413+
)
414+
415+
# Sleep before next poll, but don't exceed timeout
416+
remaining_time = timeout - elapsed
417+
sleep_time = min(poll_interval, remaining_time)
418+
if sleep_time > 0:
419+
time.sleep(sleep_time)
420+
333421

334422
class AsyncKnowledgeBasesResource(AsyncAPIResource):
335423
@cached_property
@@ -618,6 +706,85 @@ async def delete(
618706
cast_to=KnowledgeBaseDeleteResponse,
619707
)
620708

709+
async def wait_for_database(
710+
self,
711+
uuid: str,
712+
*,
713+
timeout: float = 600.0,
714+
poll_interval: float = 5.0,
715+
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
716+
# The extra values given here take precedence over values defined on the client or passed to this method.
717+
extra_headers: Headers | None = None,
718+
extra_query: Query | None = None,
719+
extra_body: Body | None = None,
720+
) -> KnowledgeBaseRetrieveResponse:
721+
"""
722+
Poll the knowledge base until the database status is ONLINE or a failed state is reached.
723+
724+
This helper function repeatedly calls retrieve() to check the database_status field.
725+
It will wait for the database to become ONLINE, or raise an exception if it enters
726+
a failed state (DECOMMISSIONED or UNHEALTHY) or if the timeout is exceeded.
727+
728+
Args:
729+
uuid: The knowledge base UUID to poll
730+
731+
timeout: Maximum time to wait in seconds (default: 600 seconds / 10 minutes)
732+
733+
poll_interval: Time to wait between polls in seconds (default: 5 seconds)
734+
735+
extra_headers: Send extra headers
736+
737+
extra_query: Add additional query parameters to the request
738+
739+
extra_body: Add additional JSON properties to the request
740+
741+
Returns:
742+
The final KnowledgeBaseRetrieveResponse when the database status is ONLINE
743+
744+
Raises:
745+
KnowledgeBaseDatabaseError: If the database enters a failed state (DECOMMISSIONED, UNHEALTHY)
746+
747+
APITimeoutError: If the timeout is exceeded before the database becomes ONLINE
748+
"""
749+
if not uuid:
750+
raise ValueError(f"Expected a non-empty value for `uuid` but received {uuid!r}")
751+
752+
start_time = time.time()
753+
failed_states = {"DECOMMISSIONED", "UNHEALTHY"}
754+
755+
while True:
756+
elapsed = time.time() - start_time
757+
if elapsed >= timeout:
758+
raise APITimeoutError(
759+
request=httpx.Request(
760+
method="GET",
761+
url=f"https://api.digitalocean.com/v2/gen-ai/knowledge_bases/{uuid}",
762+
)
763+
)
764+
765+
response = await self.retrieve(
766+
uuid,
767+
extra_headers=extra_headers,
768+
extra_query=extra_query,
769+
extra_body=extra_body,
770+
)
771+
772+
status = response.database_status
773+
774+
if status == "ONLINE":
775+
return response
776+
777+
if status in failed_states:
778+
raise KnowledgeBaseDatabaseError(
779+
f"Knowledge base database entered failed state: {status}"
780+
)
781+
782+
# Sleep before next poll, but don't exceed timeout
783+
remaining_time = timeout - elapsed
784+
sleep_time = min(poll_interval, remaining_time)
785+
if sleep_time > 0:
786+
await asyncio.sleep(sleep_time)
787+
621788

622789
class KnowledgeBasesResourceWithRawResponse:
623790
def __init__(self, knowledge_bases: KnowledgeBasesResource) -> None:
@@ -638,6 +805,9 @@ def __init__(self, knowledge_bases: KnowledgeBasesResource) -> None:
638805
self.delete = to_raw_response_wrapper(
639806
knowledge_bases.delete,
640807
)
808+
self.wait_for_database = to_raw_response_wrapper(
809+
knowledge_bases.wait_for_database,
810+
)
641811

642812
@cached_property
643813
def data_sources(self) -> DataSourcesResourceWithRawResponse:
@@ -667,6 +837,9 @@ def __init__(self, knowledge_bases: AsyncKnowledgeBasesResource) -> None:
667837
self.delete = async_to_raw_response_wrapper(
668838
knowledge_bases.delete,
669839
)
840+
self.wait_for_database = async_to_raw_response_wrapper(
841+
knowledge_bases.wait_for_database,
842+
)
670843

671844
@cached_property
672845
def data_sources(self) -> AsyncDataSourcesResourceWithRawResponse:
@@ -696,6 +869,9 @@ def __init__(self, knowledge_bases: KnowledgeBasesResource) -> None:
696869
self.delete = to_streamed_response_wrapper(
697870
knowledge_bases.delete,
698871
)
872+
self.wait_for_database = to_streamed_response_wrapper(
873+
knowledge_bases.wait_for_database,
874+
)
699875

700876
@cached_property
701877
def data_sources(self) -> DataSourcesResourceWithStreamingResponse:
@@ -725,6 +901,9 @@ def __init__(self, knowledge_bases: AsyncKnowledgeBasesResource) -> None:
725901
self.delete = async_to_streamed_response_wrapper(
726902
knowledge_bases.delete,
727903
)
904+
self.wait_for_database = async_to_streamed_response_wrapper(
905+
knowledge_bases.wait_for_database,
906+
)
728907

729908
@cached_property
730909
def data_sources(self) -> AsyncDataSourcesResourceWithStreamingResponse:

0 commit comments

Comments
 (0)