Skip to content

Commit ea9e94b

Browse files
[tests]: a nemo guardrails with retrieval rail to test matrix (#328)
* Add nemo guardrails to test matrix
1 parent fd7df61 commit ea9e94b

File tree

9 files changed

+203
-31
lines changed

9 files changed

+203
-31
lines changed

.github/workflows/_run_e2e_tests.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ jobs:
108108
ASTRA_DB_ID: "${{ steps.astra-db.outputs.db_id }}"
109109
ASTRA_DB_ENV: "${{ inputs.astradb-env }}"
110110
OPEN_AI_KEY: "${{ secrets.E2E_TESTS_OPEN_AI_KEY }}"
111+
OPENAI_API_KEY: "${{ secrets.E2E_TESTS_OPEN_AI_KEY }}"
111112
AZURE_OPEN_AI_KEY: "${{ secrets.E2E_TESTS_AZURE_OPEN_AI_KEY }}"
112113
AZURE_OPEN_AI_ENDPOINT: "${{ secrets.E2E_TESTS_AZURE_OPEN_AI_ENDPOINT }}"
113114
AZURE_BLOB_STORAGE_CONNECTION_STRING: "${{ secrets.E2E_TESTS_AZURE_BLOB_STORAGE_CONNECTION_STRING }}"
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
from e2e_tests.langchain.rag_application import (
2+
BASIC_QA_PROMPT,
3+
SAMPLE_DATA,
4+
)
5+
6+
from langchain.schema.vectorstore import VectorStore
7+
from langchain.schema.output_parser import StrOutputParser
8+
from langchain.schema.retriever import BaseRetriever
9+
from langchain.prompts import PromptTemplate
10+
from langchain.llms.base import BaseLLM
11+
12+
13+
from nemoguardrails import RailsConfig, LLMRails
14+
from nemoguardrails.actions.actions import ActionResult
15+
16+
17+
def _config(engine, model) -> str:
18+
return f"""
19+
models:
20+
- type: main
21+
engine: {engine}
22+
model: {model}
23+
"""
24+
25+
26+
def _colang() -> str:
27+
return """
28+
define user express greeting
29+
"Hi, how are you?"
30+
31+
define user ask about product
32+
"What was MyFakeProductForTesting?"
33+
"When was MyFakeProductForTesting first released?"
34+
"What capabilities does MyFakeProductForTesting have?"
35+
"What is MyFakeProductForTesting's best feature?"
36+
37+
define bot express greeting
38+
"Hello! I hope to answer all your questions!"
39+
40+
define flow greeting
41+
user express greeting
42+
bot express greeting
43+
44+
define flow answer product question
45+
user ask about product
46+
$answer = execute rag()
47+
bot $answer
48+
"""
49+
50+
51+
class NeMoRag:
52+
def __init__(self, retriever):
53+
self.retriever = retriever
54+
55+
async def rag_using_lc(self, context: dict, llm: BaseLLM) -> ActionResult:
56+
"""
57+
Defines the custom rag action
58+
"""
59+
user_message = context.get("last_user_message")
60+
context_updates = {}
61+
62+
# Use your pre-defined AstraDB Vector Store as the retriever
63+
relevant_documents = await self.retriever.aget_relevant_documents(user_message)
64+
relevant_chunks = "\n".join(
65+
[chunk.page_content for chunk in relevant_documents]
66+
)
67+
68+
# Use a custom prompt template
69+
prompt_template = PromptTemplate.from_template(BASIC_QA_PROMPT)
70+
input_variables = {"question": user_message, "context": relevant_chunks}
71+
72+
# Create LCEL chain
73+
chain = prompt_template | llm | StrOutputParser()
74+
answer = await chain.ainvoke(input_variables)
75+
76+
return ActionResult(return_value=answer, context_updates=context_updates)
77+
78+
def init(self, app: LLMRails):
79+
app.register_action(self.rag_using_lc, "rag")
80+
81+
82+
def _try_runnable_rails(config: RailsConfig, retriever: BaseRetriever) -> None:
83+
# LLM is created internally to rails using the provided config
84+
rails = LLMRails(config)
85+
processor = NeMoRag(retriever)
86+
processor.init(rails)
87+
88+
response = rails.generate(
89+
messages=[
90+
{
91+
"role": "user",
92+
"content": "Hi, how are you?",
93+
}
94+
]
95+
)
96+
assert "Hello! I hope to answer all your questions" in response["content"]
97+
98+
response = rails.generate(
99+
messages=[
100+
{
101+
"role": "user",
102+
"content": "When was MyFakeProductForTesting first released?",
103+
}
104+
]
105+
)
106+
assert "2020" in response["content"]
107+
108+
109+
def run_nemo_guardrails(vector_store: VectorStore, config: dict[str, str]) -> None:
110+
vector_store.add_texts(SAMPLE_DATA)
111+
retriever = vector_store.as_retriever()
112+
113+
model_config = _config(config["engine"], config["model"])
114+
rails_config = RailsConfig.from_content(
115+
yaml_content=model_config, colang_content=_colang()
116+
)
117+
_try_runnable_rails(config=rails_config, retriever=retriever)

ragstack-e2e-tests/e2e_tests/langchain/rag_application.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131

3232

3333
BASIC_QA_PROMPT = """
34-
Answer the question based only on the supplied context. If you don't know the answer, say you don't know the answer.
34+
Answer the question based only on the supplied context. If you don't know the answer, say the following: "I don't know the answer".
3535
Context: {context}
3636
Question: {question}
3737
Your answer:

ragstack-e2e-tests/e2e_tests/langchain/test_compatibility_rag.py

Lines changed: 75 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
)
1616
from e2e_tests.langchain.trulens import run_trulens_evaluation
1717
from e2e_tests.test_utils import get_local_resource_path
18+
from e2e_tests.langchain.nemo_guardrails import run_nemo_guardrails
1819

1920
from langchain.chat_models import ChatOpenAI, AzureChatOpenAI, ChatVertexAI, BedrockChat
2021
from langchain.embeddings import (
@@ -59,17 +60,27 @@ def _chat_openai(**kwargs) -> ChatOpenAI:
5960

6061
@pytest.fixture
6162
def openai_gpt35turbo_llm():
62-
return _chat_openai(model="gpt-3.5-turbo", streaming=False)
63+
# NeMo guardrails fails for this model with the given prompts.
64+
model = "gpt-3.5-turbo"
65+
return {"llm": _chat_openai(model=model, streaming=False), "nemo_config": None}
6366

6467

6568
@pytest.fixture
6669
def openai_gpt4_llm():
67-
return _chat_openai(model="gpt-4", streaming=False)
70+
model = "gpt-4"
71+
return {
72+
"llm": _chat_openai(model=model, streaming=False),
73+
"nemo_config": {"engine": "openai", "model": model},
74+
}
6875

6976

7077
@pytest.fixture
7178
def openai_gpt4_llm_streaming():
72-
return _chat_openai(model="gpt-4", streaming=True)
79+
model = "gpt-4"
80+
return {
81+
"llm": _chat_openai(model=model, streaming=True),
82+
"nemo_config": {"engine": "openai", "model": model},
83+
}
7384

7485

7586
def _openai_embeddings(**kwargs) -> OpenAIEmbeddings:
@@ -95,13 +106,16 @@ def openai_3large_embedding():
95106
def azure_openai_gpt35turbo_llm():
96107
# model is configurable because it can be different from the deployment
97108
# but the targeting model must be gpt-35-turbo
98-
return AzureChatOpenAI(
99-
azure_deployment=get_required_env("AZURE_OPEN_AI_CHAT_MODEL_DEPLOYMENT"),
100-
openai_api_base=get_required_env("AZURE_OPEN_AI_ENDPOINT"),
101-
openai_api_key=get_required_env("AZURE_OPEN_AI_KEY"),
102-
openai_api_type="azure",
103-
openai_api_version="2023-07-01-preview",
104-
)
109+
return {
110+
"llm": AzureChatOpenAI(
111+
azure_deployment=get_required_env("AZURE_OPEN_AI_CHAT_MODEL_DEPLOYMENT"),
112+
openai_api_base=get_required_env("AZURE_OPEN_AI_ENDPOINT"),
113+
openai_api_key=get_required_env("AZURE_OPEN_AI_KEY"),
114+
openai_api_type="azure",
115+
openai_api_version="2023-07-01-preview",
116+
),
117+
"nemo_config": None,
118+
}
105119

106120

107121
@pytest.fixture
@@ -123,7 +137,7 @@ def azure_openai_ada002_embedding():
123137

124138
@pytest.fixture
125139
def vertex_bison_llm():
126-
return ChatVertexAI(model_name="chat-bison")
140+
return {"llm": ChatVertexAI(model_name="chat-bison"), "nemo_config": None}
127141

128142

129143
@pytest.fixture
@@ -137,21 +151,30 @@ def _bedrock_chat(**kwargs) -> BedrockChat:
137151

138152
@pytest.fixture
139153
def bedrock_anthropic_claudev2_llm():
140-
return _bedrock_chat(
141-
model_id="anthropic.claude-v2",
142-
)
154+
return {
155+
"llm": _bedrock_chat(
156+
model_id="anthropic.claude-v2",
157+
),
158+
"nemo_config": None,
159+
}
143160

144161

145162
@pytest.fixture
146163
def bedrock_mistral_mistral7b_llm():
147-
return _bedrock_chat(
148-
model_id="mistral.mistral-7b-instruct-v0:2",
149-
)
164+
return {
165+
"llm": _bedrock_chat(
166+
model_id="mistral.mistral-7b-instruct-v0:2",
167+
),
168+
"nemo_config": None,
169+
}
150170

151171

152172
@pytest.fixture
153173
def bedrock_meta_llama2_llm():
154-
return _bedrock_chat(model_id="meta.llama2-13b-chat-v1")
174+
return {
175+
"llm": _bedrock_chat(model_id="meta.llama2-13b-chat-v1"),
176+
"nemo_config": None,
177+
}
155178

156179

157180
@pytest.fixture
@@ -172,11 +195,14 @@ def bedrock_cohere_embedding():
172195

173196
@pytest.fixture
174197
def huggingface_hub_flant5xxl_llm():
175-
return HuggingFaceHub(
176-
repo_id="google/flan-t5-xxl",
177-
huggingfacehub_api_token=get_required_env("HUGGINGFACE_HUB_KEY"),
178-
model_kwargs={"temperature": 1, "max_length": 256},
179-
)
198+
return {
199+
"llm": HuggingFaceHub(
200+
repo_id="google/flan-t5-xxl",
201+
huggingfacehub_api_token=get_required_env("HUGGINGFACE_HUB_KEY"),
202+
model_kwargs={"temperature": 1, "max_length": 256},
203+
),
204+
"nemo_config": None,
205+
}
180206

181207

182208
@pytest.fixture
@@ -190,7 +216,7 @@ def huggingface_hub_minilml6v2_embedding():
190216
@pytest.fixture
191217
def nvidia_aifoundation_nvolveqa40k_embedding():
192218
get_required_env("NVIDIA_API_KEY")
193-
from langchain_nvidia_ai_endpoints import NVIDIAEmbeddings
219+
from langchain_nvidia_ai_endpoints.embeddings import NVIDIAEmbeddings
194220

195221
return NVIDIAEmbeddings(model="playground_nvolveqa_40k")
196222

@@ -200,14 +226,17 @@ def nvidia_aifoundation_mixtral8x7b_llm():
200226
get_required_env("NVIDIA_API_KEY")
201227
from langchain_nvidia_ai_endpoints import ChatNVIDIA
202228

203-
return ChatNVIDIA(model="playground_mixtral_8x7b")
229+
return {"llm": ChatNVIDIA(model="playground_mixtral_8x7b"), "nemo_config": None}
204230

205231

206232
@pytest.mark.parametrize(
207233
"test_case",
208-
["rag_custom_chain", "conversational_rag", "trulens"],
234+
["rag_custom_chain", "conversational_rag", "trulens", "nemo_guardrails"],
235+
)
236+
@pytest.mark.parametrize(
237+
"vector_store",
238+
["astra_db", "cassandra"],
209239
)
210-
@pytest.mark.parametrize("vector_store", ["astra_db", "cassandra"])
211240
@pytest.mark.parametrize(
212241
"embedding,llm",
213242
[
@@ -243,8 +272,15 @@ def test_rag(test_case, vector_store, embedding, llm, request, record_property):
243272
)
244273

245274

246-
def _run_test(test_case: str, vector_store_context, embedding, llm, record_property):
275+
def _run_test(
276+
test_case: str,
277+
vector_store_context,
278+
embedding,
279+
resolved_llm,
280+
record_property,
281+
):
247282
vector_store = vector_store_context.new_langchain_vector_store(embedding=embedding)
283+
llm = resolved_llm["llm"]
248284
if test_case == "rag_custom_chain":
249285
run_rag_custom_chain(
250286
vector_store=vector_store, llm=llm, record_property=record_property
@@ -256,8 +292,19 @@ def _run_test(test_case: str, vector_store_context, embedding, llm, record_prope
256292
chat_memory=vector_store_context.new_langchain_chat_memory(),
257293
record_property=record_property,
258294
)
295+
# TODO: Add record property
259296
elif test_case == "trulens":
260297
run_trulens_evaluation(vector_store=vector_store, llm=llm)
298+
elif test_case == "nemo_guardrails":
299+
config = resolved_llm["nemo_config"]
300+
if config:
301+
# NeMo creates the LLM internally using the config
302+
run_nemo_guardrails(
303+
vector_store=vector_store,
304+
config=config,
305+
)
306+
else:
307+
pytest.skip("Skipping NeMo test for this configuration")
261308
else:
262309
raise ValueError(f"Unknown test case: {test_case}")
263310

ragstack-e2e-tests/e2e_tests/test_utils/astradb_vector_store_handler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def await_ongoing_deletions_completed(self):
6060
Blocks until all ongoing deletions are completed.
6161
"""
6262
while self.semaphore._value != self.max_workers:
63-
logging.info(
63+
logging.debug(
6464
f"{self.max_workers - self.semaphore._value} deletions still running, waiting to complete"
6565
)
6666
time.sleep(1)

ragstack-e2e-tests/pyproject.langchain.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ pillow = "^10.2.0"
2020
testcontainers = "^3.7.1"
2121
python-dotenv = "^1.0.1"
2222
trulens-eval = "^0.21.0"
23+
nemoguardrails = "^0.8.0"
2324

2425
# From LangChain optional deps, needed by WebBaseLoader
2526
beautifulsoup4 = "^4"

ragstack-e2e-tests/pyproject.llamaindex.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ pillow = "^10.2.0"
2020
testcontainers = "^3.7.1"
2121
python-dotenv = "^1.0.1"
2222
trulens-eval = "^0.21.0"
23+
nemoguardrails = "^0.7.1"
2324

2425
# From LangChain optional deps, needed by WebBaseLoader
2526
beautifulsoup4 = "^4"

ragstack-e2e-tests/pyproject.ragstack-ai.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ pillow = "^10.2.0"
2222
testcontainers = "^3.7.1"
2323
python-dotenv = "^1.0.1"
2424
trulens-eval = "^0.21.0"
25+
nemoguardrails = "^0.8.0"
2526

2627
# From LangChain optional deps, needed by WebBaseLoader
2728
beautifulsoup4 = "^4"

ragstack-e2e-tests/pyproject.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,15 @@ pillow = "^10.2.0"
2020
testcontainers = "^3.7.1"
2121
python-dotenv = "^1.0.1"
2222
trulens-eval = "^0.21.0"
23+
nemoguardrails = "^0.8.0"
2324

2425
# From LangChain optional deps, needed by WebBaseLoader
2526
beautifulsoup4 = "^4"
2627

27-
ragstack-ai = { path = "../", develop = false, extras = ["langchain-google", "langchain-nvidia"]}
28+
ragstack-ai = { path = "../", develop = false, extras = [
29+
"langchain-google",
30+
"langchain-nvidia",
31+
] }
2832
# we need this specific feature from cassio: https://github.com/CassioML/cassio/pull/128
2933
cassio = "~0.1.4"
3034

0 commit comments

Comments
 (0)