-
Notifications
You must be signed in to change notification settings - Fork 38
DatabricksStore: PostgresStore Wrapper SDK #227
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
9951257
fff1c1e
bb45b1c
bff5c76
272c62d
888a924
2dfabef
5215733
0ea050c
848a7cf
ec3444b
4d1c74c
b036124
7872846
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,125 @@ | ||
| from __future__ import annotations | ||
|
|
||
| import hashlib | ||
| import re | ||
| from typing import Any, Iterable, Optional | ||
|
|
||
| from databricks.sdk import WorkspaceClient | ||
|
|
||
| try: | ||
| from databricks_ai_bridge.lakebase import LakebasePool | ||
| from langgraph.store.base import BaseStore, Item, Op, Result | ||
| from langgraph.store.postgres import PostgresStore | ||
|
|
||
| _store_imports_available = True | ||
| except ImportError: | ||
| LakebasePool = object | ||
| PostgresStore = object | ||
| BaseStore = object | ||
| Item = object | ||
| Op = object | ||
| Result = object | ||
| _store_imports_available = False | ||
|
|
||
|
|
||
| class DatabricksStore(BaseStore): | ||
| """Provides APIs for working with long-term memory on Databricks using Lakebase. | ||
| Extends LangGraph BaseStore interface using Databricks Lakebase for connection pooling. | ||
|
|
||
| Operations borrow a connection from the pool, create a short-lived PostgresStore, | ||
| execute the operation, and return the connection to the pool. | ||
| """ | ||
|
|
||
| @staticmethod | ||
| def normalize_namespace_label(s: Optional[str]) -> str: | ||
| """Normalize a string for use as a namespace label. | ||
| Converts to lowercase, replaces @ with -at-, removes invalid characters, | ||
| and truncates with hash if too long. | ||
| Args: | ||
| s: The string to normalize (e.g., email address, user_id) | ||
| Returns: | ||
| Normalized string safe for namespace usage | ||
| Example: | ||
| >>> normalize_namespace_label("[email protected]") | ||
| 'user-at-example-com' | ||
| >>> normalize_namespace_label("") | ||
| 'anon' | ||
| """ | ||
| SAFE_NS_MAX = 64 | ||
|
|
||
| if not s: | ||
| return "anon" | ||
| x = s.strip().lower().replace("@", "-at-") | ||
| x = re.sub(r"[^a-z0-9_-]+", "-", x) # removes dots and punctuation | ||
| x = re.sub(r"-{2,}", "-", x).strip("-") or "anon" | ||
| if len(x) > SAFE_NS_MAX: | ||
| head = x[: SAFE_NS_MAX - 17] | ||
| tail = hashlib.sha256(x.encode()).hexdigest()[:16] | ||
| x = f"{head}-{tail}" | ||
| return x | ||
|
|
||
| @staticmethod | ||
| def namespace(identifier: str, prefix: str = "users") -> tuple[str, ...]: | ||
| """Create a namespace tuple with a normalized identifier. | ||
| Args: | ||
| identifier: The identifier to normalize (e.g., user_id, email, entity_name) | ||
| prefix: The namespace prefix (default: "users") | ||
| Returns: | ||
| Tuple of (prefix, normalized_identifier) for use as namespace | ||
| Example: | ||
| >>> namespace("[email protected]") | ||
| ('users', 'email-at-databricks-com') | ||
| >>> namespace("session-123", prefix="sessions") | ||
| ('sessions', 'session-123') | ||
| """ | ||
| return (prefix, DatabricksStore.normalize_namespace_label(identifier)) | ||
|
|
||
| def __init__( | ||
| self, | ||
| *, | ||
| instance_name: str, | ||
| workspace_client: Optional[WorkspaceClient] = None, | ||
| **pool_kwargs: Any, | ||
| ) -> None: | ||
| if not _store_imports_available: | ||
| raise ImportError( | ||
| "DatabricksStore requires databricks-langchain[memory]. " | ||
| "Install with: pip install 'databricks-langchain[memory]'" | ||
| ) | ||
|
|
||
| self._lakebase: LakebasePool = LakebasePool( | ||
| instance_name=instance_name, | ||
| workspace_client=workspace_client, | ||
| **pool_kwargs, | ||
| ) | ||
|
|
||
| def _with_store(self, fn, *args, **kwargs): | ||
| """ | ||
| Borrow a connection, create a short-lived PostgresStore, call fn(store), | ||
| then return the connection to the pool. | ||
| """ | ||
| with self._lakebase.connection() as conn: | ||
| store = PostgresStore(conn=conn) | ||
| return fn(store, *args, **kwargs) | ||
|
|
||
| def setup(self) -> None: | ||
| """Instantiate the store, setting up necessary persistent storage.""" | ||
| return self._with_store(lambda s: s.setup()) | ||
|
|
||
| def batch(self, ops: Iterable[Op]) -> list[Result]: | ||
| """Execute a batch of operations synchronously. | ||
|
|
||
| This is the core method required by BaseStore. All other operations | ||
| (get, put, search, delete, list_namespaces) are inherited from BaseStore | ||
| and internally call this batch() method. | ||
| """ | ||
| return self._with_store(lambda s: s.batch(ops)) | ||
|
|
||
| async def abatch(self, ops: Iterable[Op]) -> list[Result]: | ||
| """Execute a batch of operations asynchronously. | ||
|
|
||
| This is the second abstract method required by BaseStore. | ||
| Currently delegates to sync batch() - for true async support, | ||
| would need async-compatible connection pooling. | ||
| """ | ||
| return self.batch(ops) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if we are going to provide standard functions like
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. also a comment on the example: can we copy the langgraph example in using
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes to #1 for refactoring those functions related to the key-value storage -- for #2, I believe we're already doing this in the long-term example here - let me know if this is what you were thinking: https://github.com/databricks-eng/universe/pull/1443034/files?w=1#diff-fbb7ade56e8f48e42bc41e5a4672a52e301579fd4f7dc2f7cd0f1e5a30a0babbR463
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sounds good for #1 -- can we add some unit tests for the behavior we expect? |
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,163 @@ | ||
| from __future__ import annotations | ||
|
|
||
| from unittest.mock import MagicMock | ||
|
|
||
| import pytest | ||
|
|
||
| pytest.importorskip("psycopg") | ||
| pytest.importorskip("psycopg_pool") | ||
| pytest.importorskip("langgraph.checkpoint.postgres") | ||
|
|
||
| from databricks_ai_bridge import lakebase | ||
|
|
||
| from databricks_langchain import DatabricksStore | ||
|
|
||
|
|
||
| class TestConnectionPool: | ||
| def __init__(self, connection_value="conn"): | ||
| self.connection_value = connection_value | ||
| self.conninfo = "" | ||
|
|
||
| def __call__( | ||
| self, | ||
| *, | ||
| conninfo, | ||
| connection_class=None, | ||
| **kwargs, | ||
| ): | ||
| self.conninfo = conninfo | ||
| return self | ||
|
|
||
| def connection(self): | ||
| class _Ctx: | ||
| def __init__(self, outer): | ||
| self.outer = outer | ||
|
|
||
| def __enter__(self): | ||
| return self.outer.connection_value | ||
|
|
||
| def __exit__(self, exc_type, exc, tb): | ||
| pass | ||
|
|
||
| return _Ctx(self) | ||
|
|
||
|
|
||
| def test_databricks_store_configures_lakebase(monkeypatch): | ||
| mock_conn = MagicMock() | ||
| test_pool = TestConnectionPool(connection_value=mock_conn) | ||
| monkeypatch.setattr(lakebase, "ConnectionPool", test_pool) | ||
|
|
||
| from langgraph.store.postgres import PostgresStore | ||
|
|
||
| monkeypatch.setattr(PostgresStore, "setup", MagicMock()) | ||
|
|
||
| workspace = MagicMock() | ||
| workspace.database.generate_database_credential.return_value = MagicMock(token="stub-token") | ||
| workspace.database.get_database_instance.return_value.read_write_dns = "db-host" | ||
| workspace.current_service_principal.me.side_effect = RuntimeError("no sp") | ||
| workspace.current_user.me.return_value = MagicMock(user_name="[email protected]") | ||
|
|
||
| store = DatabricksStore( | ||
| instance_name="lakebase-instance", | ||
| workspace_client=workspace, | ||
| ) | ||
|
|
||
| assert ( | ||
| test_pool.conninfo | ||
| == "dbname=databricks_postgres [email protected] host=db-host port=5432 sslmode=require" | ||
| ) | ||
| assert isinstance(store, DatabricksStore) | ||
|
|
||
| with store._lakebase.connection() as conn: | ||
| assert conn == mock_conn | ||
|
|
||
|
|
||
| class TestDatabricksStoreNamespace: | ||
| """Test the DatabricksStore.namespace() static method.""" | ||
|
|
||
| def test_namespace_with_email(self): | ||
| """Test namespace normalization with a typical email address.""" | ||
| result = DatabricksStore.namespace("[email protected]") | ||
| assert result == ("users", "first-last-at-databricks-com") | ||
|
|
||
| def test_namespace_with_uppercase(self): | ||
| """Test that uppercase letters are converted to lowercase.""" | ||
| result = DatabricksStore.namespace("[email protected]") | ||
| assert result == ("users", "first-last-at-databricks-com") | ||
|
|
||
| def test_namespace_with_empty_identifier(self): | ||
| """Test that empty identifier returns 'anon'.""" | ||
| result = DatabricksStore.namespace("") | ||
| assert result == ("users", "anon") | ||
|
|
||
| def test_namespace_with_whitespace_only(self): | ||
| """Test that whitespace-only identifier returns 'anon'.""" | ||
| result = DatabricksStore.namespace(" ") | ||
| assert result == ("users", "anon") | ||
|
|
||
| def test_namespace_with_custom_prefix(self): | ||
| """Test namespace with a custom prefix.""" | ||
| result = DatabricksStore.namespace("user123", prefix="agents") | ||
| assert result == ("agents", "user123") | ||
|
|
||
| def test_namespace_with_special_characters(self): | ||
| """Test that special characters are replaced with dashes.""" | ||
| result = DatabricksStore.namespace("user!name@test#site.com") | ||
| assert result == ("users", "user-name-at-test-site-com") | ||
|
|
||
| def test_namespace_with_leading_trailing_special_chars(self): | ||
| """Test that leading/trailing dashes are stripped.""" | ||
| result = DatabricksStore.namespace("[email protected]!!!") | ||
| assert result == ("users", "user-at-test-com") | ||
|
|
||
| def test_namespace_with_underscores_and_hyphens(self): | ||
| """Test that underscores and hyphens are preserved.""" | ||
| result = DatabricksStore.namespace("user_name-123") | ||
| assert result == ("users", "user_name-123") | ||
|
|
||
| def test_namespace_with_numbers(self): | ||
| """Test that numbers are preserved.""" | ||
| result = DatabricksStore.namespace("[email protected]") | ||
| assert result == ("users", "user123-at-test456-com") | ||
|
|
||
| def test_namespace_with_long_identifier(self): | ||
| """Test that long identifiers are truncated with hash suffix.""" | ||
| # Create an identifier longer than 64 characters | ||
| long_identifier = "a" * 70 + "@example.com" | ||
| result = DatabricksStore.namespace(long_identifier) | ||
|
|
||
| # Should be exactly 64 characters | ||
| assert len(result[1]) == 64 | ||
|
|
||
| # Should start with truncated original and end with hash | ||
| assert result[1].startswith("a" * 47) | ||
| assert "-" in result[1] | ||
|
|
||
| # Verify it's a valid namespace tuple | ||
| assert result[0] == "users" | ||
| assert isinstance(result, tuple) | ||
| assert len(result) == 2 | ||
|
|
||
| def test_namespace_with_at_and_special_characters(self): | ||
| """Test identifier with at and special characters.""" | ||
| result = DatabricksStore.namespace("@@@###$$$") | ||
| assert result == ("users", "at-at-at") | ||
|
|
||
| def test_namespace_with_mixed_valid_invalid_chars(self): | ||
| """Test identifier with mix of valid and invalid characters.""" | ||
| result = DatabricksStore.namespace("test$user%123@site&domain.com") | ||
| assert result == ("users", "test-user-123-at-site-domain-com") | ||
|
|
||
| def test_namespace_with_unicode_characters(self): | ||
| """Test that unicode characters are removed or replaced.""" | ||
| result = DatabricksStore.namespace("user\[email protected]") # user with é | ||
| assert result[0] == "users" | ||
| assert "at-test-com" in result[1] | ||
|
|
||
| def test_namespace_returns_tuple(self): | ||
| """Test that namespace always returns a tuple.""" | ||
| result = DatabricksStore.namespace("[email protected]") | ||
| assert isinstance(result, tuple) | ||
| assert len(result) == 2 | ||
| assert isinstance(result[0], str) | ||
| assert isinstance(result[1], str) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
note: instead of directly subclassing, I wrapped PostgresStore instead to simplify the agent authoring CUJ using the store/handle connections for the user
Previously:
if we subclassed it would look like:
current
let me know if we prefer to subclass instead
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This approach looks good to me, hopefully advanced connection management isn't needed most of the time + it's good to hide the LakebasePool details