@@ -102,6 +102,61 @@ def get_config_blocks(cls) -> dict:
102
102
"performanceConfig" : PerformanceConfigBlock ,
103
103
}
104
104
105
+ @staticmethod
106
+ def _convert_consecutive_user_messages_to_guarded_text (
107
+ messages : List [AllMessageValues ], optional_params : dict
108
+ ) -> List [AllMessageValues ]:
109
+ """
110
+ Convert consecutive user messages at the end to guarded_text type if guardrailConfig is present
111
+ and no guarded_text is already present in those messages.
112
+ """
113
+ # Check if guardrailConfig is present
114
+ if "guardrailConfig" not in optional_params :
115
+ return messages
116
+
117
+ # Find all consecutive user messages at the end
118
+ consecutive_user_message_indices = []
119
+ for i in range (len (messages ) - 1 , - 1 , - 1 ):
120
+ if messages [i ].get ("role" ) == "user" :
121
+ consecutive_user_message_indices .append (i )
122
+ else :
123
+ break
124
+
125
+ if not consecutive_user_message_indices :
126
+ return messages
127
+
128
+ # Process each consecutive user message
129
+ messages_copy = copy .deepcopy (messages )
130
+ for user_message_index in consecutive_user_message_indices :
131
+ user_message = messages_copy [user_message_index ]
132
+ content = user_message .get ("content" , [])
133
+
134
+ if isinstance (content , list ):
135
+ has_guarded_text = any (
136
+ isinstance (item , dict ) and item .get ("type" ) == "guarded_text"
137
+ for item in content
138
+ )
139
+ if has_guarded_text :
140
+ continue # Skip this message if it already has guarded_text
141
+
142
+ # Convert text elements to guarded_text
143
+ new_content = []
144
+ for item in content :
145
+ if isinstance (item , dict ) and item .get ("type" ) == "text" :
146
+ new_item = {"type" : "guarded_text" , "text" : item ["text" ]} # type: ignore
147
+ new_content .append (new_item )
148
+ else :
149
+ new_content .append (item )
150
+
151
+ messages_copy [user_message_index ]["content" ] = new_content # type: ignore
152
+ elif isinstance (content , str ):
153
+ # If content is a string, convert it to guarded_text
154
+ messages_copy [user_message_index ]["content" ] = [ # type: ignore
155
+ {"type" : "guarded_text" , "text" : content } # type: ignore
156
+ ]
157
+
158
+ return messages_copy
159
+
105
160
@classmethod
106
161
def get_config (cls ):
107
162
return {
@@ -769,6 +824,11 @@ async def _async_transform_request(
769
824
headers : Optional [dict ] = None ,
770
825
) -> RequestObject :
771
826
messages , system_content_blocks = self ._transform_system_message (messages )
827
+
828
+ # Convert last user message to guarded_text if guardrailConfig is present
829
+ messages = self ._convert_consecutive_user_messages_to_guarded_text (
830
+ messages , optional_params
831
+ )
772
832
## TRANSFORMATION ##
773
833
774
834
_data : CommonRequestObject = self ._transform_request_helper (
@@ -821,6 +881,11 @@ def _transform_request(
821
881
) -> RequestObject :
822
882
messages , system_content_blocks = self ._transform_system_message (messages )
823
883
884
+ # Convert last user message to guarded_text if guardrailConfig is present
885
+ messages = self ._convert_consecutive_user_messages_to_guarded_text (
886
+ messages , optional_params
887
+ )
888
+
824
889
_data : CommonRequestObject = self ._transform_request_helper (
825
890
model = model ,
826
891
system_content_blocks = system_content_blocks ,
0 commit comments