Skip to content

Commit 42b7106

Browse files
committed
Mention blocked responses in the UI and fix ChatVertexAI to handle the broken history
1 parent 78d6a31 commit 42b7106

File tree

3 files changed

+55
-1
lines changed

3 files changed

+55
-1
lines changed

instrumentation-genai/opentelemetry-instrumentation-vertexai/examples/langgraph-chatbot-demo/src/langgraph_chatbot_demo/_streamlit_helpers.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,11 @@ def render_message(message: BaseMessage, trace_id: str | None) -> None:
125125
if isinstance(message.content, str)
126126
else message.content[-1]["text"]
127127
).strip()
128+
129+
# Response was probably blocked by a harm category, go check the trace for details
130+
if message.response_metadata.get("is_blocked", False):
131+
content = ":red[:material/error: Response blocked, try again]"
132+
128133
if not content:
129134
return
130135

instrumentation-genai/opentelemetry-instrumentation-vertexai/examples/langgraph-chatbot-demo/src/langgraph_chatbot_demo/langchain_history.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import tempfile
55
import pathlib
66
from os import environ
7+
import logging
78
from random import getrandbits
89
from typing import cast
910
from google.cloud.exceptions import NotFound
@@ -21,13 +22,17 @@
2122
from langgraph.checkpoint.memory import InMemorySaver
2223
from langgraph.prebuilt import create_react_agent
2324
from langgraph_chatbot_demo import _streamlit_helpers
25+
from langgraph_chatbot_demo.patched_vertexai import PatchedChatVertexAI
2426
from sqlalchemy import Engine, create_engine
2527

2628
from opentelemetry import trace
2729
from opentelemetry.trace.span import format_trace_id
2830

2931
from google.cloud import storage
3032

33+
logger = logging.getLogger(__name__)
34+
logger.setLevel(logging.DEBUG)
35+
3136
_ = """
3237
Ideas for things to add:
3338
@@ -44,7 +49,7 @@
4449
_streamlit_helpers.styles()
4550

4651

47-
model = ChatVertexAI(model="gemini-1.5-flash")
52+
model = PatchedChatVertexAI(model="gemini-2.0-flash")
4853

4954
if not st.query_params.get("thread_id"):
5055
result = model.invoke(
@@ -208,6 +213,7 @@ def get_trace_ids(thread_id: str) -> "dict[str, str]":
208213
# Invoke the agent
209214
with st.spinner("Thinking..."):
210215
res = app.invoke({"messages": [message]}, config=config)
216+
logger.debug("agent response", extra={"response": str(res)})
211217

212218
# Store trace ID for rendering
213219
trace_ids[message.id or ""] = trace_id
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
from __future__ import annotations
2+
3+
from typing import Any
4+
5+
from google.cloud.aiplatform_v1.types import (
6+
GenerateContentRequest as v1GenerateContentRequest,
7+
)
8+
from google.cloud.aiplatform_v1beta1.types import (
9+
GenerateContentRequest,
10+
)
11+
from langchain_core.messages import (
12+
BaseMessage,
13+
)
14+
from langchain_google_vertexai import ChatVertexAI
15+
16+
17+
class PatchedChatVertexAI(ChatVertexAI):
18+
def _prepare_request_gemini(
19+
self, messages: list[BaseMessage], *args: Any, **kwargs: Any
20+
) -> v1GenerateContentRequest | GenerateContentRequest:
21+
# Filter out any blocked messages with no content which can appear if you have a blocked
22+
# message from finish_reason SAFETY:
23+
#
24+
# AIMessage(
25+
# content="",
26+
# additional_kwargs={},
27+
# response_metadata={
28+
# "is_blocked": True,
29+
# "safety_ratings": [ ... ],
30+
# "finish_reason": "SAFETY",
31+
# },
32+
# ...
33+
# )
34+
#
35+
# These cause `google.api_core.exceptions.InvalidArgument: 400 Unable to submit request
36+
# because it must include at least one parts field`
37+
38+
messages = [
39+
message
40+
for message in messages
41+
if not message.response_metadata.get("is_blocked", False)
42+
]
43+
return super()._prepare_request_gemini(messages, *args, **kwargs)

0 commit comments

Comments
 (0)