|
56 | 56 |
|
57 | 57 |
|
58 | 58 | class LangchainCallbackHandler(LangchainBaseCallbackHandler): |
59 | | - def __init__(self, *, public_key: Optional[str] = None) -> None: |
| 59 | + def __init__( |
| 60 | + self, *, public_key: Optional[str] = None, update_trace: bool = False |
| 61 | + ) -> None: |
| 62 | + """Initialize the LangchainCallbackHandler. |
| 63 | +
|
| 64 | + Args: |
| 65 | + public_key: Optional Langfuse public key. If not provided, will use the default client configuration. |
| 66 | + update_trace: Whether to update the Langfuse trace with the chains input / output / metadata / name. Defaults to False. |
| 67 | + """ |
60 | 68 | self.client = get_client(public_key=public_key) |
61 | 69 |
|
62 | 70 | self.runs: Dict[UUID, Union[LangfuseSpan, LangfuseGeneration]] = {} |
63 | 71 | self.prompt_to_parent_run_map: Dict[UUID, Any] = {} |
64 | 72 | self.updated_completion_start_time_memo: Set[UUID] = set() |
65 | 73 |
|
66 | 74 | self.last_trace_id: Optional[str] = None |
| 75 | + self.update_trace = update_trace |
67 | 76 |
|
68 | 77 | def on_llm_new_token( |
69 | 78 | self, |
@@ -207,7 +216,19 @@ def on_chain_start( |
207 | 216 | ), |
208 | 217 | ) |
209 | 218 | span.update_trace( |
210 | | - **self._parse_langfuse_trace_attributes_from_metadata(metadata) |
| 219 | + **( |
| 220 | + cast( |
| 221 | + Any, |
| 222 | + { |
| 223 | + "input": inputs, |
| 224 | + "name": span_name, |
| 225 | + "metadata": span_metadata, |
| 226 | + }, |
| 227 | + ) |
| 228 | + if self.update_trace |
| 229 | + else {} |
| 230 | + ), |
| 231 | + **self._parse_langfuse_trace_attributes_from_metadata(metadata), |
211 | 232 | ) |
212 | 233 | self.runs[run_id] = span |
213 | 234 | else: |
@@ -322,14 +343,21 @@ def on_chain_end( |
322 | 343 | if run_id not in self.runs: |
323 | 344 | raise Exception("run not found") |
324 | 345 |
|
325 | | - self.runs[run_id].update( |
| 346 | + span = self.runs[run_id] |
| 347 | + span.update( |
326 | 348 | output=outputs, |
327 | 349 | input=kwargs.get("inputs"), |
328 | | - ).end() |
| 350 | + ) |
| 351 | + |
| 352 | + if parent_run_id is None and self.update_trace: |
| 353 | + span.update_trace(output=outputs, input=kwargs.get("inputs")) |
| 354 | + |
| 355 | + span.end() |
329 | 356 |
|
330 | 357 | del self.runs[run_id] |
331 | 358 |
|
332 | 359 | self._deregister_langfuse_prompt(run_id) |
| 360 | + |
333 | 361 | except Exception as e: |
334 | 362 | langfuse_logger.exception(e) |
335 | 363 |
|
|
0 commit comments