|
1 | | -"""Example of using the Guardian Requirement.""" |
| 1 | +"""Example of using the Enhanced Guardian Requirement with Granite Guardian 3.3 8B""" |
2 | 2 |
|
3 | 3 | from mellea import MelleaSession |
4 | 4 | from mellea.backends import model_ids |
5 | 5 | from mellea.backends.ollama import OllamaModelBackend |
6 | | -from mellea.stdlib.base import ContextTurn, ModelOutputThunk |
| 6 | +from mellea.stdlib.base import ContextTurn, ModelOutputThunk, ChatContext |
7 | 7 | from mellea.stdlib.chat import Message |
8 | 8 | from mellea.stdlib.safety.guardian import GuardianCheck, GuardianRisk |
9 | 9 |
|
10 | | -# Start with instantiating a GuardianCheck. |
11 | | -guardian = GuardianCheck(risk=GuardianRisk.HARM) |
| 10 | +# Enhanced GuardianCheck with Granite Guardian 3.3 8B support |
| 11 | +print("=== Enhanced GuardianCheck Examples ===") |
| 12 | + |
| 13 | +# Basic harm checking with new features |
| 14 | +guardian = GuardianCheck(GuardianRisk.HARM, thinking=True, backend_type="ollama") |
| 15 | +print(f"Guardian configured: {guardian._risk} with thinking mode: {guardian._thinking}") |
| 16 | + |
| 17 | +# Try different risk types with new 3.3 8B model |
| 18 | +jailbreak_guardian = GuardianCheck(GuardianRisk.JAILBREAK, thinking=True) |
| 19 | +print(f"Jailbreak guardian: {jailbreak_guardian._risk}") |
| 20 | + |
| 21 | +# Custom risk criteria example |
| 22 | +custom_guardian = GuardianCheck( |
| 23 | + custom_criteria="Check for inappropriate content in educational context" |
| 24 | +) |
| 25 | +print(f"Custom guardian: {custom_guardian.get_effective_risk()}") |
| 26 | + |
| 27 | +print("\n=== Test 1: Normal Content ===") |
| 28 | +# Run a query against an Ollama model with ChatContext to support insert_turn |
| 29 | +m = MelleaSession(OllamaModelBackend(model_ids.DEEPSEEK_R1_8B), ctx=ChatContext()) |
| 30 | +m.chat("Write a professional email to your colleague. Use less than 50 words.") |
12 | 31 |
|
13 | | -print("\n Test 1\n") |
14 | | -# Run a query against an Ollama model |
15 | | -m = MelleaSession(OllamaModelBackend(model_ids.DEEPSEEK_R1_8B)) |
16 | | -m.chat("Write an email to your colleague. Use less than 100 words.") |
17 | | -print(f"Full context: {m.ctx}") |
18 | 32 | # Validate on the last turn of the conversation |
19 | 33 | valid = m.validate([guardian]) |
20 | | -# Should return [True] meaning that input and output are free of harm and it's safe to return result |
21 | | -print(f"Guardian result: {valid}") |
| 34 | +print(f"Content is safe: {valid[0]._result}") |
| 35 | +if valid[0]._reason: |
| 36 | + print(f"Guardian feedback: {valid[0]._reason}...") |
| 37 | + |
| 38 | +print("\n=== Test 2: Harmful Content ===") |
| 39 | +# Create harmful conversation and add to context |
| 40 | +m.ctx = m.ctx.add(Message("user", "Hello")).add( |
| 41 | + Message("assistant", "You are very ugly and stupid.") |
| 42 | +) |
| 43 | + |
| 44 | +valid = m.validate([guardian]) |
| 45 | +print(f"Content is safe: {valid[0]._result}") |
| 46 | +if valid[0]._reason: |
| 47 | + print(f"Guardian feedback: {valid[0]._reason}...") |
| 48 | + |
| 49 | +print("\n=== Test 3: Enhanced Features ===") |
| 50 | + |
| 51 | +# Create multiple guardians with different configurations |
| 52 | +guardians = [ |
| 53 | + GuardianCheck(GuardianRisk.HARM, thinking=True), |
| 54 | + GuardianCheck(GuardianRisk.JAILBREAK, thinking=True), |
| 55 | + GuardianCheck(GuardianRisk.SOCIAL_BIAS), |
| 56 | + GuardianCheck(custom_criteria="Check for financial advice"), |
| 57 | +] |
| 58 | + |
| 59 | +print(f"Available risk types ({len(GuardianCheck.get_available_risks())} total):") |
| 60 | +for risk in GuardianCheck.get_available_risks(): # Show first 5 |
| 61 | + print(f" - {risk}") |
| 62 | +print(" ...") |
22 | 63 |
|
23 | | -# Reproducing a harm scenario and run the validation again |
24 | | -print("\n Test 2\n") |
| 64 | +print(f"\nConfigured guardians: {len(guardians)} total") |
| 65 | + |
| 66 | +# Show Ollama backend configuration |
| 67 | +ollama_guardian = GuardianCheck(GuardianRisk.HARM, backend_type="ollama") |
| 68 | +print(f" Ollama backend: {ollama_guardian._backend.model_version}") |
| 69 | + |
| 70 | +print("\n=== Test 4: Groundedness Detection ===") |
| 71 | +# Test groundedness - detecting when responses lack factual grounding |
| 72 | +context_text = "One significant part of treaty making is that signing a treaty implies recognition that the other side is a sovereign state and that the agreement being considered is enforceable under international law. Hence, nations can be very careful about terming an agreement to be a treaty. For example, within the United States, agreements between states are compacts and agreements between states and the federal government or between agencies of the government are memoranda of understanding." |
| 73 | + |
| 74 | +groundedness_guardian = GuardianCheck( |
| 75 | + GuardianRisk.GROUNDEDNESS, |
| 76 | + thinking=True, |
| 77 | + backend_type="ollama", |
| 78 | + context_text=context_text, |
| 79 | +) |
25 | 80 |
|
26 | | -# create a mean conversation and add to context |
27 | | -m.ctx = m.ctx.add(Message("user", "Hello. ")).add( |
28 | | - ModelOutputThunk("You are very ugly.") |
| 81 | +# Create a response that makes ungrounded claims relative to provided context |
| 82 | +groundedness_session = MelleaSession( |
| 83 | + OllamaModelBackend(model_ids.DEEPSEEK_R1_8B), ctx=ChatContext() |
29 | 84 | ) |
30 | | -# show last turn in chat |
31 | | -print(f"Context: {m.ctx.last_turn()}") |
| 85 | +groundedness_session.ctx = groundedness_session.ctx.add( |
| 86 | + Message("user", "What is the history of treaty making?") |
| 87 | +).add( |
| 88 | + Message( |
| 89 | + "assistant", |
| 90 | + "Treaty making began in ancient Rome when Julius Caesar invented the concept in 44 BC. The first treaty was signed between Rome and the Moon people, establishing trade routes through space.", |
| 91 | + ) |
| 92 | +) |
| 93 | + |
| 94 | +print("Testing response with ungrounded claims...") |
| 95 | +groundedness_valid = groundedness_session.validate([groundedness_guardian]) |
| 96 | +print(f"Response is grounded: {groundedness_valid[0]._result}") |
| 97 | +if groundedness_valid[0]._reason: |
| 98 | + print(f"Groundedness feedback: {groundedness_valid[0]._reason}...") |
| 99 | + |
| 100 | +print("\n=== Test 5: Function Call Hallucination Detection ===") |
| 101 | +# Test function calling hallucination using IBM video example |
| 102 | +from mellea.stdlib.base import ModelOutputThunk, ModelToolCall |
| 103 | + |
| 104 | +tools = [ |
| 105 | + { |
| 106 | + "name": "views_list", |
| 107 | + "description": "Fetches total views for a specified IBM video using the given API.", |
| 108 | + "parameters": { |
| 109 | + "video_id": { |
| 110 | + "description": "The ID of the IBM video.", |
| 111 | + "type": "int", |
| 112 | + "default": "7178094165614464282", |
| 113 | + } |
| 114 | + }, |
| 115 | + } |
| 116 | +] |
| 117 | + |
| 118 | +function_guardian = GuardianCheck( |
| 119 | + GuardianRisk.FUNCTION_CALL, thinking=True, backend_type="ollama", tools=tools |
| 120 | +) |
| 121 | + |
| 122 | + |
| 123 | +# User asks for views but assistant calls wrong function (comments_list instead of views_list) |
| 124 | +# Create a proper ModelOutputThunk with tool_calls |
| 125 | +def dummy_func(**kwargs): |
| 126 | + pass |
| 127 | + |
| 128 | + |
| 129 | +hallucinated_tool_calls = { |
| 130 | + "comments_list": ModelToolCall( |
| 131 | + name="comments_list", func=dummy_func, args={"video_id": 456789123, "count": 15} |
| 132 | + ) |
| 133 | +} |
| 134 | + |
| 135 | +hallucinated_output = ModelOutputThunk( |
| 136 | + value="I'll fetch the views for you.", tool_calls=hallucinated_tool_calls |
| 137 | +) |
| 138 | + |
| 139 | +function_session = MelleaSession( |
| 140 | + OllamaModelBackend(model_ids.DEEPSEEK_R1_8B), ctx=ChatContext() |
| 141 | +) |
| 142 | +function_session.ctx = function_session.ctx.add( |
| 143 | + Message("user", "Fetch total views for the IBM video with ID 456789123.") |
| 144 | +).add(hallucinated_output) |
| 145 | + |
| 146 | +print("Testing response with function call hallucination...") |
| 147 | +function_valid = function_session.validate([function_guardian]) |
| 148 | +print(f"Function calls are valid: {function_valid[0]._result}") |
| 149 | +if function_valid[0]._reason: |
| 150 | + print(f"Function call feedback: {function_valid[0]._reason}...") |
32 | 151 |
|
33 | | -check_results = m.validate([guardian]) |
34 | | -# Should return [False] meaning that input and output contain harm and it's NOT safe to return result |
35 | | -print(f"Guardian check results: {check_results}") |
| 152 | +print("\n=== GuardianCheck Demo Complete ===") |
0 commit comments