Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions libs/checkpoint-postgres/langgraph/checkpoint/postgres/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,19 +54,28 @@ def __init__(
@classmethod
@contextmanager
def from_conn_string(
cls, conn_string: str, *, pipeline: bool = False
cls,
conn_string: str,
*,
pipeline: bool = False,
prepare_threshold: int | None = 0,
) -> Iterator[PostgresSaver]:
"""Create a new PostgresSaver instance from a connection string.

Args:
conn_string: The Postgres connection info string.
pipeline: whether to use Pipeline
prepare_threshold: Threshold for prepared statements. Set to None to disable
prepared statements (required for external connection poolers).

Returns:
PostgresSaver: A new PostgresSaver instance.
"""
with Connection.connect(
conn_string, autocommit=True, prepare_threshold=0, row_factory=dict_row
conn_string,
autocommit=True,
prepare_threshold=prepare_threshold,
row_factory=dict_row,
) as conn:
if pipeline:
with conn.pipeline() as pipe:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,18 +60,25 @@ async def from_conn_string(
*,
pipeline: bool = False,
serde: SerializerProtocol | None = None,
prepare_threshold: int | None = 0,
) -> AsyncIterator[AsyncPostgresSaver]:
"""Create a new AsyncPostgresSaver instance from a connection string.

Args:
conn_string: The Postgres connection info string.
pipeline: whether to use AsyncPipeline
serde: Custom serializer for checkpoint data.
prepare_threshold: Threshold for prepared statements. Set to None to disable
prepared statements (required for external connection poolers).

Returns:
AsyncPostgresSaver: A new AsyncPostgresSaver instance.
"""
async with await AsyncConnection.connect(
conn_string, autocommit=True, prepare_threshold=0, row_factory=dict_row
conn_string,
autocommit=True,
prepare_threshold=prepare_threshold,
row_factory=dict_row,
) as conn:
if pipeline:
async with conn.pipeline() as pipe:
Expand Down
22 changes: 19 additions & 3 deletions libs/checkpoint-postgres/langgraph/checkpoint/postgres/shallow.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,19 +209,28 @@ def __init__(
@classmethod
@contextmanager
def from_conn_string(
cls, conn_string: str, *, pipeline: bool = False
cls,
conn_string: str,
*,
pipeline: bool = False,
prepare_threshold: int | None = 0,
) -> Iterator["ShallowPostgresSaver"]:
"""Create a new ShallowPostgresSaver instance from a connection string.

Args:
conn_string: The Postgres connection info string.
pipeline: whether to use Pipeline
prepare_threshold: Threshold for prepared statements. Set to None to disable
prepared statements (required for external connection poolers).

Returns:
ShallowPostgresSaver: A new ShallowPostgresSaver instance.
"""
with Connection.connect(
conn_string, autocommit=True, prepare_threshold=0, row_factory=dict_row
conn_string,
autocommit=True,
prepare_threshold=prepare_threshold,
row_factory=dict_row,
) as conn:
if pipeline:
with conn.pipeline() as pipe:
Expand Down Expand Up @@ -574,18 +583,25 @@ async def from_conn_string(
*,
pipeline: bool = False,
serde: SerializerProtocol | None = None,
prepare_threshold: int | None = 0,
) -> AsyncIterator["AsyncShallowPostgresSaver"]:
"""Create a new AsyncShallowPostgresSaver instance from a connection string.

Args:
conn_string: The Postgres connection info string.
pipeline: whether to use AsyncPipeline
serde: Custom serializer for checkpoint data.
prepare_threshold: Threshold for prepared statements. Set to None to disable
prepared statements (required for external connection poolers).

Returns:
AsyncShallowPostgresSaver: A new AsyncShallowPostgresSaver instance.
"""
async with await AsyncConnection.connect(
conn_string, autocommit=True, prepare_threshold=0, row_factory=dict_row
conn_string,
autocommit=True,
prepare_threshold=prepare_threshold,
row_factory=dict_row,
) as conn:
if pipeline:
async with conn.pipeline() as pipe:
Expand Down
116 changes: 116 additions & 0 deletions libs/checkpoint-postgres/tests/test_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,3 +371,119 @@ def patched_load_checkpoint_tuple(value):

checkpoint = await saver.aget_tuple(config)
assert checkpoint.checkpoint["channel_values"] == {}


@pytest.mark.asyncio
async def test_prepare_threshold_default():
"""Test that default prepare_threshold=0 works."""
database = f"test_{uuid4().hex[:16]}"
# create unique db
async with await AsyncConnection.connect(
DEFAULT_POSTGRES_URI, autocommit=True
) as conn:
await conn.execute(f"CREATE DATABASE {database}")
try:
async with AsyncPostgresSaver.from_conn_string(
DEFAULT_POSTGRES_URI + database
) as checkpointer:
await checkpointer.setup()
config = {"configurable": {"thread_id": "1", "checkpoint_ns": ""}}
checkpoint = create_checkpoint(empty_checkpoint(), {}, 1)
metadata = {"source": "test", "step": 1}
await checkpointer.aput(config, checkpoint, metadata, {})
retrieved = await checkpointer.aget_tuple(config)
assert retrieved is not None
assert retrieved.metadata["source"] == "test"
finally:
# drop unique db
async with await AsyncConnection.connect(
DEFAULT_POSTGRES_URI, autocommit=True
) as conn:
await conn.execute(f"DROP DATABASE {database}")


@pytest.mark.asyncio
async def test_prepare_threshold_none():
"""Test that prepare_threshold=None works (for connection poolers)."""
database = f"test_{uuid4().hex[:16]}"
# create unique db
async with await AsyncConnection.connect(
DEFAULT_POSTGRES_URI, autocommit=True
) as conn:
await conn.execute(f"CREATE DATABASE {database}")
try:
async with AsyncPostgresSaver.from_conn_string(
DEFAULT_POSTGRES_URI + database, prepare_threshold=None
) as checkpointer:
await checkpointer.setup()
config = {"configurable": {"thread_id": "1", "checkpoint_ns": ""}}
checkpoint = create_checkpoint(empty_checkpoint(), {}, 1)
metadata = {"source": "test", "step": 1}
await checkpointer.aput(config, checkpoint, metadata, {})
retrieved = await checkpointer.aget_tuple(config)
assert retrieved is not None
assert retrieved.metadata["source"] == "test"
finally:
# drop unique db
async with await AsyncConnection.connect(
DEFAULT_POSTGRES_URI, autocommit=True
) as conn:
await conn.execute(f"DROP DATABASE {database}")


@pytest.mark.asyncio
async def test_prepare_threshold_custom_value():
"""Test that custom prepare_threshold value works."""
database = f"test_{uuid4().hex[:16]}"
# create unique db
async with await AsyncConnection.connect(
DEFAULT_POSTGRES_URI, autocommit=True
) as conn:
await conn.execute(f"CREATE DATABASE {database}")
try:
async with AsyncPostgresSaver.from_conn_string(
DEFAULT_POSTGRES_URI + database, prepare_threshold=5
) as checkpointer:
await checkpointer.setup()
config = {"configurable": {"thread_id": "1", "checkpoint_ns": ""}}
checkpoint = create_checkpoint(empty_checkpoint(), {}, 1)
metadata = {"source": "test", "step": 1}
await checkpointer.aput(config, checkpoint, metadata, {})
retrieved = await checkpointer.aget_tuple(config)
assert retrieved is not None
assert retrieved.metadata["source"] == "test"
finally:
# drop unique db
async with await AsyncConnection.connect(
DEFAULT_POSTGRES_URI, autocommit=True
) as conn:
await conn.execute(f"DROP DATABASE {database}")


@pytest.mark.asyncio
async def test_shallow_prepare_threshold_none():
"""Test that AsyncShallowPostgresSaver works with prepare_threshold=None."""
database = f"test_{uuid4().hex[:16]}"
# create unique db
async with await AsyncConnection.connect(
DEFAULT_POSTGRES_URI, autocommit=True
) as conn:
await conn.execute(f"CREATE DATABASE {database}")
try:
async with AsyncShallowPostgresSaver.from_conn_string(
DEFAULT_POSTGRES_URI + database, prepare_threshold=None
) as checkpointer:
await checkpointer.setup()
config = {"configurable": {"thread_id": "1", "checkpoint_ns": ""}}
checkpoint = create_checkpoint(empty_checkpoint(), {}, 1)
metadata = {"source": "test", "step": 1}
await checkpointer.aput(config, checkpoint, metadata, {})
retrieved = await checkpointer.aget_tuple(config)
assert retrieved is not None
assert retrieved.metadata["source"] == "test"
finally:
# drop unique db
async with await AsyncConnection.connect(
DEFAULT_POSTGRES_URI, autocommit=True
) as conn:
await conn.execute(f"DROP DATABASE {database}")
96 changes: 96 additions & 0 deletions libs/checkpoint-postgres/tests/test_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,3 +358,99 @@ def patched_load_checkpoint_tuple(value):

checkpoint = saver.get_tuple(config)
assert checkpoint.checkpoint["channel_values"] == {}


def test_prepare_threshold_default():
"""Test that default prepare_threshold=0 works."""
database = f"test_{uuid4().hex[:16]}"
# create unique db
with Connection.connect(DEFAULT_POSTGRES_URI, autocommit=True) as conn:
conn.execute(f"CREATE DATABASE {database}")
try:
with PostgresSaver.from_conn_string(
DEFAULT_POSTGRES_URI + database
) as checkpointer:
checkpointer.setup()
config = {"configurable": {"thread_id": "1", "checkpoint_ns": ""}}
checkpoint = create_checkpoint(empty_checkpoint(), {}, 1)
metadata = {"source": "test", "step": 1}
checkpointer.put(config, checkpoint, metadata, {})
retrieved = checkpointer.get_tuple(config)
assert retrieved is not None
assert retrieved.metadata["source"] == "test"
finally:
# drop unique db
with Connection.connect(DEFAULT_POSTGRES_URI, autocommit=True) as conn:
conn.execute(f"DROP DATABASE {database}")


def test_prepare_threshold_none():
"""Test that prepare_threshold=None works (for connection poolers)."""
database = f"test_{uuid4().hex[:16]}"
# create unique db
with Connection.connect(DEFAULT_POSTGRES_URI, autocommit=True) as conn:
conn.execute(f"CREATE DATABASE {database}")
try:
with PostgresSaver.from_conn_string(
DEFAULT_POSTGRES_URI + database, prepare_threshold=None
) as checkpointer:
checkpointer.setup()
config = {"configurable": {"thread_id": "1", "checkpoint_ns": ""}}
checkpoint = create_checkpoint(empty_checkpoint(), {}, 1)
metadata = {"source": "test", "step": 1}
checkpointer.put(config, checkpoint, metadata, {})
retrieved = checkpointer.get_tuple(config)
assert retrieved is not None
assert retrieved.metadata["source"] == "test"
finally:
# drop unique db
with Connection.connect(DEFAULT_POSTGRES_URI, autocommit=True) as conn:
conn.execute(f"DROP DATABASE {database}")


def test_prepare_threshold_custom_value():
"""Test that custom prepare_threshold value works."""
database = f"test_{uuid4().hex[:16]}"
# create unique db
with Connection.connect(DEFAULT_POSTGRES_URI, autocommit=True) as conn:
conn.execute(f"CREATE DATABASE {database}")
try:
with PostgresSaver.from_conn_string(
DEFAULT_POSTGRES_URI + database, prepare_threshold=5
) as checkpointer:
checkpointer.setup()
config = {"configurable": {"thread_id": "1", "checkpoint_ns": ""}}
checkpoint = create_checkpoint(empty_checkpoint(), {}, 1)
metadata = {"source": "test", "step": 1}
checkpointer.put(config, checkpoint, metadata, {})
retrieved = checkpointer.get_tuple(config)
assert retrieved is not None
assert retrieved.metadata["source"] == "test"
finally:
# drop unique db
with Connection.connect(DEFAULT_POSTGRES_URI, autocommit=True) as conn:
conn.execute(f"DROP DATABASE {database}")


def test_shallow_prepare_threshold_none():
"""Test that ShallowPostgresSaver works with prepare_threshold=None."""
database = f"test_{uuid4().hex[:16]}"
# create unique db
with Connection.connect(DEFAULT_POSTGRES_URI, autocommit=True) as conn:
conn.execute(f"CREATE DATABASE {database}")
try:
with ShallowPostgresSaver.from_conn_string(
DEFAULT_POSTGRES_URI + database, prepare_threshold=None
) as checkpointer:
checkpointer.setup()
config = {"configurable": {"thread_id": "1", "checkpoint_ns": ""}}
checkpoint = create_checkpoint(empty_checkpoint(), {}, 1)
metadata = {"source": "test", "step": 1}
checkpointer.put(config, checkpoint, metadata, {})
retrieved = checkpointer.get_tuple(config)
assert retrieved is not None
assert retrieved.metadata["source"] == "test"
finally:
# drop unique db
with Connection.connect(DEFAULT_POSTGRES_URI, autocommit=True) as conn:
conn.execute(f"DROP DATABASE {database}")