diff --git a/docs/examples/safety.py/guardian.py b/docs/examples/safety.py/guardian.py index 7711df6f..bf31e94e 100644 --- a/docs/examples/safety.py/guardian.py +++ b/docs/examples/safety.py/guardian.py @@ -1,35 +1,152 @@ -"""Example of using the Guardian Requirement.""" +"""Example of using the Enhanced Guardian Requirement with Granite Guardian 3.3 8B""" from mellea import MelleaSession from mellea.backends import model_ids from mellea.backends.ollama import OllamaModelBackend -from mellea.stdlib.base import ContextTurn, ModelOutputThunk +from mellea.stdlib.base import ContextTurn, ModelOutputThunk, ChatContext from mellea.stdlib.chat import Message from mellea.stdlib.safety.guardian import GuardianCheck, GuardianRisk -# Start with instantiating a GuardianCheck. -guardian = GuardianCheck(risk=GuardianRisk.HARM) +# Enhanced GuardianCheck with Granite Guardian 3.3 8B support +print("=== Enhanced GuardianCheck Examples ===") + +# Basic harm checking with new features +guardian = GuardianCheck(GuardianRisk.HARM, thinking=True, backend_type="ollama") +print(f"Guardian configured: {guardian._risk} with thinking mode: {guardian._thinking}") + +# Try different risk types with new 3.3 8B model +jailbreak_guardian = GuardianCheck(GuardianRisk.JAILBREAK, thinking=True) +print(f"Jailbreak guardian: {jailbreak_guardian._risk}") + +# Custom risk criteria example +custom_guardian = GuardianCheck( + custom_criteria="Check for inappropriate content in educational context" +) +print(f"Custom guardian: {custom_guardian.get_effective_risk()}") + +print("\n=== Test 1: Normal Content ===") +# Run a query against an Ollama model with ChatContext to support insert_turn +m = MelleaSession(OllamaModelBackend(model_ids.DEEPSEEK_R1_8B), ctx=ChatContext()) +m.chat("Write a professional email to your colleague. Use less than 50 words.") -print("\n Test 1\n") -# Run a query against an Ollama model -m = MelleaSession(OllamaModelBackend(model_ids.DEEPSEEK_R1_8B)) -m.chat("Write an email to your colleague. Use less than 100 words.") -print(f"Full context: {m.ctx}") # Validate on the last turn of the conversation valid = m.validate([guardian]) -# Should return [True] meaning that input and output are free of harm and it's safe to return result -print(f"Guardian result: {valid}") +print(f"Content is safe: {valid[0]._result}") +if valid[0]._reason: + print(f"Guardian feedback: {valid[0]._reason}...") + +print("\n=== Test 2: Harmful Content ===") +# Create harmful conversation and add to context +m.ctx = m.ctx.add(Message("user", "Hello")).add( + Message("assistant", "You are very ugly and stupid.") +) + +valid = m.validate([guardian]) +print(f"Content is safe: {valid[0]._result}") +if valid[0]._reason: + print(f"Guardian feedback: {valid[0]._reason}...") + +print("\n=== Test 3: Enhanced Features ===") + +# Create multiple guardians with different configurations +guardians = [ + GuardianCheck(GuardianRisk.HARM, thinking=True), + GuardianCheck(GuardianRisk.JAILBREAK, thinking=True), + GuardianCheck(GuardianRisk.SOCIAL_BIAS), + GuardianCheck(custom_criteria="Check for financial advice"), +] + +print(f"Available risk types ({len(GuardianCheck.get_available_risks())} total):") +for risk in GuardianCheck.get_available_risks(): # Show first 5 + print(f" - {risk}") +print(" ...") -# Reproducing a harm scenario and run the validation again -print("\n Test 2\n") +print(f"\nConfigured guardians: {len(guardians)} total") + +# Show Ollama backend configuration +ollama_guardian = GuardianCheck(GuardianRisk.HARM, backend_type="ollama") +print(f" Ollama backend: {ollama_guardian._backend.model_version}") + +print("\n=== Test 4: Groundedness Detection ===") +# Test groundedness - detecting when responses lack factual grounding +context_text = "One significant part of treaty making is that signing a treaty implies recognition that the other side is a sovereign state and that the agreement being considered is enforceable under international law. Hence, nations can be very careful about terming an agreement to be a treaty. For example, within the United States, agreements between states are compacts and agreements between states and the federal government or between agencies of the government are memoranda of understanding." + +groundedness_guardian = GuardianCheck( + GuardianRisk.GROUNDEDNESS, + thinking=True, + backend_type="ollama", + context_text=context_text, +) -# create a mean conversation and add to context -m.ctx = m.ctx.add(Message("user", "Hello. ")).add( - ModelOutputThunk("You are very ugly.") +# Create a response that makes ungrounded claims relative to provided context +groundedness_session = MelleaSession( + OllamaModelBackend(model_ids.DEEPSEEK_R1_8B), ctx=ChatContext() ) -# show last turn in chat -print(f"Context: {m.ctx.last_turn()}") +groundedness_session.ctx = groundedness_session.ctx.add( + Message("user", "What is the history of treaty making?") +).add( + Message( + "assistant", + "Treaty making began in ancient Rome when Julius Caesar invented the concept in 44 BC. The first treaty was signed between Rome and the Moon people, establishing trade routes through space.", + ) +) + +print("Testing response with ungrounded claims...") +groundedness_valid = groundedness_session.validate([groundedness_guardian]) +print(f"Response is grounded: {groundedness_valid[0]._result}") +if groundedness_valid[0]._reason: + print(f"Groundedness feedback: {groundedness_valid[0]._reason}...") + +print("\n=== Test 5: Function Call Hallucination Detection ===") +# Test function calling hallucination using IBM video example +from mellea.stdlib.base import ModelOutputThunk, ModelToolCall + +tools = [ + { + "name": "views_list", + "description": "Fetches total views for a specified IBM video using the given API.", + "parameters": { + "video_id": { + "description": "The ID of the IBM video.", + "type": "int", + "default": "7178094165614464282", + } + }, + } +] + +function_guardian = GuardianCheck( + GuardianRisk.FUNCTION_CALL, thinking=True, backend_type="ollama", tools=tools +) + + +# User asks for views but assistant calls wrong function (comments_list instead of views_list) +# Create a proper ModelOutputThunk with tool_calls +def dummy_func(**kwargs): + pass + + +hallucinated_tool_calls = { + "comments_list": ModelToolCall( + name="comments_list", func=dummy_func, args={"video_id": 456789123, "count": 15} + ) +} + +hallucinated_output = ModelOutputThunk( + value="I'll fetch the views for you.", tool_calls=hallucinated_tool_calls +) + +function_session = MelleaSession( + OllamaModelBackend(model_ids.DEEPSEEK_R1_8B), ctx=ChatContext() +) +function_session.ctx = function_session.ctx.add( + Message("user", "Fetch total views for the IBM video with ID 456789123.") +).add(hallucinated_output) + +print("Testing response with function call hallucination...") +function_valid = function_session.validate([function_guardian]) +print(f"Function calls are valid: {function_valid[0]._result}") +if function_valid[0]._reason: + print(f"Function call feedback: {function_valid[0]._reason}...") -check_results = m.validate([guardian]) -# Should return [False] meaning that input and output contain harm and it's NOT safe to return result -print(f"Guardian check results: {check_results}") +print("\n=== GuardianCheck Demo Complete ===") diff --git a/docs/examples/safety.py/guardian_huggingface.py b/docs/examples/safety.py/guardian_huggingface.py new file mode 100644 index 00000000..3cc3d507 --- /dev/null +++ b/docs/examples/safety.py/guardian_huggingface.py @@ -0,0 +1,131 @@ +"""Example of using GuardianCheck with HuggingFace backend for direct model inference + +This example shows how to reuse the Guardian backend across multiple validators +to avoid reloading the model multiple times. +""" + +from mellea import MelleaSession +from mellea.backends import model_ids +from mellea.backends.ollama import OllamaModelBackend +from mellea.backends.huggingface import LocalHFBackend +from mellea.stdlib.base import ChatContext, ModelOutputThunk, ModelToolCall +from mellea.stdlib.chat import Message +from mellea.stdlib.safety.guardian import GuardianCheck, GuardianRisk + +print("=== GuardianCheck HuggingFace Backend Example ===") + +# Create a single reusable session with Ollama backend for test content generation +# We'll reuse this session and just reset its context for each test +session = MelleaSession(OllamaModelBackend(model_ids.DEEPSEEK_R1_8B), ctx=ChatContext()) + +# Create a single shared HuggingFace backend for Guardian (loads model once) +print("Loading Granite Guardian model (this happens only once)...") +shared_guardian_backend = LocalHFBackend(model_id="ibm-granite/granite-guardian-3.3-8b") +print(f"Loaded backend: {shared_guardian_backend.model_id}\n") + +# Create Guardian validators that share the backend (no model reloading!) +print("Creating harm guardian...") +harm_guardian = GuardianCheck( + GuardianRisk.HARM, thinking=True, backend=shared_guardian_backend +) + +# Test 1: Harmful content detection +print("\n=== Test 1: Harmful Content Detection ===") +print("Testing content: 'You are very ugly and stupid.'") + +# Reset context for this test +session.ctx = ( + ChatContext() + .add(Message("user", "Hello")) + .add(Message("assistant", "You are very ugly and stupid.")) +) + +validation_result = session.validate([harm_guardian]) +print(f"Content is safe: {validation_result[0]._result}") +print(f"Guardian detected harm: {not validation_result[0]._result}") + +if validation_result[0]._reason: + print(f"\nGuardian feedback:") + print(validation_result[0]._reason[:200] + "...") + +# Test 2: Groundedness detection +print("\n=== Test 2: Groundedness Detection ===") +context_text = ( + "Python is a high-level programming language created by Guido van Rossum in 1991." +) + +# Create groundedness guardian with context (reuse shared backend) +print("Creating groundedness guardian...") +groundedness_guardian = GuardianCheck( + GuardianRisk.GROUNDEDNESS, + thinking=False, + context_text=context_text, + backend=shared_guardian_backend, +) + +# Reset context with ungrounded response +session.ctx = ( + ChatContext() + .add(Message("user", "Who created Python?")) + .add( + Message( + "assistant", + "Python was created by Dennis Ritchie in 1972 for use in Unix systems.", + ) + ) +) + +groundedness_valid = session.validate([groundedness_guardian]) +print(f"Response is grounded: {groundedness_valid[0]._result}") +if groundedness_valid[0]._reason: + print(f"Groundedness feedback: {groundedness_valid[0]._reason[:200]}...") + +# Test 3: Function call validation +print("\n=== Test 3: Function Call Validation ===") + +tools = [ + { + "name": "get_weather", + "description": "Gets weather for a location", + "parameters": {"location": {"description": "City name", "type": "string"}}, + } +] + +# Create function call guardian (reuse shared backend) +print("Creating function call guardian...") +function_guardian = GuardianCheck( + GuardianRisk.FUNCTION_CALL, + thinking=False, + tools=tools, + backend=shared_guardian_backend, +) + + +# User asks for weather but model calls wrong function +def dummy_func(**kwargs): + pass + + +hallucinated_tool_calls = { + "get_stock_price": ModelToolCall( + name="get_stock_price", func=dummy_func, args={"symbol": "AAPL"} + ) +} + +hallucinated_output = ModelOutputThunk( + value="Let me get the weather for you.", tool_calls=hallucinated_tool_calls +) + +# Reset context with hallucinated function call +session.ctx = ( + ChatContext() + .add(Message("user", "What's the weather in Boston?")) + .add(hallucinated_output) +) + +function_valid = session.validate([function_guardian]) +print(f"Function calls are valid: {function_valid[0]._result}") +if function_valid[0]._reason: + print(f"Function call feedback: {function_valid[0]._reason[:200]}...") + +print("\n=== HuggingFace Guardian Demo Complete ===") diff --git a/docs/examples/safety.py/repair_with_guardian.py b/docs/examples/safety.py/repair_with_guardian.py new file mode 100644 index 00000000..c2c1d20a --- /dev/null +++ b/docs/examples/safety.py/repair_with_guardian.py @@ -0,0 +1,107 @@ +""" +RepairTemplateStrategy Example with Actual Function Call Validation +Demonstrates how RepairTemplateStrategy repairs responses using actual function calls. +""" + +from mellea import MelleaSession +from mellea.backends.ollama import OllamaModelBackend +from mellea.stdlib.safety.guardian import GuardianCheck, GuardianRisk +from mellea.stdlib.sampling import RepairTemplateStrategy + + +def demo_repair_with_actual_function_calling(): + """Demonstrate RepairTemplateStrategy with actual function calling and Guardian validation. + + Note: This demo uses an intentionally misconfigured system prompt to force an initial error, + demonstrating how Guardian provides detailed repair feedback that helps the model correct itself. + """ + print("=== Guardian Repair Demo ===\n") + + # Use Llama3.2 which supports function calling + m = MelleaSession(OllamaModelBackend("llama3.2")) + + # Simple function for stock price + def get_stock_price(symbol: str) -> str: + """Gets current stock price for a given symbol. Symbol must be a valid stock ticker (3-5 uppercase letters).""" + return f"Stock price for {symbol}: $150.25" + + # Tool schema - Guardian validates against this + tool_schemas = [ + { + "name": "get_stock_price", + "description": "Gets current stock price for a given symbol. Symbol must be a valid stock ticker (3-5 uppercase letters).", + "parameters": { + "symbol": { + "description": "The stock symbol to get price for (must be 3-5 uppercase letters like TSLA, AAPL)", + "type": "string", + } + }, + } + ] + + # Guardian validates function calls against tool schema + guardian = GuardianCheck( + GuardianRisk.FUNCTION_CALL, thinking=True, tools=tool_schemas + ) + + test_prompt = "What's the price of Tesla stock?" + print(f"Prompt: {test_prompt}\n") + + result = m.instruct( + test_prompt, + requirements=[guardian], + strategy=RepairTemplateStrategy(loop_budget=3), + return_sampling_results=True, + model_options={ + "temperature": 0.7, + "seed": 789, + "tools": [get_stock_price], + # Intentionally misconfigured to demonstrate repair + "system": "When users ask about stock prices, use the full company name as the symbol parameter. For example, use 'Tesla Motors' instead of 'TSLA'.", + }, + tool_calls=True, + ) + + # Show repair process + for attempt_num, (generation, validations) in enumerate( + zip(result.sample_generations, result.sample_validations), 1 + ): + print(f"\nAttempt {attempt_num}:") + + # Show what was sent to the model + if ( + hasattr(result, "sample_actions") + and result.sample_actions + and attempt_num <= len(result.sample_actions) + ): + action = result.sample_actions[attempt_num - 1] + if hasattr(m.backend, "formatter"): + try: + rendered = m.backend.formatter.print(action) + print(f" Instruction sent to model:") + print(f" ---") + print(f" {rendered}") + print(f" ---") + except Exception: + pass + + # Show function calls made + if hasattr(generation, "tool_calls") and generation.tool_calls: + for name, tool_call in generation.tool_calls.items(): + print(f" Function: {name}({tool_call.args})") + + # Show validation results + for req_item, validation in validations: + status = "PASS" if validation.as_bool() else "FAIL" + print(f" Status: {status}") + + print(f"\n{'=' * 60}") + print( + f"Result: {'SUCCESS' if result.success else 'FAILED'} after {len(result.sample_generations)} attempt(s)" + ) + print(f"{'=' * 60}") + return result + + +if __name__ == "__main__": + demo_repair_with_actual_function_calling() diff --git a/mellea/backends/huggingface.py b/mellea/backends/huggingface.py index 598e837f..f09b4a04 100644 --- a/mellea/backends/huggingface.py +++ b/mellea/backends/huggingface.py @@ -380,12 +380,16 @@ def _generate_from_context_standard( # Create a separate thread to handle the processing. Make it awaitable # for non-streaming cases and to get the final output. # Details: https://huggingface.co/docs/transformers/en/internal/generation_utils#transformers.AsyncTextIteratorStreamer + + # Filter out chat template-only options before passing to generate() + generate_options = self._filter_chat_template_only_options(model_options) + chat_response = asyncio.to_thread( self._model.generate, # type: ignore input_ids, return_dict_in_generate=True, output_scores=True, - **self._make_backend_specific_and_remove(model_options), + **self._make_backend_specific_and_remove(generate_options), **streaming_kwargs, # type: ignore **format_kwargs, # type: ignore ) @@ -672,6 +676,26 @@ def _make_backend_specific_and_remove( ) return ModelOption.remove_special_keys(backend_specific) + def _filter_chat_template_only_options( + self, model_options: dict[str, Any] + ) -> dict[str, Any]: + """Remove options that are only for apply_chat_template, not for generate(). + + Args: + model_options: the model_options for this call + + Returns: + a new dict without chat template-specific options + """ + # Options that should only go to apply_chat_template, not generate() + chat_template_only = { + "guardian_config", + "think", + "add_generation_prompt", + "documents", + } + return {k: v for k, v in model_options.items() if k not in chat_template_only} + def _extract_model_tool_requests( self, tools: dict[str, Callable], decoded_result: str ) -> dict[str, ModelToolCall] | None: diff --git a/mellea/stdlib/safety/guardian.py b/mellea/stdlib/safety/guardian.py index 3a6266b5..b2d4639f 100644 --- a/mellea/stdlib/safety/guardian.py +++ b/mellea/stdlib/safety/guardian.py @@ -1,156 +1,341 @@ -"""Risk checking with Guardian models.""" +"""Risk checking with Granite Guardian models via existing backends.""" -import torch -from transformers import AutoModelForCausalLM, AutoTokenizer +from enum import Enum +from typing import Literal +from mellea.backends import Backend, BaseModelSubclass from mellea.helpers.fancy_logger import FancyLogger -from mellea.stdlib.base import CBlock, Context +from mellea.stdlib.base import CBlock, ChatContext, Context, ModelOutputThunk from mellea.stdlib.chat import Message -from mellea.stdlib.requirement import Requirement +from mellea.stdlib.instruction import Instruction +from mellea.stdlib.requirement import Requirement, ValidationResult -class GuardianRisk: - """Risk definitions as defined in https://github.com/ibm-granite/granite-guardian/blob/main/cookbooks/granite-guardian-3.2/quick_start_vllm.ipynb .""" +class GuardianRisk(Enum): + """Risk definitions for Granite Guardian models. + + Based on https://github.com/ibm-granite/granite-guardian but updated for 3.3 8B support. + """ HARM = "harm" GROUNDEDNESS = "groundedness" PROFANITY = "profanity" ANSWER_RELEVANCE = "answer_relevance" + JAILBREAK = "jailbreak" + FUNCTION_CALL = "function_call" + SOCIAL_BIAS = "social_bias" + VIOLENCE = "violence" + SEXUAL_CONTENT = "sexual_content" + UNETHICAL_BEHAVIOR = "unethical_behavior" + + @classmethod + def get_available_risks(cls) -> list[str]: + """Get list of all available risk types.""" + return [risk.value for risk in cls] + + +BackendType = Literal["huggingface", "ollama"] + + +def _parse_safety_result(result: str | None, logger) -> str: + """Parse the model output to a Guardian label: Yes/No/Failed. + + Guardian returns yes/no between and tags. + Handles case variations (Yes/yes, No/no) and whitespace. + """ + if not result: + logger.warning("Guardian returned empty result") + return "Failed" + + s = str(result).lower() + + # Extract from yes/no tags + if "" in s and "" in s: + score = s.split("")[1].split("")[0].strip() + if score == "yes": + return "Yes" + if score == "no": + return "No" + + logger.warning(f"Could not parse safety result: {result}") + return "Failed" class GuardianCheck(Requirement): - """A Check for Risks based on local huggingface backend.""" + """Enhanced risk checking using Granite Guardian 3.3 8B with multiple backend support.""" def __init__( self, - risk: str = GuardianRisk.HARM, + risk: str | GuardianRisk | None = None, *, - model_version: str = "ibm-granite/granite-guardian-3.2-3b-a800m", + backend_type: BackendType = "ollama", + model_version: str | None = None, device: str | None = None, + ollama_url: str = "http://localhost:11434", + thinking: bool = False, + custom_criteria: str | None = None, + context_text: str | None = None, + tools: list[dict] | None = None, + backend: Backend | None = None, ): - """Initializes a GuardianModel instance with the specified risk definition, model version, and device. + """Initialize GuardianCheck using existing backends with minimal glue. Args: - risk: The risk definition to check for, defaults to GuardianRisk.HARM. - model_version: The version of the model, defaults to "ibm-granite/granite-guardian-3.2-3b-a800m". - device: The computational device to use ("cuda" for GPU, "mps" for Apple Silicon, or "cpu"), defaults to None. If not specified, the best available device will be automatically selected. + risk: The type of risk to check for (harm, jailbreak, etc.) + backend_type: Type of backend to use ("ollama" or "huggingface") + model_version: Specific model version to use + device: Device for model inference (for HuggingFace) + ollama_url: URL for Ollama server + thinking: Enable thinking/reasoning mode + custom_criteria: Custom criteria for validation + context_text: Context document for groundedness checks + tools: Tool schemas for function call validation + backend: Pre-initialized backend to reuse (avoids loading model multiple times) """ - super().__init__( - check_only=True, validation_fn=lambda c: self._guardian_validate(c) - ) - self._risk = risk - self._model_version = model_version - - # auto-device if not more specific - self._device = device - if device is None: - device_name: str = ( - "cuda" - if torch.cuda.is_available() - else "mps" - if torch.backends.mps.is_available() - else "cpu" - ) - assert device_name is not None - self._device = torch.device(device_name) # type: ignore + super().__init__(check_only=True) - @staticmethod - def _parse_output(output, input_len, tokenizer): - """Parse the output of a guardian model and determine whether if the risk is present or not. + # Handle risk specification with custom criteria priority + if custom_criteria: + # When custom_criteria is provided, risk becomes optional + if risk is None: + self._risk = "custom" # Default fallback risk identifier + elif isinstance(risk, GuardianRisk): + self._risk = risk.value + else: + self._risk = risk + else: + # When no custom_criteria, risk is required + if risk is None: + raise ValueError("Either 'risk' or 'custom_criteria' must be provided") + if isinstance(risk, GuardianRisk): + self._risk = risk.value + else: + self._risk = risk - Args: - output: The model's output containing sequences from which predictions are made. - input_len: The length of the original input sequence used for alignment with the model's output. - tokenizer: The tokenizer associated with the model, used to decode the tokens back into text. - """ - safe_token = "No" - unsafe_token = "Yes" + self._custom_criteria = custom_criteria + self._thinking = thinking + self._context_text = context_text + self._tools = tools - label = None + # Use provided backend or create a new one + if backend is not None: + self._backend = backend + # Infer backend_type from the provided backend + from mellea.backends.huggingface import LocalHFBackend + from mellea.backends.ollama import OllamaModelBackend - full_res = tokenizer.decode( - output.sequences[:, input_len + 1 :][0], skip_special_tokens=True - ).strip() - FancyLogger.get_logger().debug(f"Full: {full_res}") - confidence_level = ( - full_res.removeprefix("").removesuffix("").strip() - ) - res = tokenizer.decode( - output.sequences[:, input_len : input_len + 1][0], skip_special_tokens=True - ).strip() - FancyLogger.get_logger().debug(f"Res: {res}") - if unsafe_token.lower() == res.lower(): - label = unsafe_token - elif safe_token.lower() == res.lower(): - label = safe_token + if isinstance(backend, LocalHFBackend): + self._backend_type = "huggingface" + elif isinstance(backend, OllamaModelBackend): + self._backend_type = "ollama" + else: + # Keep the provided backend_type as fallback + self._backend_type = backend_type else: - label = "Failed" + self._backend_type = backend_type + # Choose defaults and initialize the chosen backend directly. + if model_version is None: + model_version = ( + "ibm-granite/granite-guardian-3.3-8b" + if backend_type == "huggingface" + else "ibm/granite3.3-guardian:8b" + ) - return label, confidence_level + if backend_type == "huggingface": + from mellea.backends.huggingface import LocalHFBackend - def _guardian_validate(self, ctx: Context): - """Validates the last turn of a conversation context using wrt given risk. + self._backend = LocalHFBackend(model_id=model_version) + elif backend_type == "ollama": + from mellea.backends.ollama import OllamaModelBackend - Code is adopted from https://huggingface.co/ibm-granite/granite-guardian-3.2-3b-a800m#quickstart-example + self._backend = OllamaModelBackend( + model_id=model_version, base_url=ollama_url + ) + else: + raise ValueError(f"Unsupported backend type: {backend_type}") - Args: - ctx (LegacyContext): The context object containing the last turn of the conversation. + # Provide a predictable attribute for the example to print. + try: + setattr(self._backend, "model_version", model_version) + except Exception: + pass - Returns: - bool: True if there is no identified risk, False otherwise. - """ - messages: list[dict[str, str]] = [] + self._logger = FancyLogger.get_logger() + + def get_effective_risk(self) -> str: + """Get the effective risk criteria to use for validation.""" + return self._custom_criteria if self._custom_criteria else self._risk - last_turn = ctx.last_turn() - assert last_turn is not None + @classmethod + def get_available_risks(cls) -> list[str]: + """Get list of all available standard risk types.""" + return GuardianRisk.get_available_risks() - # This requirement can handle incomplete turns with only a user message - # or only an assistant message. Handle both. - if last_turn.model_input: - user_msg = last_turn.model_input + def __deepcopy__(self, memo): + """Custom deepcopy to handle unpicklable backend objects.""" + from copy import deepcopy - # Handle the variety of possible user input. - if isinstance(user_msg, CBlock) and user_msg.value is not None: - messages.append({"role": "user", "content": user_msg.value}) - elif isinstance(user_msg, Message) and user_msg.content != "": - messages.append({"role": user_msg.role, "content": user_msg.content}) + # Create a new instance without calling __init__ + cls = self.__class__ + result = cls.__new__(cls) + memo[id(self)] = result + # Copy all attributes except the backend (which contains locks) + for k, v in self.__dict__.items(): + if k == "_backend": + # Share the backend reference instead of copying it + setattr(result, k, v) + elif k == "_logger": + # Share the logger reference + setattr(result, k, v) else: - messages.append({"role": "user", "content": str(user_msg)}) + setattr(result, k, deepcopy(v, memo)) + return result - if last_turn.output and last_turn.output.value: - messages.append({"role": "assistant", "content": last_turn.output.value}) + async def validate( + self, + backend: Backend, + ctx: Context, + *, + format: type[BaseModelSubclass] | None = None, + model_options: dict | None = None, + ) -> ValidationResult: + """Validate conversation using Granite Guardian via selected backend.""" + logger = self._logger - # Load model - model = AutoModelForCausalLM.from_pretrained( - self._model_version, device_map="auto", torch_dtype=torch.bfloat16 - ) - model.to(self._device) # type: ignore - model.eval() - - tokenizer = AutoTokenizer.from_pretrained(self._model_version) - - # Please note that the default risk definition is of `harm`. If a config is not specified, this behavior will be applied. - guardian_config = {"risk_name": self._risk} - - input_ids = tokenizer.apply_chat_template( - messages, - guardian_config=guardian_config, - add_generation_prompt=True, - return_tensors="pt", - ).to(model.device) - - input_len = input_ids.shape[1] - - with torch.no_grad(): - output = model.generate( - input_ids, - do_sample=False, - max_new_tokens=20, - return_dict_in_generate=True, - output_scores=True, + # Build a fresh chat context for the guardian model (keep it minimal). + gctx = ChatContext() + + effective_risk = self.get_effective_risk() + + # For groundedness: add doc only for Ollama; HF receives context via guardian_config + if ( + (self._risk == "groundedness" or effective_risk == "groundedness") + and self._context_text + and self._backend_type == "ollama" + ): + gctx = gctx.add(Message("user", f"Document: {self._context_text}")) + + # Try to reuse chat history directly when available. + messages = None + try: + from mellea.stdlib.chat import as_chat_history + + messages = as_chat_history(ctx) + except Exception: + messages = None + + if messages: + for m in messages: + gctx = gctx.add(m) + else: + # Fallback: build from the last turn only + last_turn = ctx.last_turn() + if last_turn is None: + logger.warning("No last turn found in context") + return ValidationResult(False, reason="No content to validate") + + if last_turn.model_input is not None: + gctx = gctx.add(last_turn.model_input) + + if last_turn.output is not None: + # For function call risk, append tool call info as text; otherwise add thunk directly. + if self._risk == "function_call" or effective_risk == "function_call": + content = last_turn.output.value or "" + tcalls = getattr(last_turn.output, "tool_calls", None) + if tcalls: + calls = [ + f"{name}({getattr(tc, 'args', {})})" + for name, tc in tcalls.items() + ] + if calls: + suffix = f" [Tool calls: {', '.join(calls)}]" + content = (content + suffix) if content else suffix + if content: + gctx = gctx.add(Message("assistant", content)) + else: + gctx = gctx.add(last_turn.output) + + # Ensure we have something to validate. + history = gctx.view_for_generation() or [] + if len(history) == 0: + logger.warning("No messages found to validate") + return ValidationResult(False, reason="No messages to validate") + + # Backend options (mapped by backends internally to their specific keys). + guardian_options: dict[str, object] = {} + if self._backend_type == "ollama": + # Ollama templates expect the risk as the system prompt + guardian_options["system"] = effective_risk + guardian_options.update( + { + "temperature": 0.0, + "num_predict": 4000 if self._thinking else 50, + "stream": False, + "think": True if self._thinking else None, + } ) + else: # huggingface + # HF chat template for Guardian expects guardian_config and (optionally) documents + guardian_cfg: dict[str, object] = {"criteria_id": effective_risk} + if self._custom_criteria: + # When using custom criteria, provide it as free-text criteria + guardian_cfg["criteria_text"] = self._custom_criteria - label, confidence = GuardianCheck._parse_output(output, input_len, tokenizer) + guardian_options.update( + { + "guardian_config": guardian_cfg, + "think": self._thinking, # Passed to apply_chat_template + "add_generation_prompt": True, # Guardian template requires a generation prompt + "max_new_tokens": 4000 if self._thinking else 50, + "stream": False, + } + ) + + # Provide documents parameter for groundedness + if self._context_text and ( + self._risk == "groundedness" or effective_risk == "groundedness" + ): + guardian_options["documents"] = [ + {"doc_id": "0", "text": self._context_text} + ] + + # Attach tools for function_call checks. + # Guardian only needs tool schemas for validation, not actual callable functions. + if ( + self._risk == "function_call" or effective_risk == "function_call" + ) and self._tools: + guardian_options["tools"] = self._tools - # valid if there is NO risk - return label == "No" + # Generate the guardian decision. + # For Ollama: add blank assistant turn to trigger generation + # For HuggingFace: use CBlock (won't be added to conversation, add_generation_prompt handles the judge role) + if self._backend_type == "ollama": + action = Message("assistant", "") + else: + # Use a CBlock for HuggingFace - it won't be added as a message + action = CBlock("") # type: ignore + + mot, _ = self._backend.generate_from_context( + action, gctx, model_options=guardian_options + ) + await mot.avalue() + + # Prefer explicit thinking if available, else try to split from output text. + trace = getattr(mot, "_thinking", None) + text = mot.value or "" + if trace is None and "" in text: + parts = text.split("") + if len(parts) > 1: + trace = parts[0].replace("", "").strip() + text = parts[1].strip() + + label = _parse_safety_result(text, logger) + is_safe = label == "No" + + reason_parts = [f"Guardian check for '{effective_risk}': {label}"] + if trace: + reason_parts.append(f"Reasoning: {trace}") + + return ValidationResult( + result=is_safe, reason="; ".join(reason_parts), thunk=mot + ) diff --git a/mellea/stdlib/sampling/base.py b/mellea/stdlib/sampling/base.py index 5ca10a23..9f374fa9 100644 --- a/mellea/stdlib/sampling/base.py +++ b/mellea/stdlib/sampling/base.py @@ -327,15 +327,25 @@ def repair( """ pa = past_actions[-1] if isinstance(pa, Instruction): - last_failed_reqs: list[Requirement] = [ - s[0] for s in past_val[-1] if not s[1] + # Get failed requirements and their detailed validation reasons + failed_items = [ + (req, val) for req, val in past_val[-1] if not val.as_bool() ] - last_failed_reqs_str = "* " + "\n* ".join( - [str(r.description) for r in last_failed_reqs] + + # Build repair feedback using ValidationResult.reason when available + repair_lines = [] + for req, validation in failed_items: + if validation.reason: + repair_lines.append(f"* {validation.reason}") + else: + # Fallback to requirement description if no reason + repair_lines.append(f"* {req.description}") + + repair_string = "The following requirements failed before:\n" + "\n".join( + repair_lines ) - return pa.copy_and_repair( - repair_string=f"The following requirements failed before:\n{last_failed_reqs_str}" - ), old_ctx + + return pa.copy_and_repair(repair_string=repair_string), old_ctx return pa, old_ctx diff --git a/test/backends/test_huggingface_tools.py b/test/backends/test_huggingface_tools.py index cc88ec0a..f78898ec 100644 --- a/test/backends/test_huggingface_tools.py +++ b/test/backends/test_huggingface_tools.py @@ -2,6 +2,7 @@ import pytest from typing_extensions import Annotated +import mellea.backends.model_ids as model_ids from mellea import MelleaSession from mellea.backends.aloras.huggingface.granite_aloras import add_granite_aloras from mellea.backends.cache import SimpleLRUCache @@ -16,7 +17,6 @@ ValidationResult, default_output_to_bool, ) -import mellea.backends.model_ids as model_ids @pytest.fixture(scope="module")