Skip to content
1 change: 1 addition & 0 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ jobs:
run: |
pytest tests/databricks_ai_bridge/test_lakebase.py
pytest integrations/langchain/tests/unit_tests/test_checkpoint.py
pytest integrations/langchain/tests/unit_tests/test_store.py

langchain_cross_version_test:
runs-on: ubuntu-latest
Expand Down
2 changes: 2 additions & 0 deletions integrations/langchain/src/databricks_langchain/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from databricks_langchain.checkpoint import CheckpointSaver
from databricks_langchain.embeddings import DatabricksEmbeddings
from databricks_langchain.genie import GenieAgent
from databricks_langchain.store import DatabricksStore
from databricks_langchain.vector_search_retriever_tool import VectorSearchRetrieverTool
from databricks_langchain.vectorstores import DatabricksVectorSearch

Expand All @@ -29,6 +30,7 @@
"ChatDatabricks",
"CheckpointSaver",
"DatabricksEmbeddings",
"DatabricksStore",
"DatabricksVectorSearch",
"GenieAgent",
"VectorSearchRetrieverTool",
Expand Down
92 changes: 92 additions & 0 deletions integrations/langchain/src/databricks_langchain/store.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
from __future__ import annotations

from typing import Any, Iterable, Optional

from databricks.sdk import WorkspaceClient

try:
from databricks_ai_bridge.lakebase import LakebasePool
from langgraph.store.base import BaseStore, Item, Op, Result
from langgraph.store.postgres import PostgresStore

_store_imports_available = True
except ImportError:
LakebasePool = object
PostgresStore = object
BaseStore = object
Item = object
Op = object
Result = object
_store_imports_available = False


class DatabricksStore(BaseStore):
"""Provides APIs for working with long-term memory on Databricks using Lakebase.
Extends LangGraph BaseStore interface using Databricks Lakebase for connection pooling.

Operations borrow a connection from the pool, create a short-lived PostgresStore,
execute the operation, and return the connection to the pool.
"""

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note: instead of directly subclassing, I wrapped PostgresStore instead to simplify the agent authoring CUJ using the store/handle connections for the user

Previously:

with self.pool.connection() as conn:
        store = PostgresStore(conn=conn)
        results = store.search(namespace, query=query, limit=5)

if we subclassed it would look like:

with self.pool.connection() as conn:
        store = DatabricksStore(conn=conn)
        results = store.search(namespace, query=query, limit=5)

current

results = self.store.search(namespace, query=query, limit=5)

let me know if we prefer to subclass instead

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This approach looks good to me, hopefully advanced connection management isn't needed most of the time + it's good to hide the LakebasePool details

def __init__(
self,
*,
instance_name: str,
workspace_client: Optional[WorkspaceClient] = None,
**pool_kwargs: Any,
) -> None:
if not _store_imports_available:
raise ImportError(
"DatabricksStore requires databricks-langchain[memory]. "
"Install with: pip install 'databricks-langchain[memory]'"
)

# Store initialization parameters for lazy initialization, otherwise
# if we directly initialize pool during deployment it will fail
self._instance_name = instance_name
self._workspace_client = workspace_client
self._pool_kwargs = pool_kwargs
self._lakebase: Optional[LakebasePool] = None
self._pool = None

def _ensure_initialized(self) -> None:
"""Lazy initialization of LakebasePool on first use after deployment is ready."""
if self._lakebase is None:
self._lakebase = LakebasePool(
instance_name=self._instance_name,
workspace_client=self._workspace_client,
**self._pool_kwargs,
)
self._pool = self._lakebase.pool

def _with_store(self, fn, *args, **kwargs):
"""
Borrow a connection, create a short-lived PostgresStore, call fn(store),
then return the connection to the pool.
"""
self._ensure_initialized()
with self._pool.connection() as conn:
store = PostgresStore(conn=conn)
return fn(store, *args, **kwargs)

def setup(self) -> None:
"""Instantiate the store, setting up necessary persistent storage."""
return self._with_store(lambda s: s.setup())

def batch(self, ops: Iterable[Op]) -> list[Result]:
"""Execute a batch of operations synchronously.

This is the core method required by BaseStore. All other operations
(get, put, search, delete, list_namespaces) are inherited from BaseStore
and internally call this batch() method.
"""
return self._with_store(lambda s: s.batch(ops))

async def abatch(self, ops: Iterable[Op]) -> list[Result]:
"""Execute a batch of operations asynchronously.

This is the second abstract method required by BaseStore.
Currently delegates to sync batch() - for true async support,
would need async-compatible connection pooling.
"""
return self.batch(ops)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if we are going to provide standard functions like _normalize_ns_label and _user_namespace, would it make sense to put them here as well? just to reduce the total LoC in the example

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also a comment on the example: can we copy the langgraph example in using output_to_responses_items_stream https://docs.databricks.com/aws/en/generative-ai/agent-framework/author-agent?language=LangGraph#responsesagent-examples

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes to #1 for refactoring those functions related to the key-value storage -- for #2, I believe we're already doing this in the long-term example here - let me know if this is what you were thinking: https://github.com/databricks-eng/universe/pull/1443034/files?w=1#diff-fbb7ade56e8f48e42bc41e5a4672a52e301579fd4f7dc2f7cd0f1e5a30a0babbR463

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds good for #1 -- can we add some unit tests for the behavior we expect?

75 changes: 75 additions & 0 deletions integrations/langchain/tests/unit_tests/test_store.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from __future__ import annotations

from unittest.mock import MagicMock

import pytest

pytest.importorskip("psycopg")
pytest.importorskip("psycopg_pool")
pytest.importorskip("langgraph.checkpoint.postgres")

from databricks_ai_bridge import lakebase

from databricks_langchain import DatabricksStore


class TestConnectionPool:
def __init__(self, connection_value="conn"):
self.connection_value = connection_value
self.conninfo = ""

def __call__(
self,
*,
conninfo,
connection_class=None,
**kwargs,
):
self.conninfo = conninfo
return self

def connection(self):
class _Ctx:
def __init__(self, outer):
self.outer = outer

def __enter__(self):
return self.outer.connection_value

def __exit__(self, exc_type, exc, tb):
pass

return _Ctx(self)


def test_databricks_store_configures_lakebase(monkeypatch):
mock_conn = MagicMock()
test_pool = TestConnectionPool(connection_value=mock_conn)
monkeypatch.setattr(lakebase, "ConnectionPool", test_pool)

from langgraph.store.postgres import PostgresStore

monkeypatch.setattr(PostgresStore, "setup", MagicMock())

workspace = MagicMock()
workspace.database.generate_database_credential.return_value = MagicMock(token="stub-token")
workspace.database.get_database_instance.return_value.read_write_dns = "db-host"
workspace.current_service_principal.me.side_effect = RuntimeError("no sp")
workspace.current_user.me.return_value = MagicMock(user_name="[email protected]")

store = DatabricksStore(
instance_name="lakebase-instance",
workspace_client=workspace,
)

store.setup()

assert (
test_pool.conninfo
== "dbname=databricks_postgres [email protected] host=db-host port=5432 sslmode=require"
)
assert isinstance(store, DatabricksStore)
assert store._lakebase.pool == test_pool

with store._lakebase.connection() as conn:
assert conn == mock_conn