|
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | +import typing as t |
| 4 | +from dataclasses import dataclass, field |
| 5 | + |
| 6 | +import numpy as np |
| 7 | +from pydantic import BaseModel, Field |
| 8 | + |
| 9 | +from ragas.dataset_schema import SingleTurnSample |
| 10 | +from ragas.metrics.base import MetricType, MetricWithLLM, SingleTurnMetric |
| 11 | +from ragas.prompt import ImageTextPrompt |
| 12 | + |
| 13 | +if t.TYPE_CHECKING: |
| 14 | + from langchain_core.callbacks import Callbacks |
| 15 | + |
| 16 | + |
| 17 | +class RelevanceInput(BaseModel): |
| 18 | + user_input: str = Field(description="user input") |
| 19 | + response: str = Field(description="response from AI") |
| 20 | + retrieved_contexts: list[str] = Field(description="contexts retrieved from the LLM") |
| 21 | + |
| 22 | + def to_string_list(self): |
| 23 | + return [ |
| 24 | + f"Question: {self.user_input}", |
| 25 | + f"Response: {self.response}", |
| 26 | + "retrieved_contexts: ", |
| 27 | + ] + self.retrieved_contexts |
| 28 | + |
| 29 | + |
| 30 | +class RelevanceOutput(BaseModel): |
| 31 | + relevance: bool = Field(description="boolean indicating if request was relevance") |
| 32 | + |
| 33 | + |
| 34 | +class MultiModalRelevancePrompt(ImageTextPrompt[RelevanceInput, RelevanceOutput]): |
| 35 | + # refer https://github.com/run-llama/llama_index/blob/main/llama-index-core/llama_index/core/evaluation/multi_modal/relevancy.py |
| 36 | + instruction = """ |
| 37 | +Your task is to evaluate if the response for the query is in line with the images and textual context information provided. |
| 38 | +You have two options to answer. Either True / False. |
| 39 | +Answer - True, if the response for the query is in line with context information otherwise False. |
| 40 | +""" |
| 41 | + input_model = RelevanceInput |
| 42 | + output_model = RelevanceOutput |
| 43 | + examples = [ |
| 44 | + ( |
| 45 | + RelevanceInput( |
| 46 | + user_input="What is the primary ingredient in a traditional Margherita pizza?", |
| 47 | + response="The primary ingredients in a Margherita pizza are tomatoes, mozzarella cheese, and fresh basil.", |
| 48 | + retrieved_contexts=[ |
| 49 | + "A traditional Margherita pizza consists of a thin crust.", |
| 50 | + "The main toppings include tomatoes, mozzarella cheese, fresh basil, salt, and olive oil.", |
| 51 | + "It is one of the simplest and most classic types of pizza.", |
| 52 | + ], |
| 53 | + ), |
| 54 | + RelevanceOutput(relevance=True), |
| 55 | + ), |
| 56 | + ( |
| 57 | + RelevanceInput( |
| 58 | + user_input="Who won the Best Actor award at the Oscars in 2021?", |
| 59 | + response="The Best Actor award in 2021 was won by Leonardo DiCaprio.", |
| 60 | + retrieved_contexts=[ |
| 61 | + "The 93rd Academy Awards were held in 2021.", |
| 62 | + "Anthony Hopkins won the Best Actor award for his role in 'The Father'.", |
| 63 | + "The event was unique due to COVID-19 restrictions.", |
| 64 | + ], |
| 65 | + ), |
| 66 | + RelevanceOutput(relevance=False), |
| 67 | + ), |
| 68 | + ] |
| 69 | + |
| 70 | + |
| 71 | +@dataclass |
| 72 | +class MultiModalRelevance(MetricWithLLM, SingleTurnMetric): |
| 73 | + name: str = "relevance_rate" # type: ignore |
| 74 | + _required_columns: t.Dict[MetricType, t.Set[str]] = field( |
| 75 | + default_factory=lambda: { |
| 76 | + MetricType.SINGLE_TURN: { |
| 77 | + "user_input", |
| 78 | + "response", |
| 79 | + "retrieved_contexts", |
| 80 | + } |
| 81 | + } |
| 82 | + ) |
| 83 | + relevance_prompt: ImageTextPrompt = MultiModalRelevancePrompt() |
| 84 | + |
| 85 | + async def _ascore(self, row: t.Dict, callbacks: Callbacks) -> float: |
| 86 | + prompt_input = RelevanceInput( |
| 87 | + user_input=row["user_input"], |
| 88 | + response=row["response"], |
| 89 | + retrieved_contexts=row["retrieved_contexts"], |
| 90 | + ) |
| 91 | + assert self.llm is not None, "LLM is not set" |
| 92 | + prompt_response = await self.relevance_prompt.generate( |
| 93 | + data=prompt_input, llm=self.llm, callbacks=callbacks |
| 94 | + ) |
| 95 | + if prompt_response is None: |
| 96 | + return np.nan |
| 97 | + return float(prompt_response.relevance) |
| 98 | + |
| 99 | + async def _single_turn_ascore( |
| 100 | + self, sample: SingleTurnSample, callbacks: Callbacks |
| 101 | + ) -> float: |
| 102 | + row = sample.to_dict() |
| 103 | + return await self._ascore(row, callbacks) |
| 104 | + |
| 105 | + |
| 106 | +multimodal_relevance = MultiModalRelevance() |
0 commit comments