@@ -103,67 +103,59 @@ def get_config_blocks(cls) -> dict:
103
103
}
104
104
105
105
@staticmethod
106
- def _convert_last_user_message_to_guarded_text (
106
+ def _convert_consecutive_user_messages_to_guarded_text (
107
107
messages : List [AllMessageValues ], optional_params : dict
108
108
) -> List [AllMessageValues ]:
109
109
"""
110
- Convert the last user message to guarded_text type if guardrailConfig is present
111
- and no guarded_text is already present in the last user message .
110
+ Convert consecutive user messages at the end to guarded_text type if guardrailConfig is present
111
+ and no guarded_text is already present in those messages .
112
112
"""
113
113
# Check if guardrailConfig is present
114
114
if "guardrailConfig" not in optional_params :
115
115
return messages
116
116
117
- # Find the last user message
118
- last_user_message = None
119
- last_user_message_index = - 1
117
+ # Find all consecutive user messages at the end
118
+ consecutive_user_message_indices = []
120
119
for i in range (len (messages ) - 1 , - 1 , - 1 ):
121
120
if messages [i ].get ("role" ) == "user" :
122
- last_user_message = messages [ i ]
123
- last_user_message_index = i
121
+ consecutive_user_message_indices . append ( i )
122
+ else :
124
123
break
125
124
126
- if last_user_message is None :
125
+ if not consecutive_user_message_indices :
127
126
return messages
128
127
129
- # Check if the last user message already has guarded_text
130
- content = last_user_message .get ("content" , [])
131
- if isinstance (content , list ):
132
- has_guarded_text = any (
133
- isinstance (item , dict ) and item .get ("type" ) == "guarded_text"
134
- for item in content
135
- )
136
- if has_guarded_text :
137
- return messages
138
-
139
- # Convert text elements to guarded_text
140
- new_content = []
141
- for item in content :
142
- if isinstance (item , dict ) and item .get ("type" ) == "text" :
143
- new_item = {
144
- "type" : "guarded_text" ,
145
- "text" : item ["text" ]
146
- }
147
- new_content .append (new_item )
148
- else :
149
- new_content .append (item )
150
-
151
- # Create a copy of messages and update the last user message
152
- messages_copy = copy .deepcopy (messages )
153
- messages_copy [last_user_message_index ]["content" ] = new_content
154
- return messages_copy
155
- elif isinstance (content , str ):
156
- # If content is a string, convert it to guarded_text
157
- messages_copy = copy .deepcopy (messages )
158
- messages_copy [last_user_message_index ]["content" ] = [
159
- {
160
- "type" : "guarded_text" ,
161
- "text" : content
162
- }
163
- ]
164
- return messages_copy
165
-
166
- return messages
128
+ # Process each consecutive user message
129
+ messages_copy = copy .deepcopy (messages )
130
+ for user_message_index in consecutive_user_message_indices :
131
+ user_message = messages_copy [user_message_index ]
132
+ content = user_message .get ("content" , [])
133
+
134
+ if isinstance (content , list ):
135
+ has_guarded_text = any (
136
+ isinstance (item , dict ) and item .get ("type" ) == "guarded_text"
137
+ for item in content
138
+ )
139
+ if has_guarded_text :
140
+ continue # Skip this message if it already has guarded_text
141
+
142
+ # Convert text elements to guarded_text
143
+ new_content = []
144
+ for item in content :
145
+ if isinstance (item , dict ) and item .get ("type" ) == "text" :
146
+ new_item = {"type" : "guarded_text" , "text" : item ["text" ]} # type: ignore
147
+ new_content .append (new_item )
148
+ else :
149
+ new_content .append (item )
150
+
151
+ messages_copy [user_message_index ]["content" ] = new_content # type: ignore
152
+ elif isinstance (content , str ):
153
+ # If content is a string, convert it to guarded_text
154
+ messages_copy [user_message_index ]["content" ] = [ # type: ignore
155
+ {"type" : "guarded_text" , "text" : content } # type: ignore
156
+ ]
157
+
158
+ return messages_copy
167
159
168
160
@classmethod
169
161
def get_config (cls ):
@@ -832,9 +824,11 @@ async def _async_transform_request(
832
824
headers : Optional [dict ] = None ,
833
825
) -> RequestObject :
834
826
messages , system_content_blocks = self ._transform_system_message (messages )
835
-
827
+
836
828
# Convert last user message to guarded_text if guardrailConfig is present
837
- messages = self ._convert_last_user_message_to_guarded_text (messages , optional_params )
829
+ messages = self ._convert_consecutive_user_messages_to_guarded_text (
830
+ messages , optional_params
831
+ )
838
832
## TRANSFORMATION ##
839
833
840
834
_data : CommonRequestObject = self ._transform_request_helper (
@@ -888,7 +882,9 @@ def _transform_request(
888
882
messages , system_content_blocks = self ._transform_system_message (messages )
889
883
890
884
# Convert last user message to guarded_text if guardrailConfig is present
891
- messages = self ._convert_last_user_message_to_guarded_text (messages , optional_params )
885
+ messages = self ._convert_consecutive_user_messages_to_guarded_text (
886
+ messages , optional_params
887
+ )
892
888
893
889
_data : CommonRequestObject = self ._transform_request_helper (
894
890
model = model ,
@@ -1346,4 +1342,4 @@ def should_fake_stream(
1346
1342
###################################################################
1347
1343
if "ai21" in model :
1348
1344
return True
1349
- return False
1345
+ return False
0 commit comments