Skip to content

Commit c8d7730

Browse files
committed
Merge branch 'development' of github.com:Scale3-Labs/langtrace-python-sdk into add-weaviate-generate-support
2 parents 70fb1f0 + c5b8601 commit c8d7730

File tree

5 files changed

+135
-200819
lines changed

5 files changed

+135
-200819
lines changed

src/langtrace_python_sdk/instrumentation/openai/patch.py

Lines changed: 72 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
)
3030
from langtrace_python_sdk.constants.instrumentation.openai import APIS
3131
from langtrace_python_sdk.utils.llm import calculate_prompt_tokens, estimate_tokens
32+
from openai._types import NOT_GIVEN
3233

3334

3435
def images_generate(original_method, version, tracer):
@@ -276,54 +277,83 @@ def traced_method(wrapped, instance, args, kwargs):
276277

277278

278279
class StreamWrapper:
279-
def __init__(self, stream, span, prompt_tokens, function_call=False, tool_calls=False):
280+
def __init__(
281+
self, stream, span, prompt_tokens, function_call=False, tool_calls=False
282+
):
280283
self.stream = stream
281284
self.span = span
282285
self.prompt_tokens = prompt_tokens
283286
self.function_call = function_call
284287
self.tool_calls = tool_calls
285288
self.result_content = []
286289
self.completion_tokens = 0
290+
self._span_started = False
291+
self._start_span()
292+
293+
def _start_span(self):
294+
if not self._span_started:
295+
self.span.add_event(Event.STREAM_START.value)
296+
self._span_started = True
297+
298+
def _end_span(self):
299+
if self._span_started:
300+
self.span.add_event(Event.STREAM_END.value)
301+
self.span.set_attribute(
302+
"llm.token.counts",
303+
json.dumps(
304+
{
305+
"input_tokens": self.prompt_tokens,
306+
"output_tokens": self.completion_tokens,
307+
"total_tokens": self.prompt_tokens + self.completion_tokens,
308+
}
309+
),
310+
)
311+
self.span.set_attribute(
312+
"llm.responses",
313+
json.dumps(
314+
[
315+
{
316+
"role": "assistant",
317+
"content": "".join(self.result_content),
318+
}
319+
]
320+
),
321+
)
322+
self.span.set_status(StatusCode.OK)
323+
self.span.end()
324+
self._span_started = False
287325

288326
def __enter__(self):
289-
self.span.add_event(Event.STREAM_START.value)
327+
self._start_span()
290328
return self
291329

292330
def __exit__(self, exc_type, exc_val, exc_tb):
293-
self.span.add_event(Event.STREAM_END.value)
294-
self.span.set_attribute(
295-
"llm.token.counts",
296-
json.dumps(
297-
{
298-
"input_tokens": self.prompt_tokens,
299-
"output_tokens": self.completion_tokens,
300-
"total_tokens": self.prompt_tokens + self.completion_tokens,
301-
}
302-
),
303-
)
304-
self.span.set_attribute(
305-
"llm.responses",
306-
json.dumps(
307-
[
308-
{
309-
"role": "assistant",
310-
"content": "".join(self.result_content),
311-
}
312-
]
313-
),
314-
)
315-
self.span.set_status(StatusCode.OK)
316-
self.span.end()
331+
self._end_span()
317332

318333
def __iter__(self):
334+
self._start_span()
319335
return self
320336

337+
def __aiter__(self):
338+
self._start_span()
339+
return self
340+
341+
async def __anext__(self):
342+
try:
343+
chunk = await self.stream.__anext__()
344+
self.process_chunk(chunk)
345+
return chunk
346+
except StopIteration:
347+
self._end_span()
348+
raise
349+
321350
def __next__(self):
322351
try:
323352
chunk = next(self.stream)
324353
self.process_chunk(chunk)
325354
return chunk
326355
except StopIteration:
356+
self._end_span()
327357
raise
328358

329359
def process_chunk(self, chunk):
@@ -441,16 +471,16 @@ def traced_method(wrapped, instance, args, kwargs):
441471
attributes = LLMSpanAttributes(**span_attributes)
442472

