Skip to content

Commit 517e9c5

Browse files
feat: Add Granite Guardian 3.3 8B with updated examples function call validation and repair with reason. (#167)
* Add Granite Guardian 3.3 8B with dual backends and function call validation - Enhanced GuardianCheck with HuggingFace and Ollama backends - Added thinking mode support for detailed reasoning traces - Implemented actual function calling validation with RepairTemplateStrategy that consumes reasoning in repair process. - Added groundedness and function call hallucination detection examples * restore updates from upstream main. * refactor to use mellea hf and ollama backends. * feat: add reason to repair string. * successful run of examples * cleanup * cleanup * fix fc example. * fix hf example. * guardian_config as passthrough in hf backend. * guardian_config as passthrough in hf backend. * simplelr gg hf example. * pass think to hf backend. * pass think to hf backend. * pass add_generation_prompt to hf backend. * dont pass add_generation_prompt to hf generate. * better construction of messages for guardian. * better construction of messages for guardian. * chore: fixing some ruff issues --------- Co-authored-by: Avinash Balakrishnan <[email protected]> Co-authored-by: Avinash Balakrishnan <[email protected]>
1 parent 519a35a commit 517e9c5

File tree

7 files changed

+718
-144
lines changed

7 files changed

+718
-144
lines changed
Lines changed: 138 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,152 @@
1-
"""Example of using the Guardian Requirement."""
1+
"""Example of using the Enhanced Guardian Requirement with Granite Guardian 3.3 8B"""
22

33
from mellea import MelleaSession
44
from mellea.backends import model_ids
55
from mellea.backends.ollama import OllamaModelBackend
6-
from mellea.stdlib.base import ContextTurn, ModelOutputThunk
6+
from mellea.stdlib.base import ContextTurn, ModelOutputThunk, ChatContext
77
from mellea.stdlib.chat import Message
88
from mellea.stdlib.safety.guardian import GuardianCheck, GuardianRisk
99

10-
# Start with instantiating a GuardianCheck.
11-
guardian = GuardianCheck(risk=GuardianRisk.HARM)
10+
# Enhanced GuardianCheck with Granite Guardian 3.3 8B support
11+
print("=== Enhanced GuardianCheck Examples ===")
12+
13+
# Basic harm checking with new features
14+
guardian = GuardianCheck(GuardianRisk.HARM, thinking=True, backend_type="ollama")
15+
print(f"Guardian configured: {guardian._risk} with thinking mode: {guardian._thinking}")
16+
17+
# Try different risk types with new 3.3 8B model
18+
jailbreak_guardian = GuardianCheck(GuardianRisk.JAILBREAK, thinking=True)
19+
print(f"Jailbreak guardian: {jailbreak_guardian._risk}")
20+
21+
# Custom risk criteria example
22+
custom_guardian = GuardianCheck(
23+
custom_criteria="Check for inappropriate content in educational context"
24+
)
25+
print(f"Custom guardian: {custom_guardian.get_effective_risk()}")
26+
27+
print("\n=== Test 1: Normal Content ===")
28+
# Run a query against an Ollama model with ChatContext to support insert_turn
29+
m = MelleaSession(OllamaModelBackend(model_ids.DEEPSEEK_R1_8B), ctx=ChatContext())
30+
m.chat("Write a professional email to your colleague. Use less than 50 words.")
1231

13-
print("\n Test 1\n")
14-
# Run a query against an Ollama model
15-
m = MelleaSession(OllamaModelBackend(model_ids.DEEPSEEK_R1_8B))
16-
m.chat("Write an email to your colleague. Use less than 100 words.")
17-
print(f"Full context: {m.ctx}")
1832
# Validate on the last turn of the conversation
1933
valid = m.validate([guardian])
20-
# Should return [True] meaning that input and output are free of harm and it's safe to return result
21-
print(f"Guardian result: {valid}")
34+
print(f"Content is safe: {valid[0]._result}")
35+
if valid[0]._reason:
36+
print(f"Guardian feedback: {valid[0]._reason}...")
37+
38+
print("\n=== Test 2: Harmful Content ===")
39+
# Create harmful conversation and add to context
40+
m.ctx = m.ctx.add(Message("user", "Hello")).add(
41+
Message("assistant", "You are very ugly and stupid.")
42+
)
43+
44+
valid = m.validate([guardian])
45+
print(f"Content is safe: {valid[0]._result}")
46+
if valid[0]._reason:
47+
print(f"Guardian feedback: {valid[0]._reason}...")
48+
49+
print("\n=== Test 3: Enhanced Features ===")
50+
51+
# Create multiple guardians with different configurations
52+
guardians = [
53+
GuardianCheck(GuardianRisk.HARM, thinking=True),
54+
GuardianCheck(GuardianRisk.JAILBREAK, thinking=True),
55+
GuardianCheck(GuardianRisk.SOCIAL_BIAS),
56+
GuardianCheck(custom_criteria="Check for financial advice"),
57+
]
58+
59+
print(f"Available risk types ({len(GuardianCheck.get_available_risks())} total):")
60+
for risk in GuardianCheck.get_available_risks(): # Show first 5
61+
print(f" - {risk}")
62+
print(" ...")
2263

23-
# Reproducing a harm scenario and run the validation again
24-
print("\n Test 2\n")
64+
print(f"\nConfigured guardians: {len(guardians)} total")
65+
66+
# Show Ollama backend configuration
67+
ollama_guardian = GuardianCheck(GuardianRisk.HARM, backend_type="ollama")
68+
print(f" Ollama backend: {ollama_guardian._backend.model_version}")
69+
70+
print("\n=== Test 4: Groundedness Detection ===")
71+
# Test groundedness - detecting when responses lack factual grounding
72+
context_text = "One significant part of treaty making is that signing a treaty implies recognition that the other side is a sovereign state and that the agreement being considered is enforceable under international law. Hence, nations can be very careful about terming an agreement to be a treaty. For example, within the United States, agreements between states are compacts and agreements between states and the federal government or between agencies of the government are memoranda of understanding."
73+
74+
groundedness_guardian = GuardianCheck(
75+
GuardianRisk.GROUNDEDNESS,
76+
thinking=True,
77+
backend_type="ollama",
78+
context_text=context_text,
79+
)
2580

26-
# create a mean conversation and add to context
27-
m.ctx = m.ctx.add(Message("user", "Hello. ")).add(
28-
ModelOutputThunk("You are very ugly.")
81+
# Create a response that makes ungrounded claims relative to provided context
82+
groundedness_session = MelleaSession(
83+
OllamaModelBackend(model_ids.DEEPSEEK_R1_8B), ctx=ChatContext()
2984
)
30-
# show last turn in chat
31-
print(f"Context: {m.ctx.last_turn()}")
85+
groundedness_session.ctx = groundedness_session.ctx.add(
86+
Message("user", "What is the history of treaty making?")
87+
).add(
88+
Message(
89+
"assistant",
90+
"Treaty making began in ancient Rome when Julius Caesar invented the concept in 44 BC. The first treaty was signed between Rome and the Moon people, establishing trade routes through space.",
91+
)
92+
)
93+
94+
print("Testing response with ungrounded claims...")
95+
groundedness_valid = groundedness_session.validate([groundedness_guardian])
96+
print(f"Response is grounded: {groundedness_valid[0]._result}")
97+
if groundedness_valid[0]._reason:
98+
print(f"Groundedness feedback: {groundedness_valid[0]._reason}...")
99+
100+
print("\n=== Test 5: Function Call Hallucination Detection ===")
101+
# Test function calling hallucination using IBM video example
102+
from mellea.stdlib.base import ModelOutputThunk, ModelToolCall
103+
104+
tools = [
105+
{
106+
"name": "views_list",
107+
"description": "Fetches total views for a specified IBM video using the given API.",
108+
"parameters": {
109+
"video_id": {
110+
"description": "The ID of the IBM video.",
111+
"type": "int",
112+
"default": "7178094165614464282",
113+
}
114+
},
115+
}
116+
]
117+
118+
function_guardian = GuardianCheck(
119+
GuardianRisk.FUNCTION_CALL, thinking=True, backend_type="ollama", tools=tools
120+
)
121+
122+
123+
# User asks for views but assistant calls wrong function (comments_list instead of views_list)
124+
# Create a proper ModelOutputThunk with tool_calls
125+
def dummy_func(**kwargs):
126+
pass
127+
128+
129+
hallucinated_tool_calls = {
130+
"comments_list": ModelToolCall(
131+
name="comments_list", func=dummy_func, args={"video_id": 456789123, "count": 15}
132+
)
133+
}
134+
135+
hallucinated_output = ModelOutputThunk(
136+
value="I'll fetch the views for you.", tool_calls=hallucinated_tool_calls
137+
)
138+
139+
function_session = MelleaSession(
140+
OllamaModelBackend(model_ids.DEEPSEEK_R1_8B), ctx=ChatContext()
141+
)
142+
function_session.ctx = function_session.ctx.add(
143+
Message("user", "Fetch total views for the IBM video with ID 456789123.")
144+
).add(hallucinated_output)
145+
146+
print("Testing response with function call hallucination...")
147+
function_valid = function_session.validate([function_guardian])
148+
print(f"Function calls are valid: {function_valid[0]._result}")
149+
if function_valid[0]._reason:
150+
print(f"Function call feedback: {function_valid[0]._reason}...")
32151

33-
check_results = m.validate([guardian])
34-
# Should return [False] meaning that input and output contain harm and it's NOT safe to return result
35-
print(f"Guardian check results: {check_results}")
152+
print("\n=== GuardianCheck Demo Complete ===")
Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
"""Example of using GuardianCheck with HuggingFace backend for direct model inference
2+
3+
This example shows how to reuse the Guardian backend across multiple validators
4+
to avoid reloading the model multiple times.
5+
"""
6+
7+
from mellea import MelleaSession
8+
from mellea.backends import model_ids
9+
from mellea.backends.ollama import OllamaModelBackend
10+
from mellea.backends.huggingface import LocalHFBackend
11+
from mellea.stdlib.base import ChatContext, ModelOutputThunk, ModelToolCall
12+
from mellea.stdlib.chat import Message
13+
from mellea.stdlib.safety.guardian import GuardianCheck, GuardianRisk
14+
15+
print("=== GuardianCheck HuggingFace Backend Example ===")
16+
17+
# Create a single reusable session with Ollama backend for test content generation
18+
# We'll reuse this session and just reset its context for each test
19+
session = MelleaSession(OllamaModelBackend(model_ids.DEEPSEEK_R1_8B), ctx=ChatContext())
20+
21+
# Create a single shared HuggingFace backend for Guardian (loads model once)
22+
print("Loading Granite Guardian model (this happens only once)...")
23+
shared_guardian_backend = LocalHFBackend(model_id="ibm-granite/granite-guardian-3.3-8b")
24+
print(f"Loaded backend: {shared_guardian_backend.model_id}\n")
25+
26+
# Create Guardian validators that share the backend (no model reloading!)
27+
print("Creating harm guardian...")
28+
harm_guardian = GuardianCheck(
29+
GuardianRisk.HARM, thinking=True, backend=shared_guardian_backend
30+
)
31+
32+
# Test 1: Harmful content detection
33+
print("\n=== Test 1: Harmful Content Detection ===")
34+
print("Testing content: 'You are very ugly and stupid.'")
35+
36+
# Reset context for this test
37+
session.ctx = (
38+
ChatContext()
39+
.add(Message("user", "Hello"))
40+
.add(Message("assistant", "You are very ugly and stupid."))
41+
)
42+
43+
validation_result = session.validate([harm_guardian])
44+
print(f"Content is safe: {validation_result[0]._result}")
45+
print(f"Guardian detected harm: {not validation_result[0]._result}")
46+
47+
if validation_result[0]._reason:
48+
print(f"\nGuardian feedback:")
49+
print(validation_result[0]._reason[:200] + "...")
50+
51+
# Test 2: Groundedness detection
52+
print("\n=== Test 2: Groundedness Detection ===")
53+
context_text = (
54+
"Python is a high-level programming language created by Guido van Rossum in 1991."
55+
)
56+
57+
# Create groundedness guardian with context (reuse shared backend)
58+
print("Creating groundedness guardian...")
59+
groundedness_guardian = GuardianCheck(
60+
GuardianRisk.GROUNDEDNESS,
61+
thinking=False,
62+
context_text=context_text,
63+
backend=shared_guardian_backend,
64+
)
65+
66+
# Reset context with ungrounded response
67+
session.ctx = (
68+
ChatContext()
69+
.add(Message("user", "Who created Python?"))
70+
.add(
71+
Message(
72+
"assistant",
73+
"Python was created by Dennis Ritchie in 1972 for use in Unix systems.",
74+
)
75+
)
76+
)
77+
78+
groundedness_valid = session.validate([groundedness_guardian])
79+
print(f"Response is grounded: {groundedness_valid[0]._result}")
80+
if groundedness_valid[0]._reason:
81+
print(f"Groundedness feedback: {groundedness_valid[0]._reason[:200]}...")
82+
83+
# Test 3: Function call validation
84+
print("\n=== Test 3: Function Call Validation ===")
85+
86+
tools = [
87+
{
88+
"name": "get_weather",
89+
"description": "Gets weather for a location",
90+
"parameters": {"location": {"description": "City name", "type": "string"}},
91+
}
92+
]
93+
94+
# Create function call guardian (reuse shared backend)
95+
print("Creating function call guardian...")
96+
function_guardian = GuardianCheck(
97+
GuardianRisk.FUNCTION_CALL,
98+
thinking=False,
99+
tools=tools,
100+
backend=shared_guardian_backend,
101+
)
102+
103+
104+
# User asks for weather but model calls wrong function
105+
def dummy_func(**kwargs):
106+
pass
107+
108+
109+
hallucinated_tool_calls = {
110+
"get_stock_price": ModelToolCall(
111+
name="get_stock_price", func=dummy_func, args={"symbol": "AAPL"}
112+
)
113+
}
114+
115+
hallucinated_output = ModelOutputThunk(
116+
value="Let me get the weather for you.", tool_calls=hallucinated_tool_calls
117+
)
118+
119+
# Reset context with hallucinated function call
120+
session.ctx = (
121+
ChatContext()
122+
.add(Message("user", "What's the weather in Boston?"))
123+
.add(hallucinated_output)
124+
)
125+
126+
function_valid = session.validate([function_guardian])
127+
print(f"Function calls are valid: {function_valid[0]._result}")
128+
if function_valid[0]._reason:
129+
print(f"Function call feedback: {function_valid[0]._reason[:200]}...")
130+
131+
print("\n=== HuggingFace Guardian Demo Complete ===")

0 commit comments

Comments
 (0)