Skip to content

Commit 7d796e9

Browse files
committed
add 'other' to invalid topics in every case
1 parent 084cf0c commit 7d796e9

File tree

2 files changed

+4
-3
lines changed

2 files changed

+4
-3
lines changed

guardrails/validators/on_topic.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,9 @@ def get_topic_zero_shot(
220220
score = result["scores"][0] # type: ignore
221221
return topic, score # type: ignore
222222

223-
def validate(self, value: str, metadata: Dict[str, Any]) -> ValidationResult:
223+
def validate(
224+
self, value: str, metadata: Optional[Dict[str, Any]]
225+
) -> ValidationResult:
224226
valid_topics = set(self._valid_topics)
225227
invalid_topics = set(self._invalid_topics)
226228

@@ -236,7 +238,7 @@ def validate(self, value: str, metadata: Dict[str, Any]) -> ValidationResult:
236238

237239
# Add 'other' to the invalid topics list
238240
if "other" not in invalid_topics:
239-
self._invalid_topics.append("other")
241+
invalid_topics.add("other")
240242

241243
# Combine valid and invalid topics
242244
candidate_topics = valid_topics.union(invalid_topics)

tests/integration_tests/validators/test_on_topic.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ def test_validate_invalid_topic_cpu_disable_llm(self):
2424
valid_topics=["sports", "politics"],
2525
disable_classifier=False,
2626
disable_llm=True,
27-
model_threshold=0.6,
2827
)
2928
text = "This is an article about music."
3029
expected_result = FailResult(error_message="Most relevant topic is other.")

0 commit comments

Comments
 (0)