diff --git a/langfuse/langchain/CallbackHandler.py b/langfuse/langchain/CallbackHandler.py index 9e8278e28..429069ea5 100644 --- a/langfuse/langchain/CallbackHandler.py +++ b/langfuse/langchain/CallbackHandler.py @@ -299,16 +299,30 @@ def on_chain_start( serialized, "chain", **kwargs ) - span = self._get_parent_observation(parent_run_id).start_observation( - name=span_name, - as_type=observation_type, - metadata=span_metadata, - input=inputs, - level=cast( - Optional[Literal["DEBUG", "DEFAULT", "WARNING", "ERROR"]], - span_level, - ), - ) + if metadata and (trace_id := metadata.get("langfuse_trace_id")) and parent_run_id is None: + span = self.client.start_observation( + trace_context={"trace_id": trace_id}, + name=span_name, + as_type=observation_type, + metadata=span_metadata, + input=inputs, + level=cast( + Optional[Literal["DEBUG", "DEFAULT", "WARNING", "ERROR"]], + span_level, + ), + ) + + else: + span = self._get_parent_observation(parent_run_id).start_observation( + name=span_name, + as_type=observation_type, + metadata=span_metadata, + input=inputs, + level=cast( + Optional[Literal["DEBUG", "DEFAULT", "WARNING", "ERROR"]], + span_level, + ), + ) self._attach_observation(run_id, span) if parent_run_id is None: