Skip to content

Commit e425627

Browse files
committed
add support for ollama
1 parent 3f568de commit e425627

File tree

2 files changed

+54
-112
lines changed

2 files changed

+54
-112
lines changed
Lines changed: 43 additions & 110 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from langtrace_python_sdk.constants.instrumentation.ollama import APIS
22
from langtrace_python_sdk.utils import set_span_attribute
33
from langtrace_python_sdk.utils.llm import (
4+
StreamWrapper,
45
get_extra_attributes,
56
get_langtrace_attributes,
67
get_llm_request_attributes,
@@ -16,9 +17,10 @@
1617
import json
1718
from opentelemetry.trace.status import Status, StatusCode
1819
from langtrace.trace_attributes import SpanAttributes
20+
from opentelemetry.trace import Tracer
1921

2022

21-
def generic_patch(operation_name, version, tracer):
23+
def generic_patch(operation_name, version, tracer: Tracer):
2224
def traced_method(wrapped, instance, args, kwargs):
2325
api = APIS[operation_name]
2426
service_provider = SERVICE_PROVIDERS["OLLAMA"]
@@ -35,36 +37,29 @@ def traced_method(wrapped, instance, args, kwargs):
3537
}
3638

3739
attributes = LLMSpanAttributes(**span_attributes)
38-
with tracer.start_as_current_span(
39-
name=get_span_name(f'ollama.{api["METHOD"]}'), kind=SpanKind.CLIENT
40-
) as span:
41-
_set_input_attributes(span, kwargs, attributes)
42-
43-
try:
44-
result = wrapped(*args, **kwargs)
45-
if result:
46-
if span.is_recording():
47-
48-
if kwargs.get("stream"):
49-
return _handle_streaming_response(
50-
span, result, api["METHOD"]
51-
)
5240

53-
_set_response_attributes(span, result)
54-
span.set_status(Status(StatusCode.OK))
41+
span = tracer.start_span(
42+
name=get_span_name(f'ollama.{api["METHOD"]}'), kind=SpanKind.CLIENT
43+
)
44+
_set_input_attributes(span, kwargs, attributes)
5545

56-
span.end()
57-
return result
46+
try:
47+
result = wrapped(*args, **kwargs)
48+
if kwargs.get("stream"):
49+
return StreamWrapper(result, span)
50+
else:
51+
_set_response_attributes(span, result)
52+
return result
5853

59-
except Exception as err:
60-
# Record the exception in the span
61-
span.record_exception(err)
54+
except Exception as err:
55+
# Record the exception in the span
56+
span.record_exception(err)
6257

63-
# Set the span status to indicate an error
64-
span.set_status(Status(StatusCode.ERROR, str(err)))
58+
# Set the span status to indicate an error
59+
span.set_status(Status(StatusCode.ERROR, str(err)))
6560

66-
# Reraise the exception to ensure it's not swallowed
67-
raise
61+
# Reraise the exception to ensure it's not swallowed
62+
raise
6863

6964
return traced_method
7065

@@ -82,30 +77,28 @@ async def traced_method(wrapped, instance, args, kwargs):
8277
**get_extra_attributes(),
8378
}
8479
attributes = LLMSpanAttributes(**span_attributes)
85-
with tracer.start_as_current_span(api["METHOD"], kind=SpanKind.CLIENT) as span:
86-
_set_input_attributes(span, kwargs, attributes)
87-
try:
88-
result = await wrapped(*args, **kwargs)
89-
if result:
90-
if span.is_recording():
91-
if kwargs.get("stream"):
92-
return _ahandle_streaming_response(
93-
span, result, api["METHOD"]
94-
)
95-
96-
_set_response_attributes(span, result)
97-
span.set_status(Status(StatusCode.OK))
98-
span.end()
99-
return result
100-
except Exception as err:
101-
# Record the exception in the span
102-
span.record_exception(err)
103-
104-
# Set the span status to indicate an error
105-
span.set_status(Status(StatusCode.ERROR, str(err)))
106-
107-
# Reraise the exception to ensure it's not swallowed
108-
raise
80+
span = tracer.start_span(
81+
name=get_span_name(f'ollama.{api["METHOD"]}'), kind=SpanKind.CLIENT
82+
)
83+
84+
_set_input_attributes(span, kwargs, attributes)
85+
try:
86+
result = await wrapped(*args, **kwargs)
87+
if kwargs.get("stream"):
88+
return StreamWrapper(span, result)
89+
else:
90+
_set_response_attributes(span, result)
91+
span.end()
92+
return result
93+
except Exception as err:
94+
# Record the exception in the span
95+
span.record_exception(err)
96+
97+
# Set the span status to indicate an error
98+
span.set_status(Status(StatusCode.ERROR, str(err)))
99+
100+
# Reraise the exception to ensure it's not swallowed
101+
raise
109102

110103
return traced_method
111104

@@ -162,63 +155,3 @@ def _set_input_attributes(span, kwargs, attributes):
162155
SpanAttributes.LLM_PRESENCE_PENALTY,
163156
options.get("presence_penalty"),
164157
)
165-
166-
167-
def _handle_streaming_response(span, response, api):
168-
accumulated_tokens = None
169-
if api == "chat":
170-
accumulated_tokens = {"message": {"content": "", "role": ""}}
171-
if api == "completion" or api == "generate":
172-
accumulated_tokens = {"response": ""}
173-
span.add_event(Event.STREAM_START.value)
174-
try:
175-
for chunk in response:
176-
content = None
177-
if api == "chat":
178-
content = chunk["message"]["content"]
179-
accumulated_tokens["message"]["content"] += chunk["message"]["content"]
180-
accumulated_tokens["message"]["role"] = chunk["message"]["role"]
181-
if api == "generate":
182-
content = chunk["response"]
183-
accumulated_tokens["response"] += chunk["response"]
184-
185-
set_event_completion_chunk(span, content)
186-
187-
_set_response_attributes(span, chunk | accumulated_tokens)
188-
finally:
189-
# Finalize span after processing all chunks
190-
span.add_event(Event.STREAM_END.value)
191-
span.set_status(StatusCode.OK)
192-
span.end()
193-
194-
return response
195-
196-
197-
async def _ahandle_streaming_response(span, response, api):
198-
accumulated_tokens = None
199-
if api == "chat":
200-
accumulated_tokens = {"message": {"content": "", "role": ""}}
201-
if api == "completion" or api == "generate":
202-
accumulated_tokens = {"response": ""}
203-
204-
span.add_event(Event.STREAM_START.value)
205-
try:
206-
async for chunk in response:
207-
content = None
208-
if api == "chat":
209-
content = chunk["message"]["content"]
210-
accumulated_tokens["message"]["content"] += chunk["message"]["content"]
211-
accumulated_tokens["message"]["role"] = chunk["message"]["role"]
212-
if api == "generate":
213-
content = chunk["response"]
214-
accumulated_tokens["response"] += chunk["response"]
215-
216-
set_event_completion_chunk(span, content)
217-
_set_response_attributes(span, chunk | accumulated_tokens)
218-
finally:
219-
# Finalize span after processing all chunks
220-
span.add_event(Event.STREAM_END.value)
221-
span.set_status(StatusCode.OK)
222-
span.end()
223-
224-
return response

src/langtrace_python_sdk/utils/llm.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,6 @@ class StreamWrapper:
237237
def __init__(
238238
self, stream, span, prompt_tokens=0, function_call=False, tool_calls=False
239239
):
240-
241240
self.stream = stream
242241
self.span = span
243242
self.prompt_tokens = prompt_tokens
@@ -284,7 +283,6 @@ def cleanup(self):
284283
}
285284
],
286285
)
287-
288286
self.span.set_status(StatusCode.OK)
289287
self.span.end()
290288
self._span_started = False
@@ -377,6 +375,10 @@ def build_streaming_response(self, chunk):
377375
if hasattr(chunk, "delta") and chunk.delta is not None:
378376
content = [chunk.delta.text] if hasattr(chunk.delta, "text") else []
379377

378+
if isinstance(chunk, dict):
379+
if "message" in chunk:
380+
if "content" in chunk["message"]:
381+
content = [chunk["message"]["content"]]
380382
if content:
381383
self.result_content.append(content[0])
382384

@@ -401,6 +403,13 @@ def set_usage_attributes(self, chunk):
401403
self.completion_tokens = chunk.usage_metadata.candidates_token_count
402404
self.prompt_tokens = chunk.usage_metadata.prompt_token_count
403405

406+
# Ollama
407+
if isinstance(chunk, dict):
408+
if "prompt_eval_count" in chunk:
409+
self.prompt_tokens = chunk["prompt_eval_count"]
410+
if "eval_count" in chunk:
411+
self.completion_tokens = chunk["eval_count"]
412+
404413
def process_chunk(self, chunk):
405414
self.set_response_model(chunk=chunk)
406415
self.build_streaming_response(chunk=chunk)

0 commit comments

Comments
 (0)