Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
2a7c2f4
Apply patch [skip ci]
open-swe Aug 12, 2025
409fe9f
Apply patch [skip ci]
open-swe Aug 12, 2025
67d05fc
Apply patch [skip ci]
open-swe Aug 12, 2025
f605f91
Apply patch [skip ci]
open-swe Aug 12, 2025
aba42b2
Apply patch [skip ci]
open-swe Aug 12, 2025
ffb2e89
Apply patch [skip ci]
open-swe Aug 12, 2025
66a3f2e
Apply patch [skip ci]
open-swe Aug 12, 2025
ab61a33
Apply patch [skip ci]
open-swe Aug 12, 2025
380636f
Apply patch [skip ci]
open-swe Aug 12, 2025
a9bf29e
Apply patch [skip ci]
open-swe Aug 12, 2025
f1cc9ba
Apply patch [skip ci]
open-swe Aug 12, 2025
a13c686
Apply patch [skip ci]
open-swe Aug 12, 2025
bcbd047
Apply patch [skip ci]
open-swe Aug 12, 2025
e572722
Apply patch [skip ci]
open-swe Aug 12, 2025
be8656a
Apply patch [skip ci]
open-swe Aug 12, 2025
d4349a3
Apply patch [skip ci]
open-swe Aug 12, 2025
71f2023
Apply patch [skip ci]
open-swe Aug 12, 2025
c487e54
Apply patch [skip ci]
open-swe Aug 12, 2025
7e669a9
Apply patch [skip ci]
open-swe Aug 12, 2025
d6077d2
Apply patch [skip ci]
open-swe Aug 12, 2025
e9d4894
Apply patch [skip ci]
open-swe Aug 12, 2025
54bf2b9
Apply patch [skip ci]
open-swe Aug 12, 2025
a6f0c69
Apply patch [skip ci]
open-swe Aug 12, 2025
3af8a61
Apply patch [skip ci]
open-swe Aug 12, 2025
2167a6b
Apply patch [skip ci]
open-swe Aug 12, 2025
8838282
Apply patch [skip ci]
open-swe Aug 12, 2025
4bd0c1a
Apply patch [skip ci]
open-swe Aug 12, 2025
c6d82fb
Apply patch [skip ci]
open-swe Aug 12, 2025
bd8d186
Apply patch [skip ci]
open-swe Aug 12, 2025
71bcfdd
Empty commit to trigger CI
open-swe Aug 12, 2025
f6fb060
locking
sydney-runkle Aug 12, 2025
86e1bb7
more docs
sydney-runkle Aug 12, 2025
1398d07
docs
sydney-runkle Aug 12, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
477 changes: 30 additions & 447 deletions README.md

Large diffs are not rendered by default.

