Skip to content
1 change: 1 addition & 0 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ jobs:
run: |
pytest tests/databricks_ai_bridge/test_lakebase.py
pytest integrations/langchain/tests/unit_tests/test_checkpoint.py
pytest integrations/langchain/tests/unit_tests/test_store.py

langchain_cross_version_test:
runs-on: ubuntu-latest
Expand Down
2 changes: 2 additions & 0 deletions integrations/langchain/src/databricks_langchain/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from databricks_langchain.checkpoint import CheckpointSaver
from databricks_langchain.embeddings import DatabricksEmbeddings
from databricks_langchain.genie import GenieAgent
from databricks_langchain.store import DatabricksStore
from databricks_langchain.vector_search_retriever_tool import VectorSearchRetrieverTool
from databricks_langchain.vectorstores import DatabricksVectorSearch

Expand All @@ -29,6 +30,7 @@
"ChatDatabricks",
"CheckpointSaver",
"DatabricksEmbeddings",
"DatabricksStore",
"DatabricksVectorSearch",
"GenieAgent",
"VectorSearchRetrieverTool",
Expand Down
82 changes: 82 additions & 0 deletions integrations/langchain/src/databricks_langchain/store.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
from __future__ import annotations

from typing import Any, Optional

from databricks.sdk import WorkspaceClient

try:
from databricks_ai_bridge.lakebase import LakebasePool
from langgraph.store.postgres import PostgresStore

_store_imports_available = True
except ImportError:
LakebasePool = object
PostgresStore = object
_store_imports_available = False


class DatabricksStore:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a base Store class we should be subclassing?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"""
Wrapper around LangGraph's PostgresStore that uses a Lakebase
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we document this instead for its purpose e.g. explain that this class provides APIs for working with long-term memory, extending LangGraph's Store interface?

connection pool and borrows a connection per call.
"""

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note: instead of directly subclassing, I wrapped PostgresStore instead to simplify the agent authoring CUJ using the store/handle connections for the user

Previously:

with self.pool.connection() as conn:
        store = PostgresStore(conn=conn)
        results = store.search(namespace, query=query, limit=5)

if we subclassed it would look like:

with self.pool.connection() as conn:
        store = DatabricksStore(conn=conn)
        results = store.search(namespace, query=query, limit=5)

current

results = self.store.search(namespace, query=query, limit=5)

let me know if we prefer to subclass instead

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This approach looks good to me, hopefully advanced connection management isn't needed most of the time + it's good to hide the LakebasePool details

def __init__(
self,
*,
instance_name: str,
workspace_client: Optional[WorkspaceClient] = None,
**pool_kwargs: Any,
) -> None:
if not _store_imports_available:
raise ImportError(
"DatabricksStore requires databricks-langchain[memory]. "
"Install with: pip install 'databricks-langchain[memory]'"
)

# Store initialization parameters for lazy initialization, otherwise
# if we directly iniitalize pool during deployment it will fail
self._instance_name = instance_name
self._workspace_client = workspace_client
self._pool_kwargs = pool_kwargs
self._lakebase: Optional[LakebasePool] = None
self._pool = None
self._setup_called = False

def _ensure_initialized(self) -> None:
"""Lazy initialization of LakebasePool on first use after deployment is ready."""
if self._lakebase is None:
self._lakebase = LakebasePool(
instance_name=self._instance_name,
workspace_client=self._workspace_client,
**self._pool_kwargs,
)
self._pool = self._lakebase.pool

def _with_store(self, fn, *args, **kwargs):
"""
Borrow a connection, create a short-lived PostgresStore, call fn(store),
then return the connection to the pool.
"""
self._ensure_initialized()
with self._pool.connection() as conn:
store = PostgresStore(conn=conn)
return fn(store, *args, **kwargs)

def setup(self) -> None:
"""Set up the store database tables."""
return self._with_store(lambda s: s.setup())

def put(self, namespace: tuple[str, ...], key: str, value: Any) -> None:
"""Store a value in the store."""
return self._with_store(lambda s: s.put(namespace, key, value))

def search(
self,
namespace: tuple[str, ...],
*,
query: Optional[str] = None,
limit: int = 20,
) -> list[Any]:
"""Search for items in the store."""
return self._with_store(lambda s: s.search(namespace, query=query, limit=limit))
75 changes: 75 additions & 0 deletions integrations/langchain/tests/unit_tests/test_store.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from __future__ import annotations

from unittest.mock import MagicMock

import pytest

pytest.importorskip("psycopg")
pytest.importorskip("psycopg_pool")
pytest.importorskip("langgraph.checkpoint.postgres")

from databricks_ai_bridge import lakebase

from databricks_langchain import DatabricksStore


class TestConnectionPool:
def __init__(self, connection_value="conn"):
self.connection_value = connection_value
self.conninfo = ""

def __call__(
self,
*,
conninfo,
connection_class=None,
**kwargs,
):
self.conninfo = conninfo
return self

def connection(self):
class _Ctx:
def __init__(self, outer):
self.outer = outer

def __enter__(self):
return self.outer.connection_value

def __exit__(self, exc_type, exc, tb):
pass

return _Ctx(self)


def test_databricks_store_configures_lakebase(monkeypatch):
mock_conn = MagicMock()
test_pool = TestConnectionPool(connection_value=mock_conn)
monkeypatch.setattr(lakebase, "ConnectionPool", test_pool)

from langgraph.store.postgres import PostgresStore

monkeypatch.setattr(PostgresStore, "setup", MagicMock())

workspace = MagicMock()
workspace.database.generate_database_credential.return_value = MagicMock(token="stub-token")
workspace.database.get_database_instance.return_value.read_write_dns = "db-host"
workspace.current_service_principal.me.side_effect = RuntimeError("no sp")
workspace.current_user.me.return_value = MagicMock(user_name="[email protected]")

store = DatabricksStore(
instance_name="lakebase-instance",
workspace_client=workspace,
)

store.setup()

assert (
test_pool.conninfo
== "dbname=databricks_postgres [email protected] host=db-host port=5432 sslmode=require"
)
assert isinstance(store, DatabricksStore)
assert store._lakebase.pool == test_pool

with store._lakebase.connection() as conn:
assert conn == mock_conn