Skip to content
Open
1 change: 1 addition & 0 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ jobs:
run: |
pytest tests/databricks_ai_bridge/test_lakebase.py
pytest integrations/langchain/tests/unit_tests/test_checkpoint.py
pytest integrations/langchain/tests/unit_tests/test_store.py

langchain_cross_version_test:
runs-on: ubuntu-latest
Expand Down
2 changes: 2 additions & 0 deletions integrations/langchain/src/databricks_langchain/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from databricks_langchain.checkpoint import CheckpointSaver
from databricks_langchain.embeddings import DatabricksEmbeddings
from databricks_langchain.genie import GenieAgent
from databricks_langchain.store import DatabricksStore
from databricks_langchain.vector_search_retriever_tool import VectorSearchRetrieverTool
from databricks_langchain.vectorstores import DatabricksVectorSearch

Expand All @@ -29,6 +30,7 @@
"ChatDatabricks",
"CheckpointSaver",
"DatabricksEmbeddings",
"DatabricksStore",
"DatabricksVectorSearch",
"GenieAgent",
"VectorSearchRetrieverTool",
Expand Down
125 changes: 125 additions & 0 deletions integrations/langchain/src/databricks_langchain/store.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
from __future__ import annotations

import hashlib
import re
from typing import Any, Iterable, Optional

from databricks.sdk import WorkspaceClient

try:
from databricks_ai_bridge.lakebase import LakebasePool
from langgraph.store.base import BaseStore, Item, Op, Result
from langgraph.store.postgres import PostgresStore

_store_imports_available = True
except ImportError:
LakebasePool = object
PostgresStore = object
BaseStore = object
Item = object
Op = object
Result = object
_store_imports_available = False


class DatabricksStore(BaseStore):
"""Provides APIs for working with long-term memory on Databricks using Lakebase.
Extends LangGraph BaseStore interface using Databricks Lakebase for connection pooling.

Operations borrow a connection from the pool, create a short-lived PostgresStore,
execute the operation, and return the connection to the pool.
"""

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note: instead of directly subclassing, I wrapped PostgresStore instead to simplify the agent authoring CUJ using the store/handle connections for the user

Previously:

with self.pool.connection() as conn:
        store = PostgresStore(conn=conn)
        results = store.search(namespace, query=query, limit=5)

if we subclassed it would look like:

with self.pool.connection() as conn:
        store = DatabricksStore(conn=conn)
        results = store.search(namespace, query=query, limit=5)

current

results = self.store.search(namespace, query=query, limit=5)

let me know if we prefer to subclass instead

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This approach looks good to me, hopefully advanced connection management isn't needed most of the time + it's good to hide the LakebasePool details

@staticmethod
def normalize_namespace_label(s: Optional[str]) -> str:
"""Normalize a string for use as a namespace label.
Converts to lowercase, replaces @ with -at-, removes invalid characters,
and truncates with hash if too long.
Args:
s: The string to normalize (e.g., email address, user_id)
Returns:
Normalized string safe for namespace usage
Example:
>>> normalize_namespace_label("[email protected]")
'user-at-example-com'
>>> normalize_namespace_label("")
'anon'
"""
SAFE_NS_MAX = 64

if not s:
return "anon"
x = s.strip().lower().replace("@", "-at-")
x = re.sub(r"[^a-z0-9_-]+", "-", x) # removes dots and punctuation
x = re.sub(r"-{2,}", "-", x).strip("-") or "anon"
if len(x) > SAFE_NS_MAX:
head = x[: SAFE_NS_MAX - 17]
tail = hashlib.sha256(x.encode()).hexdigest()[:16]
x = f"{head}-{tail}"
return x

