Skip to content

Commit 11e7e3f

Browse files
committed
refactor groq
1 parent 40d0b93 commit 11e7e3f

File tree

1 file changed

+30
-108
lines changed
  • src/langtrace_python_sdk/instrumentation/groq

1 file changed

+30
-108
lines changed

src/langtrace_python_sdk/instrumentation/groq/patch.py

Lines changed: 30 additions & 108 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,14 @@
2323
from opentelemetry.trace import SpanKind
2424
from opentelemetry.trace.status import Status, StatusCode
2525

26+
from langtrace_python_sdk.utils.llm import (
27+
get_base_url,
28+
get_extra_attributes,
29+
get_llm_request_attributes,
30+
get_llm_url,
31+
get_langtrace_attributes,
32+
set_usage_attributes,
33+
)
2634
from langtrace_python_sdk.constants.instrumentation.common import (
2735
LANGTRACE_ADDITIONAL_SPAN_ATTRIBUTES_KEY,
2836
SERVICE_PROVIDERS,
@@ -39,20 +47,13 @@ def chat_completions_create(original_method, version, tracer):
3947
"""Wrap the `create` method of the `ChatCompletion` class to trace it."""
4048

4149
def traced_method(wrapped, instance, args, kwargs):
42-
base_url = (
43-
str(instance._client._base_url)
44-
if hasattr(instance, "_client") and hasattr(instance._client, "_base_url")
45-
else ""
46-
)
4750
service_provider = SERVICE_PROVIDERS["GROQ"]
4851
# If base url contains perplexity or azure, set the service provider accordingly
49-
if "perplexity" in base_url:
52+
if "perplexity" in get_base_url(instance):
5053
service_provider = SERVICE_PROVIDERS["PPLX"]
51-
elif "azure" in base_url:
54+
elif "azure" in get_base_url(instance):
5255
service_provider = SERVICE_PROVIDERS["AZURE"]
5356

54-
extra_attributes = baggage.get_baggage(LANGTRACE_ADDITIONAL_SPAN_ATTRIBUTES_KEY)
55-
5657
# handle tool calls in the kwargs
5758
llm_prompts = []
5859
for item in kwargs.get("messages", []):
@@ -82,19 +83,11 @@ def traced_method(wrapped, instance, args, kwargs):
8283
llm_prompts.append(item)
8384

8485
span_attributes = {
85-
SpanAttributes.LANGTRACE_SDK_NAME.value: LANGTRACE_SDK_NAME,
86-
SpanAttributes.LANGTRACE_SERVICE_NAME.value: service_provider,
87-
SpanAttributes.LANGTRACE_SERVICE_TYPE.value: "llm",
88-
SpanAttributes.LANGTRACE_SERVICE_VERSION.value: version,
89-
SpanAttributes.LANGTRACE_VERSION.value: v(LANGTRACE_SDK_NAME),
90-
SpanAttributes.LLM_URL.value: base_url,
86+
**get_langtrace_attributes(version, service_provider),
87+
**get_llm_request_attributes(kwargs, prompts=llm_prompts),
88+
**get_llm_url(instance),
9189
SpanAttributes.LLM_PATH.value: APIS["CHAT_COMPLETION"]["ENDPOINT"],
92-
SpanAttributes.LLM_PROMPTS.value: json.dumps(llm_prompts),
93-
SpanAttributes.LLM_IS_STREAMING.value: kwargs.get("stream"),
94-
SpanAttributes.LLM_REQUEST_TEMPERATURE.value: kwargs.get("temperature"),
95-
SpanAttributes.LLM_REQUEST_TOP_P.value: kwargs.get("top_p"),
96-
SpanAttributes.LLM_USER.value: kwargs.get("user"),
97-
**(extra_attributes if extra_attributes is not None else {}),
90+
**get_extra_attributes(),
9891
}
9992

10093
attributes = LLMSpanAttributes(**span_attributes)
@@ -110,10 +103,10 @@ def traced_method(wrapped, instance, args, kwargs):
110103

111104
# TODO(Karthik): Gotta figure out how to handle streaming with context
112105
# with tracer.start_as_current_span(APIS["CHAT_COMPLETION"]["METHOD"],
113-
# kind=SpanKind.CLIENT) as span:
106+
# kind=SpanKind.CLIENT.value) as span:
114107
span = tracer.start_span(
115108
APIS["CHAT_COMPLETION"]["METHOD"],
116-
kind=SpanKind.CLIENT,
109+
kind=SpanKind.CLIENT.value,
117110
context=set_span_in_context(trace.get_current_span()),
118111
)
119112
for field, value in attributes.model_dump(by_alias=True).items():
@@ -171,23 +164,7 @@ def traced_method(wrapped, instance, args, kwargs):
171164
# Get the usage
172165
if hasattr(result, "usage") and result.usage is not None:
173166
usage = result.usage
174-
if usage is not None:
175-
set_span_attribute(
176-
span,
177-
SpanAttributes.LLM_USAGE_PROMPT_TOKENS.value,
178-
result.usage.prompt_tokens,
179-
)
180-
181-
set_span_attribute(
182-
span,
183-
SpanAttributes.LLM_USAGE_COMPLETION_TOKENS.value,
184-
usage.completion_tokens,
185-
)
186-
set_span_attribute(
187-
span,
188-
SpanAttributes.LLM_USAGE_TOTAL_TOKENS.value,
189-
usage.total_tokens,
190-
)
167+
set_usage_attributes(span, dict(usage))
191168

192169
span.set_status(StatusCode.OK)
193170
span.end()
@@ -289,22 +266,9 @@ def handle_streaming_response(
289266
finally:
290267
# Finalize span after processing all chunks
291268
span.add_event(Event.STREAM_END.value)
292-
set_span_attribute(
293-
span,
294-
SpanAttributes.LLM_USAGE_PROMPT_TOKENS.value,
295-
prompt_tokens,
296-
)
297-
298-
set_span_attribute(
269+
set_usage_attributes(
299270
span,
300-
SpanAttributes.LLM_USAGE_COMPLETION_TOKENS.value,
301-
completion_tokens,
302-
)
303-
304-
set_span_attribute(
305-
span,
306-
SpanAttributes.LLM_USAGE_TOTAL_TOKENS.value,
307-
prompt_tokens + completion_tokens,
271+
{"input_tokens": prompt_tokens, "output_tokens": completion_tokens},
308272
)
309273

310274
set_span_attribute(
@@ -324,20 +288,13 @@ def async_chat_completions_create(original_method, version, tracer):
324288
"""Wrap the `create` method of the `ChatCompletion` class to trace it."""
325289

326290
async def traced_method(wrapped, instance, args, kwargs):
327-
base_url = (
328-
str(instance._client._base_url)
329-
if hasattr(instance, "_client") and hasattr(instance._client, "_base_url")
330-
else ""
331-
)
332291
service_provider = SERVICE_PROVIDERS["GROQ"]
333292
# If base url contains perplexity or azure, set the service provider accordingly
334-
if "perplexity" in base_url:
293+
if "perplexity" in get_base_url(instance):
335294
service_provider = SERVICE_PROVIDERS["PPLX"]
336-
elif "azure" in base_url:
295+
elif "azure" in get_base_url(instance):
337296
service_provider = SERVICE_PROVIDERS["AZURE"]
338297

339-
extra_attributes = baggage.get_baggage(LANGTRACE_ADDITIONAL_SPAN_ATTRIBUTES_KEY)
340-
341298
# handle tool calls in the kwargs
342299
llm_prompts = []
343300
for item in kwargs.get("messages", []):
@@ -367,19 +324,11 @@ async def traced_method(wrapped, instance, args, kwargs):
367324
llm_prompts.append(item)
368325

369326
span_attributes = {
370-
SpanAttributes.LANGTRACE_SDK_NAME.value: LANGTRACE_SDK_NAME,
371-
SpanAttributes.LANGTRACE_SERVICE_NAME.value: service_provider,
372-
SpanAttributes.LANGTRACE_SERVICE_TYPE.value: "llm",
373-
SpanAttributes.LANGTRACE_SERVICE_VERSION.value: version,
374-
SpanAttributes.LANGTRACE_VERSION.value: v(LANGTRACE_SDK_NAME),
375-
SpanAttributes.LLM_URL.value: base_url,
327+
**get_langtrace_attributes(version, service_provider),
328+
**get_llm_request_attributes(kwargs, prompts=llm_prompts),
329+
**get_llm_url(instance),
376330
SpanAttributes.LLM_PATH.value: APIS["CHAT_COMPLETION"]["ENDPOINT"],
377-
SpanAttributes.LLM_PROMPTS.value: json.dumps(llm_prompts),
378-
SpanAttributes.LLM_IS_STREAMING.value: kwargs.get("stream"),
379-
SpanAttributes.LLM_REQUEST_TEMPERATURE.value: kwargs.get("temperature"),
380-
SpanAttributes.LLM_REQUEST_TOP_P.value: kwargs.get("top_p"),
381-
SpanAttributes.LLM_USER.value: kwargs.get("user"),
382-
**(extra_attributes if extra_attributes is not None else {}),
331+
**get_extra_attributes(),
383332
}
384333

385334
attributes = LLMSpanAttributes(**span_attributes)
@@ -396,9 +345,9 @@ async def traced_method(wrapped, instance, args, kwargs):
396345

397346
# TODO(Karthik): Gotta figure out how to handle streaming with context
398347
# with tracer.start_as_current_span(APIS["CHAT_COMPLETION"]["METHOD"],
399-
# kind=SpanKind.CLIENT) as span:
348+
# kind=SpanKind.CLIENT.value) as span:
400349
span = tracer.start_span(
401-
APIS["CHAT_COMPLETION"]["METHOD"], kind=SpanKind.CLIENT
350+
APIS["CHAT_COMPLETION"]["METHOD"], kind=SpanKind.CLIENT.value
402351
)
403352
for field, value in attributes.model_dump(by_alias=True).items():
404353
set_span_attribute(span, field, value)
@@ -456,22 +405,8 @@ async def traced_method(wrapped, instance, args, kwargs):
456405
if hasattr(result, "usage") and result.usage is not None:
457406
usage = result.usage
458407
if usage is not None:
459-
set_span_attribute(
460-
span,
461-
SpanAttributes.LLM_USAGE_PROMPT_TOKENS.value,
462-
result.usage.prompt_tokens,
463-
)
408+
set_usage_attributes(span, dict(usage))
464409

465-
set_span_attribute(
466-
span,
467-
SpanAttributes.LLM_USAGE_COMPLETION_TOKENS.value,
468-
usage.completion_tokens,
469-
)
470-
set_span_attribute(
471-
span,
472-
SpanAttributes.LLM_USAGE_TOTAL_TOKENS.value,
473-
usage.total_tokens,
474-
)
475410
span.set_status(StatusCode.OK)
476411
span.end()
477412
return result
@@ -576,23 +511,10 @@ async def ahandle_streaming_response(
576511
# Finalize span after processing all chunks
577512
span.add_event(Event.STREAM_END.value)
578513

579-
set_span_attribute(
514+
set_usage_attributes(
580515
span,
581-
SpanAttributes.LLM_USAGE_PROMPT_TOKENS.value,
582-
prompt_tokens,
516+
{"input_tokens": prompt_tokens, "output_tokens": completion_tokens},
583517
)
584-
set_span_attribute(
585-
span,
586-
SpanAttributes.LLM_USAGE_COMPLETION_TOKENS.value,
587-
completion_tokens,
588-
)
589-
590-
set_span_attribute(
591-
span,
592-
SpanAttributes.LLM_USAGE_TOTAL_TOKENS.value,
593-
prompt_tokens + completion_tokens,
594-
)
595-
596518
set_span_attribute(
597519
span,
598520
SpanAttributes.LLM_COMPLETIONS.value,

0 commit comments

Comments
 (0)