8 changes: 4 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,16 @@ authors = [
]
license = { text = "MIT" }
readme = "README.md"
requires-python = ">=3.9"
requires-python = ">=3.11"
dependencies = [
"langgraph>=0.2.6",
"langgraph>=0.6.0,<0.7.0",
"langchain-openai>=0.1.22",
"langchain-anthropic>=0.1.23",
"langchain>=0.2.14",
"langchain-fireworks>=0.1.7",
"python-dotenv>=1.0.1",
"langchain-elasticsearch>=0.2.2,<0.3.0",
"langchain-pinecone>=0.1.3,<0.2.0",
"langchain-elasticsearch>=0.3.0,<0.4.0",
"langchain-pinecone>=0.2.0,<0.3.0",
"msgspec>=0.18.6",
"langchain-mongodb>=0.1.9",
"langchain-cohere>=0.2.4",
Expand Down
56 changes: 33 additions & 23 deletions src/retrieval_graph/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,9 @@

from __future__ import annotations

from dataclasses import dataclass, field, fields
from typing import Annotated, Any, Literal, Optional, Type, TypeVar

from langchain_core.runnables import RunnableConfig, ensure_config
import os
from dataclasses import dataclass, field
from typing import Annotated, Any, Literal

from retrieval_graph import prompts

Expand All @@ -19,7 +18,9 @@ class IndexConfiguration:
retriever provider choice, and search parameters.
"""

user_id: str = field(metadata={"description": "Unique identifier for the user."})
user_id: str = field(
default="", metadata={"description": "Unique identifier for the user."}
)

embedding_model: Annotated[
str,
Expand Down Expand Up @@ -48,26 +49,19 @@ class IndexConfiguration:
},
)

@classmethod
def from_runnable_config(
cls: Type[T], config: Optional[RunnableConfig] = None
) -> T:
"""Create an IndexConfiguration instance from a RunnableConfig object.

Args:
cls (Type[T]): The class itself.
config (Optional[RunnableConfig]): The configuration object to use.
def __post_init__(self) -> None:
"""Populate fields from environment variables if not already set."""
# Only populate from environment variables if the field is not already set
if not self.user_id:
self.user_id = os.environ.get("USER_ID", "")

Returns:
T: An instance of IndexConfiguration with the specified configuration.
"""
config = ensure_config(config)
configurable = config.get("configurable") or {}
_fields = {f.name for f in fields(cls) if f.init}
return cls(**{k: v for k, v in configurable.items() if k in _fields})
if self.embedding_model == "openai/text-embedding-3-small":
self.embedding_model = os.environ.get(
"EMBEDDING_MODEL", "openai/text-embedding-3-small"
)


T = TypeVar("T", bound=IndexConfiguration)
if self.retriever_provider == "elastic":
self.retriever_provider = os.environ.get("RETRIEVER_PROVIDER", "elastic") # type: ignore


@dataclass(kw_only=True)
Expand Down Expand Up @@ -99,3 +93,19 @@ class Configuration(IndexConfiguration):
"description": "The language model used for processing and refining queries. Should be in the form: provider/model-name."
},
)

def __post_init__(self) -> None:
"""Populate fields from environment variables if not already set."""
# Call parent's __post_init__ first
super().__post_init__()

# Only populate from environment variables if the field is using the default value
if self.response_model == "anthropic/claude-3-5-sonnet-20240620":
self.response_model = os.environ.get(
"RESPONSE_MODEL", "anthropic/claude-3-5-sonnet-20240620"
)

if self.query_model == "anthropic/claude-3-haiku-20240307":
self.query_model = os.environ.get(
"QUERY_MODEL", "anthropic/claude-3-haiku-20240307"
)
32 changes: 15 additions & 17 deletions src/retrieval_graph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
from langchain_core.messages import BaseMessage
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.runnables import RunnableConfig
from langgraph.graph import StateGraph
from langgraph.runtime import Runtime

from retrieval_graph import retrieval
from retrieval_graph.configuration import Configuration
Expand All @@ -31,7 +31,7 @@ class SearchQuery(BaseModel):


async def generate_query(
state: State, *, config: RunnableConfig
state: State, *, runtime: Runtime[Configuration]
) -> dict[str, list[str]]:
"""Generate a search query based on the current state and configuration.

Expand All @@ -41,7 +41,7 @@ async def generate_query(

Args:
state (State): The current state containing messages and other information.
config (RunnableConfig | None, optional): Configuration for the query generation process.
runtime (Runtime[Configuration]): Runtime context containing configuration.

Returns:
dict[str, list[str]]: A dictionary with a 'queries' key containing a list of generated queries.
Expand All @@ -57,7 +57,7 @@ async def generate_query(
human_input = get_message_text(messages[-1])
return {"queries": [human_input]}
else:
configuration = Configuration.from_runnable_config(config)
configuration = runtime.context
# Feel free to customize the prompt, model, and other logic!
prompt = ChatPromptTemplate.from_messages(
[
Expand All @@ -74,17 +74,16 @@ async def generate_query(
"messages": state.messages,
"queries": "\n- ".join(state.queries),
"system_time": datetime.now(tz=timezone.utc).isoformat(),
},
config,
}
)
generated = cast(SearchQuery, await model.ainvoke(message_value, config))
generated = cast(SearchQuery, await model.ainvoke(message_value))
return {
"queries": [generated.query],
}


async def retrieve(
state: State, *, config: RunnableConfig
state: State, *, runtime: Runtime[Configuration]
) -> dict[str, list[Document]]:
"""Retrieve documents based on the latest query in the state.

Expand All @@ -94,22 +93,22 @@ async def retrieve(

Args:
state (State): The current state containing queries and the retriever.
config (RunnableConfig | None, optional): Configuration for the retrieval process.
runtime (Runtime[Configuration]): Runtime context containing configuration.

Returns:
dict[str, list[Document]]: A dictionary with a single key "retrieved_docs"
containing a list of retrieved Document objects.
"""
with retrieval.make_retriever(config) as retriever:
response = await retriever.ainvoke(state.queries[-1], config)
with retrieval.make_retriever(runtime) as retriever:
response = await retriever.ainvoke(state.queries[-1])
return {"retrieved_docs": response}


async def respond(
state: State, *, config: RunnableConfig
state: State, *, runtime: Runtime[Configuration]
) -> dict[str, list[BaseMessage]]:
"""Call the LLM powering our "agent"."""
configuration = Configuration.from_runnable_config(config)
configuration = runtime.context
# Feel free to customize the prompt, model, and other logic!
prompt = ChatPromptTemplate.from_messages(
[
Expand All @@ -125,18 +124,17 @@ async def respond(
"messages": state.messages,
"retrieved_docs": retrieved_docs,
"system_time": datetime.now(tz=timezone.utc).isoformat(),
},
config,
}
)
response = await model.ainvoke(message_value, config)
response = await model.ainvoke(message_value)
# We return a list, because this will get added to the existing list
return {"messages": [response]}


# Define a new graph (It's just a pipe)


builder = StateGraph(State, input=InputState, config_schema=Configuration)
builder = StateGraph(State, input=InputState, context_schema=Configuration)

builder.add_node(generate_query)
builder.add_node(retrieve)
Expand Down
22 changes: 10 additions & 12 deletions src/retrieval_graph/index_graph.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,28 @@
"""This "graph" simply exposes an endpoint for a user to upload docs to be indexed."""

from typing import Optional, Sequence
from typing import Sequence

from langchain_core.documents import Document
from langchain_core.runnables import RunnableConfig
from langgraph.graph import StateGraph
from langgraph.runtime import Runtime

from retrieval_graph import retrieval
from retrieval_graph.configuration import IndexConfiguration
from retrieval_graph.state import IndexState


def ensure_docs_have_user_id(
docs: Sequence[Document], config: RunnableConfig
docs: Sequence[Document], runtime: Runtime[IndexConfiguration]
) -> list[Document]:
"""Ensure that all documents have a user_id in their metadata.

docs (Sequence[Document]): A sequence of Document objects to process.
config (RunnableConfig): A configuration object containing the user_id.
runtime (Runtime[IndexConfiguration]): Runtime context containing configuration.

Returns:
list[Document]: A new list of Document objects with updated metadata.
"""
user_id = config["configurable"]["user_id"]
user_id = runtime.context.user_id
return [
Document(
page_content=doc.page_content, metadata={**doc.metadata, "user_id": user_id}
Expand All @@ -32,7 +32,7 @@ def ensure_docs_have_user_id(


async def index_docs(
state: IndexState, *, config: Optional[RunnableConfig] = None
state: IndexState, *, runtime: Runtime[IndexConfiguration]
) -> dict[str, str]:
"""Asynchronously index documents in the given state using the configured retriever.

Expand All @@ -42,12 +42,10 @@ async def index_docs(

Args:
state (IndexState): The current state containing documents and retriever.
config (Optional[RunnableConfig]): Configuration for the indexing process.r
runtime (Runtime[IndexConfiguration]): Runtime context containing configuration.
"""
if not config:
raise ValueError("Configuration required to run index_docs.")
with retrieval.make_retriever(config) as retriever:
stamped_docs = ensure_docs_have_user_id(state.docs, config)
with retrieval.make_retriever(runtime) as retriever:
stamped_docs = ensure_docs_have_user_id(state.docs, runtime)

await retriever.aadd_documents(stamped_docs)
return {"docs": "delete"}
Expand All @@ -56,7 +54,7 @@ async def index_docs(
# Define a new graph


builder = StateGraph(IndexState, config_schema=IndexConfiguration)
builder = StateGraph(IndexState, context_schema=IndexConfiguration)
builder.add_node(index_docs)
builder.add_edge("__start__", "index_docs")
# Finally, we compile it!
Expand Down
8 changes: 3 additions & 5 deletions src/retrieval_graph/retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
from typing import Generator

from langchain_core.embeddings import Embeddings
from langchain_core.runnables import RunnableConfig
from langchain_core.vectorstores import VectorStoreRetriever
from langgraph.runtime import get_runtime

from retrieval_graph.configuration import Configuration, IndexConfiguration

Expand Down Expand Up @@ -105,11 +105,9 @@ def make_mongodb_retriever(


@contextmanager
def make_retriever(
config: RunnableConfig,
) -> Generator[VectorStoreRetriever, None, None]:
def make_retriever() -> Generator[VectorStoreRetriever, None, None]:
"""Create a retriever for the agent, based on the current configuration."""
configuration = IndexConfiguration.from_runnable_config(config)
configuration = get_runtime(IndexConfiguration).context
embedding_model = make_text_encoder(configuration.embedding_model)
user_id = configuration.user_id
if not user_id:
Expand Down
17 changes: 6 additions & 11 deletions tests/integration_tests/test_graph.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import uuid

import pytest
from langchain_core.runnables import RunnableConfig
from langsmith import expect, unit

from retrieval_graph import graph, index_graph
Expand All @@ -14,27 +13,23 @@ async def test_retrieval_graph() -> None:
user_id = "test__" + uuid.uuid4().hex
other_user_id = "test__" + uuid.uuid4().hex

config = RunnableConfig(
configurable={"user_id": user_id, "retriever_provider": "elastic-local"}
)
context = {"user_id": user_id, "retriever_provider": "elastic-local"}

result = await index_graph.ainvoke({"docs": simple_doc}, config)
result = await index_graph.ainvoke({"docs": simple_doc}, context=context)
expect(result["docs"]).against(lambda x: not x) # we delete after the end

res = await graph.ainvoke(
{"messages": [("user", "Where do cats perform synchronized swimming routes?")]},
config,
context=context,
)
response = str(res["messages"][-1].content)
expect(response.lower()).to_contain("bowl")

res = await graph.ainvoke(
{"messages": [("user", "Where do cats perform synchronized swimming routes?")]},
{
"configurable": {
"user_id": other_user_id,
"retriever_provider": "elastic-local",
}
context={
"user_id": other_user_id,
"retriever_provider": "elastic-local",
},
)
response = str(res["messages"][-1].content)
Expand Down
Loading
Loading