443473
tools = []
444-
if kwargs.get("temperature") is not None:
474+
if kwargs.get("temperature") is not None and kwargs.get("temperature") != NOT_GIVEN:
445475
attributes.llm_temperature = kwargs.get("temperature")
446-
if kwargs.get("top_p") is not None:
476+
if kwargs.get("top_p") is not None and kwargs.get("top_p") != NOT_GIVEN:
447477
attributes.llm_top_p = kwargs.get("top_p")
448-
if kwargs.get("user") is not None:
478+
if kwargs.get("user") is not None and kwargs.get("user") != NOT_GIVEN:
449479
attributes.llm_user = kwargs.get("user")
450-
if kwargs.get("functions") is not None:
480+
if kwargs.get("functions") is not None and kwargs.get("functions") != NOT_GIVEN:
451481
for function in kwargs.get("functions"):
452482
tools.append(json.dumps({"type": "function", "function": function}))
453-
if kwargs.get("tools") is not None:
483+
if kwargs.get("tools") is not None and kwargs.get("tools") != NOT_GIVEN:
454484
tools.append(json.dumps(kwargs.get("tools")))
455485
if len(tools) > 0:
456486
attributes.llm_tools = json.dumps(tools)
@@ -469,7 +499,7 @@ def traced_method(wrapped, instance, args, kwargs):
469499
try:
470500
# Attempt to call the original method
471501
result = wrapped(*args, **kwargs)
472-
if kwargs.get("stream") is False or kwargs.get("stream") is None:
502+
if kwargs.get("stream") is False or kwargs.get("stream") is None or kwargs.get("stream") == NOT_GIVEN:
473503
span.set_attribute("llm.model", result.model)
474504
if hasattr(result, "choices") and result.choices is not None:
475505
responses = [
@@ -498,7 +528,7 @@ def traced_method(wrapped, instance, args, kwargs):
498528
span.set_attribute("llm.responses", json.dumps(responses))
499529
if (
500530
hasattr(result, "system_fingerprint")
501-
and result.system_fingerprint is not None
531+
and result.system_fingerprint is not None and result.system_fingerprint != NOT_GIVEN
502532
):
503533
span.set_attribute(
504534
"llm.system.fingerprint", result.system_fingerprint
@@ -525,7 +555,7 @@ def traced_method(wrapped, instance, args, kwargs):
525555
)
526556

527557
# iterate over kwargs.get("functions") and calculate the prompt tokens
528-
if kwargs.get("functions") is not None:
558+
if kwargs.get("functions") is not None and kwargs.get("functions") != NOT_GIVEN:
529559
for function in kwargs.get("functions"):
530560
prompt_tokens += calculate_prompt_tokens(
531561
json.dumps(function), kwargs.get("model")
@@ -611,16 +641,16 @@ async def traced_method(wrapped, instance, args, kwargs):
611641
attributes = LLMSpanAttributes(**span_attributes)
612642

613643
tools = []
614-
if kwargs.get("temperature") is not None:
644+
if kwargs.get("temperature") is not None and kwargs.get("temperature") != NOT_GIVEN:
615645
attributes.llm_temperature = kwargs.get("temperature")
616-
if kwargs.get("top_p") is not None:
646+
if kwargs.get("top_p") is not None and kwargs.get("top_p") != NOT_GIVEN:
617647
attributes.llm_top_p = kwargs.get("top_p")
618-
if kwargs.get("user") is not None:
648+
if kwargs.get("user") is not None and kwargs.get("user") != NOT_GIVEN:
619649
attributes.llm_user = kwargs.get("user")
620-
if kwargs.get("functions") is not None:
650+
if kwargs.get("functions") is not None and kwargs.get("functions") != NOT_GIVEN:
621651
for function in kwargs.get("functions"):
622652
tools.append(json.dumps({"type": "function", "function": function}))
623-
if kwargs.get("tools") is not None:
653+
if kwargs.get("tools") is not None and kwargs.get("tools") != NOT_GIVEN:
624654
tools.append(json.dumps(kwargs.get("tools")))
625655
if len(tools) > 0:
626656
attributes.llm_tools = json.dumps(tools)
@@ -637,7 +667,7 @@ async def traced_method(wrapped, instance, args, kwargs):
637667
try:
638668
# Attempt to call the original method
639669
result = await wrapped(*args, **kwargs)
640-
if kwargs.get("stream") is False or kwargs.get("stream") is None:
670+
if kwargs.get("stream") is False or kwargs.get("stream") is None or kwargs.get("stream") == NOT_GIVEN:
641671
span.set_attribute("llm.model", result.model)
642672
if hasattr(result, "choices") and result.choices is not None:
643673
responses = [
@@ -666,7 +696,7 @@ async def traced_method(wrapped, instance, args, kwargs):
666696
span.set_attribute("llm.responses", json.dumps(responses))
667697
if (
668698
hasattr(result, "system_fingerprint")
669-
and result.system_fingerprint is not None
699+
and result.system_fingerprint is not None and result.system_fingerprint != NOT_GIVEN
670700
):
671701
span.set_attribute(
672702
"llm.system.fingerprint", result.system_fingerprint
@@ -693,7 +723,7 @@ async def traced_method(wrapped, instance, args, kwargs):
693723
)
694724

695725
# iterate over kwargs.get("functions") and calculate the prompt tokens
696-
if kwargs.get("functions") is not None:
726+
if kwargs.get("functions") is not None and kwargs.get("functions") != NOT_GIVEN:
697727
for function in kwargs.get("functions"):
698728
prompt_tokens += calculate_prompt_tokens(
699729
json.dumps(function), kwargs.get("model")
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "2.1.22"
1+
__version__ = "2.1.24"

0 commit comments

Comments
 (0)