Skip to content

Commit da1eb86

Browse files
authored
feat(langchain): add update_trace argument (#1302)
1 parent 117f6ee commit da1eb86

File tree

1 file changed

+32
-4
lines changed

1 file changed

+32
-4
lines changed

langfuse/langchain/CallbackHandler.py

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,14 +56,23 @@
5656

5757

5858
class LangchainCallbackHandler(LangchainBaseCallbackHandler):
59-
def __init__(self, *, public_key: Optional[str] = None) -> None:
59+
def __init__(
60+
self, *, public_key: Optional[str] = None, update_trace: bool = False
61+
) -> None:
62+
"""Initialize the LangchainCallbackHandler.
63+
64+
Args:
65+
public_key: Optional Langfuse public key. If not provided, will use the default client configuration.
66+
update_trace: Whether to update the Langfuse trace with the chains input / output / metadata / name. Defaults to False.
67+
"""
6068
self.client = get_client(public_key=public_key)
6169

6270
self.runs: Dict[UUID, Union[LangfuseSpan, LangfuseGeneration]] = {}
6371
self.prompt_to_parent_run_map: Dict[UUID, Any] = {}
6472
self.updated_completion_start_time_memo: Set[UUID] = set()
6573

6674
self.last_trace_id: Optional[str] = None
75+
self.update_trace = update_trace
6776

6877
def on_llm_new_token(
6978
self,
@@ -207,7 +216,19 @@ def on_chain_start(
207216
),
208217
)
209218
span.update_trace(
210-
**self._parse_langfuse_trace_attributes_from_metadata(metadata)
219+
**(
220+
cast(
221+
Any,
222+
{
223+
"input": inputs,
224+
"name": span_name,
225+
"metadata": span_metadata,
226+
},
227+
)
228+
if self.update_trace
229+
else {}
230+
),
231+
**self._parse_langfuse_trace_attributes_from_metadata(metadata),
211232
)
212233
self.runs[run_id] = span
213234
else:
@@ -322,14 +343,21 @@ def on_chain_end(
322343
if run_id not in self.runs:
323344
raise Exception("run not found")
324345

325-
self.runs[run_id].update(
346+
span = self.runs[run_id]
347+
span.update(
326348
output=outputs,
327349
input=kwargs.get("inputs"),
328-
).end()
350+
)
351+
352+
if parent_run_id is None and self.update_trace:
353+
span.update_trace(output=outputs, input=kwargs.get("inputs"))
354+
355+
span.end()
329356

330357
del self.runs[run_id]
331358

332359
self._deregister_langfuse_prompt(run_id)
360+
333361
except Exception as e:
334362
langfuse_logger.exception(e)
335363

0 commit comments

Comments
 (0)