Skip to content

Commit 9a43a61

Browse files
authored
feat: add thinking token support to lc (#2056)
* feat: add thinking token support to lc * fix: ci * fix: ci
1 parent 717c06b commit 9a43a61

File tree

1 file changed

+72
-19
lines changed

1 file changed

+72
-19
lines changed

backend/chainlit/langchain/callbacks.py

Lines changed: 72 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import json
21
import time
32
from typing import Any, Dict, List, Optional, Sequence, Tuple, TypedDict, Union
43
from uuid import UUID
@@ -170,7 +169,31 @@ def _convert_message_dict(
170169
if function_call:
171170
msg["function_call"] = function_call
172171
else:
173-
msg["content"] = kwargs.get("content", "")
172+
content = kwargs.get("content")
173+
if isinstance(content, list):
174+
tool_calls = []
175+
content_parts = []
176+
for item in content:
177+
if item.get("type") == "tool_use":
178+
tool_calls.append(
179+
{
180+
"id": item.get("id"),
181+
"type": "function",
182+
"function": {
183+
"name": item.get("name"),
184+
"arguments": item.get("input"),
185+
},
186+
}
187+
)
188+
elif item.get("type") == "text":
189+
content_parts.append({"type": "text", "text": item.get("text")})
190+
191+
if tool_calls:
192+
msg["tool_calls"] = tool_calls
193+
if content_parts:
194+
msg["content"] = content_parts # type: ignore
195+
else:
196+
msg["content"] = content # type: ignore
174197

175198
return msg
176199

@@ -182,6 +205,7 @@ def _convert_message(
182205
return self._convert_message_dict(
183206
message,
184207
)
208+
185209
function_call = message.additional_kwargs.get("function_call")
186210

187211
msg = GenerationMessage(
@@ -199,7 +223,32 @@ def _convert_message(
199223
if function_call:
200224
msg["function_call"] = function_call
201225
else:
202-
msg["content"] = message.content # type: ignore
226+
if isinstance(message.content, list):
227+
tool_calls = []
228+
content_parts = []
229+
for item in message.content:
230+
if isinstance(item, str):
231+
continue
232+
if item.get("type") == "tool_use":
233+
tool_calls.append(
234+
{
235+
"id": item.get("id"),
236+
"type": "function",
237+
"function": {
238+
"name": item.get("name"),
239+
"arguments": item.get("input"),
240+
},
241+
}
242+
)
243+
elif item.get("type") == "text":
244+
content_parts.append({"type": "text", "text": item.get("text")})
245+
246+
if tool_calls:
247+
msg["tool_calls"] = tool_calls
248+
if content_parts:
249+
msg["content"] = content_parts # type: ignore
250+
else:
251+
msg["content"] = message.content # type: ignore
203252

204253
return msg
205254

@@ -236,7 +285,12 @@ def _build_llm_settings(
236285
if "functions" in settings:
237286
tools = [{"type": "function", "function": f} for f in settings["functions"]]
238287
if "tools" in settings:
239-
tools = settings["tools"]
288+
tools = [
289+
{"type": "function", "function": t}
290+
if t.get("type") != "function"
291+
else t
292+
for t in settings["tools"]
293+
]
240294
return provider, model, tools, settings
241295

242296

@@ -492,11 +546,12 @@ async def _start_trace(self, run: Run) -> None:
492546
parent_id=parent_id,
493547
)
494548
step.start = utc_now()
495-
step.input, language = process_content(run.inputs)
496-
if language is not None:
497-
if step.metadata is None:
498-
step.metadata = {}
499-
step.metadata["language"] = language
549+
if step.metadata is None:
550+
step.metadata = {}
551+
if step_type != "llm":
552+
step.input, language = process_content(run.inputs)
553+
if language is not None:
554+
step.metadata["language"] = language
500555

501556
step.tags = run.tags
502557
self.steps[str(run.id)] = step
@@ -560,9 +615,6 @@ async def _on_run_update(self, run: Run) -> None:
560615
break
561616

562617
current_step.language = "json"
563-
current_step.output = json.dumps(
564-
message_completion, indent=4, ensure_ascii=False
565-
)
566618
else:
567619
completion_start = self.completion_generations[str(run.id)]
568620
completion = generation.get("text", "")
@@ -601,13 +653,14 @@ async def _on_run_update(self, run: Run) -> None:
601653
output = outputs.get(output_keys[0], outputs)
602654

603655
if current_step:
604-
current_step.output = (
605-
output[0]
606-
if isinstance(output, Sequence)
607-
and not isinstance(output, str)
608-
and len(output)
609-
else output
610-
)
656+
if current_step.type != "llm":
657+
current_step.output = (
658+
output[0]
659+
if isinstance(output, Sequence)
660+
and not isinstance(output, str)
661+
and len(output)
662+
else output
663+
)
611664
current_step.end = utc_now()
612665
await current_step.update()
613666

0 commit comments

Comments
 (0)