Skip to content

Commit 8177876

Browse files
authored
Support <think> tags (#117)
1 parent 66cb51b commit 8177876

File tree

1 file changed

+17
-4
lines changed

1 file changed

+17
-4
lines changed

src/api/models/bedrock.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,7 @@ async def chat_stream(self, chat_request: ChatRequest) -> AsyncIterable[bytes]:
268268
response = await self._invoke_bedrock(chat_request, stream=True)
269269
message_id = self.generate_message_id()
270270
stream = response.get("stream")
271+
self.think_emitted = False
271272
async for chunk in self._async_iterate(stream):
272273
args = {"model_id": chat_request.model, "message_id": message_id, "chunk": chunk}
273274
stream_response = self._create_response_stream(**args)
@@ -288,6 +289,7 @@ async def chat_stream(self, chat_request: ChatRequest) -> AsyncIterable[bytes]:
288289

289290
# return an [DONE] message at the end.
290291
yield self.stream_response_to_bytes()
292+
self.think_emitted = False # Cleanup
291293
except Exception as e:
292294
logger.error("Stream error for model %s: %s", chat_request.model, str(e))
293295
error_event = Error(error=ErrorMessage(message=str(e)))
@@ -634,6 +636,9 @@ def _create_response(
634636
logger.warning(
635637
"Unknown tag in message content " + ",".join(c.keys())
636638
)
639+
if message.reasoning_content:
640+
message.content = f"<think>{message.reasoning_content}</think>{message.content}"
641+
message.reasoning_content = None
637642

638643
response = ChatResponse(
639644
id=message_id,
@@ -702,11 +707,19 @@ def _create_response_stream(
702707
content=delta["text"],
703708
)
704709
elif "reasoningContent" in delta:
705-
# ignore "signature" in the delta.
706710
if "text" in delta["reasoningContent"]:
707-
message = ChatResponseMessage(
708-
reasoning_content=delta["reasoningContent"]["text"],
709-
)
711+
content = delta["reasoningContent"]["text"]
712+
if not self.think_emitted:
713+
# Port of "content_block_start" with "thinking"
714+
content = "<think>" + content
715+
self.think_emitted = True
716+
message = ChatResponseMessage(content=content)
717+
elif "signature" in delta["reasoningContent"]:
718+
# Port of "signature_delta"
719+
if self.think_emitted:
720+
message = ChatResponseMessage(content="\n </think> \n\n")
721+
else:
722+
return None # Ignore signature if no <think> started
710723
else:
711724
# tool use
712725
index = chunk["contentBlockDelta"]["contentBlockIndex"] - 1

0 commit comments

Comments
 (0)