-
Notifications
You must be signed in to change notification settings - Fork 38
DatabricksStore: PostgresStore Wrapper SDK #227
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 9 commits
9951257
fff1c1e
bb45b1c
bff5c76
272c62d
888a924
2dfabef
5215733
0ea050c
848a7cf
ec3444b
4d1c74c
b036124
7872846
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,92 @@ | ||
| from __future__ import annotations | ||
|
|
||
| from typing import Any, Iterable, Optional | ||
|
|
||
| from databricks.sdk import WorkspaceClient | ||
|
|
||
| try: | ||
| from databricks_ai_bridge.lakebase import LakebasePool | ||
| from langgraph.store.base import BaseStore, Item, Op, Result | ||
| from langgraph.store.postgres import PostgresStore | ||
|
|
||
| _store_imports_available = True | ||
| except ImportError: | ||
| LakebasePool = object | ||
| PostgresStore = object | ||
| BaseStore = object | ||
| Item = object | ||
| Op = object | ||
| Result = object | ||
| _store_imports_available = False | ||
|
|
||
|
|
||
| class DatabricksStore(BaseStore): | ||
| """Provides APIs for working with long-term memory on Databricks using Lakebase. | ||
| Extends LangGraph BaseStore interface using Databricks Lakebase for connection pooling. | ||
|
|
||
| Operations borrow a connection from the pool, create a short-lived PostgresStore, | ||
| execute the operation, and return the connection to the pool. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| *, | ||
| instance_name: str, | ||
| workspace_client: Optional[WorkspaceClient] = None, | ||
| **pool_kwargs: Any, | ||
| ) -> None: | ||
| if not _store_imports_available: | ||
| raise ImportError( | ||
| "DatabricksStore requires databricks-langchain[memory]. " | ||
| "Install with: pip install 'databricks-langchain[memory]'" | ||
| ) | ||
|
|
||
| # Store initialization parameters for lazy initialization, otherwise | ||
| # if we directly initialize pool during deployment it will fail | ||
jennsun marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| self._instance_name = instance_name | ||
| self._workspace_client = workspace_client | ||
| self._pool_kwargs = pool_kwargs | ||
| self._lakebase: Optional[LakebasePool] = None | ||
| self._pool = None | ||
|
|
||
| def _ensure_initialized(self) -> None: | ||
| """Lazy initialization of LakebasePool on first use after deployment is ready.""" | ||
| if self._lakebase is None: | ||
| self._lakebase = LakebasePool( | ||
| instance_name=self._instance_name, | ||
| workspace_client=self._workspace_client, | ||
| **self._pool_kwargs, | ||
| ) | ||
| self._pool = self._lakebase.pool | ||
|
|
||
| def _with_store(self, fn, *args, **kwargs): | ||
| """ | ||
| Borrow a connection, create a short-lived PostgresStore, call fn(store), | ||
| then return the connection to the pool. | ||
| """ | ||
| self._ensure_initialized() | ||
| with self._pool.connection() as conn: | ||
| store = PostgresStore(conn=conn) | ||
| return fn(store, *args, **kwargs) | ||
|
|
||
| def setup(self) -> None: | ||
| """Instantiate the store, setting up necessary persistent storage.""" | ||
| return self._with_store(lambda s: s.setup()) | ||
|
|
||
| def batch(self, ops: Iterable[Op]) -> list[Result]: | ||
| """Execute a batch of operations synchronously. | ||
|
|
||
| This is the core method required by BaseStore. All other operations | ||
| (get, put, search, delete, list_namespaces) are inherited from BaseStore | ||
| and internally call this batch() method. | ||
| """ | ||
| return self._with_store(lambda s: s.batch(ops)) | ||
|
|
||
| async def abatch(self, ops: Iterable[Op]) -> list[Result]: | ||
| """Execute a batch of operations asynchronously. | ||
|
|
||
| This is the second abstract method required by BaseStore. | ||
| Currently delegates to sync batch() - for true async support, | ||
| would need async-compatible connection pooling. | ||
| """ | ||
| return self.batch(ops) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if we are going to provide standard functions like
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. also a comment on the example: can we copy the langgraph example in using
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes to #1 for refactoring those functions related to the key-value storage -- for #2, I believe we're already doing this in the long-term example here - let me know if this is what you were thinking: https://github.com/databricks-eng/universe/pull/1443034/files?w=1#diff-fbb7ade56e8f48e42bc41e5a4672a52e301579fd4f7dc2f7cd0f1e5a30a0babbR463
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sounds good for #1 -- can we add some unit tests for the behavior we expect? |
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,75 @@ | ||
| from __future__ import annotations | ||
|
|
||
| from unittest.mock import MagicMock | ||
|
|
||
| import pytest | ||
|
|
||
| pytest.importorskip("psycopg") | ||
| pytest.importorskip("psycopg_pool") | ||
| pytest.importorskip("langgraph.checkpoint.postgres") | ||
|
|
||
| from databricks_ai_bridge import lakebase | ||
|
|
||
| from databricks_langchain import DatabricksStore | ||
|
|
||
|
|
||
| class TestConnectionPool: | ||
| def __init__(self, connection_value="conn"): | ||
| self.connection_value = connection_value | ||
| self.conninfo = "" | ||
|
|
||
| def __call__( | ||
| self, | ||
| *, | ||
| conninfo, | ||
| connection_class=None, | ||
| **kwargs, | ||
| ): | ||
| self.conninfo = conninfo | ||
| return self | ||
|
|
||
| def connection(self): | ||
| class _Ctx: | ||
| def __init__(self, outer): | ||
| self.outer = outer | ||
|
|
||
| def __enter__(self): | ||
| return self.outer.connection_value | ||
|
|
||
| def __exit__(self, exc_type, exc, tb): | ||
| pass | ||
|
|
||
| return _Ctx(self) | ||
|
|
||
|
|
||
| def test_databricks_store_configures_lakebase(monkeypatch): | ||
| mock_conn = MagicMock() | ||
| test_pool = TestConnectionPool(connection_value=mock_conn) | ||
| monkeypatch.setattr(lakebase, "ConnectionPool", test_pool) | ||
|
|
||
| from langgraph.store.postgres import PostgresStore | ||
|
|
||
| monkeypatch.setattr(PostgresStore, "setup", MagicMock()) | ||
|
|
||
| workspace = MagicMock() | ||
| workspace.database.generate_database_credential.return_value = MagicMock(token="stub-token") | ||
| workspace.database.get_database_instance.return_value.read_write_dns = "db-host" | ||
| workspace.current_service_principal.me.side_effect = RuntimeError("no sp") | ||
| workspace.current_user.me.return_value = MagicMock(user_name="[email protected]") | ||
|
|
||
| store = DatabricksStore( | ||
| instance_name="lakebase-instance", | ||
| workspace_client=workspace, | ||
| ) | ||
|
|
||
| store.setup() | ||
|
|
||
| assert ( | ||
| test_pool.conninfo | ||
| == "dbname=databricks_postgres [email protected] host=db-host port=5432 sslmode=require" | ||
| ) | ||
| assert isinstance(store, DatabricksStore) | ||
| assert store._lakebase.pool == test_pool | ||
|
|
||
| with store._lakebase.connection() as conn: | ||
| assert conn == mock_conn |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
note: instead of directly subclassing, I wrapped PostgresStore instead to simplify the agent authoring CUJ using the store/handle connections for the user
Previously:
if we subclassed it would look like:
current
let me know if we prefer to subclass instead
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This approach looks good to me, hopefully advanced connection management isn't needed most of the time + it's good to hide the LakebasePool details