diff --git a/langfuse/langchain/CallbackHandler.py b/langfuse/langchain/CallbackHandler.py index 9e8278e28..eecdca646 100644 --- a/langfuse/langchain/CallbackHandler.py +++ b/langfuse/langchain/CallbackHandler.py @@ -28,6 +28,7 @@ LangfuseSpan, LangfuseTool, ) +from langfuse.types import TraceContext from langfuse._utils import _get_timestamp from langfuse.langchain.utils import _extract_model_name from langfuse.logger import langfuse_logger @@ -92,7 +93,11 @@ class LangchainCallbackHandler(LangchainBaseCallbackHandler): def __init__( - self, *, public_key: Optional[str] = None, update_trace: bool = False + self, + *, + public_key: Optional[str] = None, + update_trace: bool = False, + trace_context: Optional[TraceContext] = None, ) -> None: """Initialize the LangchainCallbackHandler. @@ -120,6 +125,7 @@ def __init__( self.last_trace_id: Optional[str] = None self.update_trace = update_trace + self.trace_context = trace_context def on_llm_new_token( self, @@ -299,16 +305,31 @@ def on_chain_start( serialized, "chain", **kwargs ) - span = self._get_parent_observation(parent_run_id).start_observation( - name=span_name, - as_type=observation_type, - metadata=span_metadata, - input=inputs, - level=cast( - Optional[Literal["DEBUG", "DEFAULT", "WARNING", "ERROR"]], - span_level, - ), - ) + obs = self._get_parent_observation(parent_run_id) + if isinstance(obs, Langfuse): + span = obs.start_observation( + trace_context=self.trace_context, + name=span_name, + as_type=observation_type, + metadata=span_metadata, + input=inputs, + level=cast( + Literal["DEBUG", "DEFAULT", "WARNING", "ERROR"] | None, + span_level, + ), + ) + else: + span = obs.start_observation( + name=span_name, + as_type=observation_type, + metadata=span_metadata, + input=inputs, + level=cast( + Literal["DEBUG", "DEFAULT", "WARNING", "ERROR"] | None, + span_level, + ), + ) + self._attach_observation(run_id, span) if parent_run_id is None: