|
1 | 1 | # Standard |
2 | 2 | from typing import Optional |
3 | 3 | import asyncio |
4 | | -import copy |
5 | 4 |
|
6 | 5 | # Third Party |
7 | 6 | from fastapi import Request |
@@ -45,33 +44,21 @@ def __post_process_result(self, responses, scores, detection_type): |
45 | 44 | # this is guided by the prompt definition of the model, so we expect llama_guard to adhere to it |
46 | 45 | # atleast for Llama-Guard-3 (latest at the time of writing) |
47 | 46 |
|
48 | | - # NOTE: The concept of "choice" doesn't exist for content type detector API, so |
49 | | - # we will essentially flatten out the responses, so different categories in 1 choice |
50 | | - # will also look like another choice. |
| 47 | + # In this function, we will basically remove those "safety" category from output and later on |
| 48 | + # move them to evidences. |
51 | 49 |
|
52 | 50 | new_choices = [] |
53 | 51 | new_scores = [] |
54 | 52 |
|
| 53 | + # NOTE: we are flattening out choices here as different categories |
55 | 54 | for i, choice in enumerate(responses.choices): |
56 | 55 | content = choice.message.content |
57 | 56 | if self.UNSAFE_TOKEN in content: |
58 | | - # We will create multiple results for each unsafe category |
59 | | - # in addition to "unsafe" as a category itself |
60 | | - # NOTE: need to deepcopy, otherwise, choice will get overwritten |
61 | | - unsafe_choice = copy.deepcopy(choice) |
62 | | - unsafe_choice.message.content = self.UNSAFE_TOKEN |
63 | | - |
64 | | - new_choices.append(unsafe_choice) |
| 57 | + # Reason for reassigning the content: |
| 58 | + # We want to remove the safety category from the content |
| 59 | + choice.message.content = self.UNSAFE_TOKEN |
| 60 | + new_choices.append(choice) |
65 | 61 | new_scores.append(scores[i]) |
66 | | - |
67 | | - # Fetch categories as the last line in the response available in csv format |
68 | | - for category in content.splitlines()[-1].split(","): |
69 | | - category_choice = copy.deepcopy(choice) |
70 | | - category_choice.message.content = category |
71 | | - new_choices.append(category_choice) |
72 | | - # NOTE: currently using same score as "unsafe" |
73 | | - # but we need to see if we can revisit this to get better score |
74 | | - new_scores.append(scores[i]) |
75 | 62 | else: |
76 | 63 | # "safe" case |
77 | 64 | new_choices.append(choice) |
@@ -122,16 +109,16 @@ async def content_analysis( |
122 | 109 |
|
123 | 110 | # If there is any error, return that otherwise, return the whole response |
124 | 111 | # properly formatted. |
125 | | - categorized_results = [] |
| 112 | + processed_result = [] |
126 | 113 | for result in results: |
127 | 114 | # NOTE: we are only sending 1 of the error results |
128 | 115 | # and not every or not cumulative |
129 | 116 | if isinstance(result, ErrorResponse): |
130 | 117 | return result |
131 | 118 | else: |
132 | 119 | # Process results to split out safety categories into separate objects |
133 | | - categorized_results.append(self.__post_process_result(*result)) |
| 120 | + processed_result.append(self.__post_process_result(*result)) |
134 | 121 |
|
135 | 122 | return ContentsDetectionResponse.from_chat_completion_response( |
136 | | - categorized_results, request.contents |
| 123 | + processed_result, request.contents |
137 | 124 | ) |
0 commit comments