@staticmethod
def namespace(identifier: str, prefix: str = "users") -> tuple[str, ...]:
"""Create a namespace tuple with a normalized identifier.
Args:
identifier: The identifier to normalize (e.g., user_id, email, entity_name)
prefix: The namespace prefix (default: "users")
Returns:
Tuple of (prefix, normalized_identifier) for use as namespace
Example:
>>> namespace("[email protected]")
('users', 'email-at-databricks-com')
>>> namespace("session-123", prefix="sessions")
('sessions', 'session-123')
"""
return (prefix, DatabricksStore.normalize_namespace_label(identifier))

def __init__(
self,
*,
instance_name: str,
workspace_client: Optional[WorkspaceClient] = None,
**pool_kwargs: Any,
) -> None:
if not _store_imports_available:
raise ImportError(
"DatabricksStore requires databricks-langchain[memory]. "
"Install with: pip install 'databricks-langchain[memory]'"
)

self._lakebase: LakebasePool = LakebasePool(
instance_name=instance_name,
workspace_client=workspace_client,
**pool_kwargs,
)

def _with_store(self, fn, *args, **kwargs):
"""
Borrow a connection, create a short-lived PostgresStore, call fn(store),
then return the connection to the pool.
"""
with self._lakebase.connection() as conn:
store = PostgresStore(conn=conn)
return fn(store, *args, **kwargs)

def setup(self) -> None:
"""Instantiate the store, setting up necessary persistent storage."""
return self._with_store(lambda s: s.setup())

def batch(self, ops: Iterable[Op]) -> list[Result]:
"""Execute a batch of operations synchronously.

This is the core method required by BaseStore. All other operations
(get, put, search, delete, list_namespaces) are inherited from BaseStore
and internally call this batch() method.
"""
return self._with_store(lambda s: s.batch(ops))

async def abatch(self, ops: Iterable[Op]) -> list[Result]:
"""Execute a batch of operations asynchronously.

This is the second abstract method required by BaseStore.
Currently delegates to sync batch() - for true async support,
would need async-compatible connection pooling.
"""
return self.batch(ops)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if we are going to provide standard functions like _normalize_ns_label and _user_namespace, would it make sense to put them here as well? just to reduce the total LoC in the example

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also a comment on the example: can we copy the langgraph example in using output_to_responses_items_stream https://docs.databricks.com/aws/en/generative-ai/agent-framework/author-agent?language=LangGraph#responsesagent-examples

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes to #1 for refactoring those functions related to the key-value storage -- for #2, I believe we're already doing this in the long-term example here - let me know if this is what you were thinking: https://github.com/databricks-eng/universe/pull/1443034/files?w=1#diff-fbb7ade56e8f48e42bc41e5a4672a52e301579fd4f7dc2f7cd0f1e5a30a0babbR463

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds good for #1 -- can we add some unit tests for the behavior we expect?

163 changes: 163 additions & 0 deletions integrations/langchain/tests/unit_tests/test_store.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
from __future__ import annotations

from unittest.mock import MagicMock

import pytest

pytest.importorskip("psycopg")
pytest.importorskip("psycopg_pool")
pytest.importorskip("langgraph.checkpoint.postgres")

from databricks_ai_bridge import lakebase

from databricks_langchain import DatabricksStore


class TestConnectionPool:
def __init__(self, connection_value="conn"):
self.connection_value = connection_value
self.conninfo = ""

def __call__(
self,
*,
conninfo,
connection_class=None,
**kwargs,
):
self.conninfo = conninfo
return self

def connection(self):
class _Ctx:
def __init__(self, outer):
self.outer = outer

def __enter__(self):
return self.outer.connection_value

def __exit__(self, exc_type, exc, tb):
pass

return _Ctx(self)


def test_databricks_store_configures_lakebase(monkeypatch):
mock_conn = MagicMock()
test_pool = TestConnectionPool(connection_value=mock_conn)
monkeypatch.setattr(lakebase, "ConnectionPool", test_pool)

from langgraph.store.postgres import PostgresStore

monkeypatch.setattr(PostgresStore, "setup", MagicMock())

