@@ -48,7 +48,9 @@ class TopicClassificationOutput(BaseModel):
4848class TopicClassificationPrompt (
4949 PydanticPrompt [TopicClassificationInput , TopicClassificationOutput ]
5050):
51- instruction = "Given a set of topics classify if the topic falls into any of the given reference topics."
51+ instruction = (
52+ "Given a set of topics classify if the topic falls into any of the given reference topics."
53+ )
5254 input_model = TopicClassificationInput
5355 output_model = TopicClassificationOutput
5456 examples = [
@@ -66,7 +68,9 @@ class TopicClassificationPrompt(
6668
6769
6870class TopicRefusedPrompt (PydanticPrompt [TopicRefusedInput , TopicRefusedOutput ]):
69- instruction : str = "Given a topic, classify if the AI refused to answer the question about the topic."
71+ instruction : str = (
72+ "Given a topic, classify if the AI refused to answer the question about the topic."
73+ )
7074 input_model = TopicRefusedInput
7175 output_model = TopicRefusedOutput
7276 examples = [
@@ -93,10 +97,10 @@ class TopicRefusedPrompt(PydanticPrompt[TopicRefusedInput, TopicRefusedOutput]):
9397 ]
9498
9599
96- class TopicExtractionPrompt (
97- PydanticPrompt [ TopicExtractionInput , TopicExtractionOutput ]
98- ):
99- instruction : str = "Given an interaction between Human, Tool and AI, extract the topics from Human's input."
100+ class TopicExtractionPrompt (PydanticPrompt [ TopicExtractionInput , TopicExtractionOutput ]):
101+ instruction : str = (
102+ "Given an interaction between Human, Tool and AI, extract the topics from Human's input."
103+ )
100104 input_model = TopicExtractionInput
101105 output_model = TopicExtractionOutput
102106 examples = [
@@ -143,14 +147,10 @@ class TopicAdherenceScore(MetricWithLLM, MultiTurnMetric):
143147 topic_classification_prompt : PydanticPrompt = TopicClassificationPrompt ()
144148 topic_refused_prompt : PydanticPrompt = TopicRefusedPrompt ()
145149
146- async def _multi_turn_ascore (
147- self , sample : MultiTurnSample , callbacks : Callbacks
148- ) -> float :
150+ async def _multi_turn_ascore (self , sample : MultiTurnSample , callbacks : Callbacks ) -> float :
149151 assert self .llm is not None , "LLM must be set"
150152 assert isinstance (sample .user_input , list ), "Sample user_input must be a list"
151- assert isinstance (
152- sample .reference_topics , list
153- ), "Sample reference_topics must be a list"
153+ assert isinstance (sample .reference_topics , list ), "Sample reference_topics must be a list"
154154 user_input = sample .pretty_repr ()
155155
156156 prompt_input = TopicExtractionInput (user_input = user_input )
@@ -166,9 +166,7 @@ async def _multi_turn_ascore(
166166 data = prompt_input , llm = self .llm , callbacks = callbacks
167167 )
168168 topic_answered_verdict .append (response .refused_to_answer )
169- topic_answered_verdict = np .array (
170- [not answer for answer in topic_answered_verdict ]
171- )
169+ topic_answered_verdict = np .array ([not answer for answer in topic_answered_verdict ])
172170
173171 prompt_input = TopicClassificationInput (
174172 reference_topics = sample .reference_topics , topics = topics
@@ -183,13 +181,13 @@ async def _multi_turn_ascore(
183181 false_negatives = sum (~ topic_answered_verdict & topic_classifications )
184182
185183 if self .mode == "precision" :
186- return true_positives / (true_positives + false_positives )
184+ return true_positives / (true_positives + false_positives + 1e-10 )
187185 elif self .mode == "recall" :
188- return true_positives / (true_positives + false_negatives )
186+ return true_positives / (true_positives + false_negatives + 1e-10 )
189187 else :
190- precision = true_positives / (true_positives + false_positives )
191- recall = true_positives / (true_positives + false_negatives )
192- return 2 * (precision * recall ) / (precision + recall )
188+ precision = true_positives / (true_positives + false_positives + 1e-10 )
189+ recall = true_positives / (true_positives + false_negatives + 1e-10 )
190+ return 2 * (precision * recall ) / (precision + recall + 1e-10 )
193191
194192 async def _ascore (self , row : t .Dict , callbacks : Callbacks ) -> float :
195193 return await self ._multi_turn_ascore (MultiTurnSample (** row ), callbacks )
0 commit comments