workspace = MagicMock()
workspace.database.generate_database_credential.return_value = MagicMock(token="stub-token")
workspace.database.get_database_instance.return_value.read_write_dns = "db-host"
workspace.current_service_principal.me.side_effect = RuntimeError("no sp")
workspace.current_user.me.return_value = MagicMock(user_name="[email protected]")

store = DatabricksStore(
instance_name="lakebase-instance",
workspace_client=workspace,
)

assert (
test_pool.conninfo
== "dbname=databricks_postgres [email protected] host=db-host port=5432 sslmode=require"
)
assert isinstance(store, DatabricksStore)

with store._lakebase.connection() as conn:
assert conn == mock_conn


class TestDatabricksStoreNamespace:
"""Test the DatabricksStore.namespace() static method."""

def test_namespace_with_email(self):
"""Test namespace normalization with a typical email address."""
result = DatabricksStore.namespace("[email protected]")
assert result == ("users", "first-last-at-databricks-com")

def test_namespace_with_uppercase(self):
"""Test that uppercase letters are converted to lowercase."""
result = DatabricksStore.namespace("[email protected]")
assert result == ("users", "first-last-at-databricks-com")

def test_namespace_with_empty_identifier(self):
"""Test that empty identifier returns 'anon'."""
result = DatabricksStore.namespace("")
assert result == ("users", "anon")

def test_namespace_with_whitespace_only(self):
"""Test that whitespace-only identifier returns 'anon'."""
result = DatabricksStore.namespace(" ")
assert result == ("users", "anon")

def test_namespace_with_custom_prefix(self):
"""Test namespace with a custom prefix."""
result = DatabricksStore.namespace("user123", prefix="agents")
assert result == ("agents", "user123")

def test_namespace_with_special_characters(self):
"""Test that special characters are replaced with dashes."""
result = DatabricksStore.namespace("user!name@test#site.com")
assert result == ("users", "user-name-at-test-site-com")

def test_namespace_with_leading_trailing_special_chars(self):
"""Test that leading/trailing dashes are stripped."""
result = DatabricksStore.namespace("[email protected]!!!")
assert result == ("users", "user-at-test-com")

def test_namespace_with_underscores_and_hyphens(self):
"""Test that underscores and hyphens are preserved."""
result = DatabricksStore.namespace("user_name-123")
assert result == ("users", "user_name-123")

def test_namespace_with_numbers(self):
"""Test that numbers are preserved."""
result = DatabricksStore.namespace("[email protected]")
assert result == ("users", "user123-at-test456-com")

def test_namespace_with_long_identifier(self):
"""Test that long identifiers are truncated with hash suffix."""
# Create an identifier longer than 64 characters
long_identifier = "a" * 70 + "@example.com"
result = DatabricksStore.namespace(long_identifier)

# Should be exactly 64 characters
assert len(result[1]) == 64

# Should start with truncated original and end with hash
assert result[1].startswith("a" * 47)
assert "-" in result[1]

# Verify it's a valid namespace tuple
assert result[0] == "users"
assert isinstance(result, tuple)
assert len(result) == 2

def test_namespace_with_at_and_special_characters(self):
"""Test identifier with at and special characters."""
result = DatabricksStore.namespace("@@@###$$$")
assert result == ("users", "at-at-at")

def test_namespace_with_mixed_valid_invalid_chars(self):
"""Test identifier with mix of valid and invalid characters."""
result = DatabricksStore.namespace("test$user%123@site&domain.com")
assert result == ("users", "test-user-123-at-site-domain-com")

def test_namespace_with_unicode_characters(self):
"""Test that unicode characters are removed or replaced."""
result = DatabricksStore.namespace("user\[email protected]") # user with é
assert result[0] == "users"
assert "at-test-com" in result[1]

def test_namespace_returns_tuple(self):
"""Test that namespace always returns a tuple."""
result = DatabricksStore.namespace("[email protected]")
assert isinstance(result, tuple)
assert len(result) == 2
assert isinstance(result[0], str)
assert isinstance(result[1], str)