Skip to content

Commit 2f48a10

Browse files
feat(openai): add support for tagging by API key [backport #5757 to 1.13] (#5769)
This PR backports #5757 to 1.13. Support tagging traces, logs and metrics with the API key. This enables OpenAI admins and users to see usage per-API key. The API key is obfuscated to include only the last 4 characters which is what OpenAI displays in their UI. Note: with this change, we no longer trace streamed responses and instead include that in the initial openai request/response span. This was done since having the stream captured in the request span itself is likely more useful to users than having a disjointed trace. For example: Previous behavior: <img width="1141" alt="Screenshot 2023-05-04 at 10 02 39 PM" src="https://user-images.githubusercontent.com/35776586/236363999-a61737e7-76a5-4a20-9e66-bea34ba9ccf1.png"> Trace behavior after this change: <img width="1144" alt="Screenshot 2023-05-04 at 10 02 24 PM" src="https://user-images.githubusercontent.com/35776586/236363976-aa34a907-8ad5-40d9-9533-e3706f88b9ed.png"> The one risk is if a user never consumers the stream generator, then this span will never be finished. This is a low risk, but should be addressed as a future step. ## Checklist - [x] Change(s) are motivated and described in the PR description. - [x] Testing strategy is described if automated tests are not included in the PR. - [x] Risk is outlined (performance impact, potential for breakage, maintainability, etc). - [x] Change is maintainable (easy to change, telemetry, documentation). - [x] [Library release note guidelines](https://ddtrace.readthedocs.io/en/stable/contributing.html#Release-Note-Guidelines) are followed. - [x] Documentation is included (in-code, generated user docs, [public corp docs](https://github.com/DataDog/documentation/)). - [x] PR description includes explicit acknowledgement/acceptance of the performance implications of this PR as reported in the benchmarks PR comment. ## Reviewer Checklist - [x] Title is accurate. - [x] No unnecessary changes are introduced. - [x] Description motivates each change. - [x] Avoids breaking [API](https://ddtrace.readthedocs.io/en/stable/versioning.html#interfaces) changes unless absolutely necessary. - [x] Testing strategy adequately addresses listed risk(s). - [x] Change is maintainable (easy to change, telemetry, documentation). - [x] Release note makes sense to a user of the library. - [x] Reviewer has explicitly acknowledged and discussed the performance implications of this PR as reported in the benchmarks PR comment. Co-authored-by: Kyle Verhoog <[email protected]>
1 parent b2ba73f commit 2f48a10

14 files changed

+490
-189
lines changed

ddtrace/contrib/openai/patch.py

Lines changed: 88 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import time
55
from typing import AsyncGenerator
66
from typing import Generator
7+
from typing import TYPE_CHECKING
78

89
from ddtrace import config
910
from ddtrace.constants import SPAN_MEASURED_KEY
@@ -22,6 +23,10 @@
2223
from ._logging import V2LogWriter
2324

2425

26+
if TYPE_CHECKING:
27+
from ddtrace import Span
28+
29+
2530
log = get_logger(__name__)
2631

2732

@@ -69,6 +74,29 @@ def is_pc_sampled_log(self, span):
6974
def start_log_writer(self):
7075
self._log_writer.start()
7176

77+
@property
78+
def _user_api_key(self):
79+
# type: () -> str
80+
"""Get a representation of the user API key for tagging."""
81+
# Match the API key representation that OpenAI uses in their UI.
82+
return "sk-...%s" % self._openai.api_key[-4:]
83+
84+
def set_base_span_tags(self, span):
85+
# type: (Span) -> None
86+
span.set_tag_str(COMPONENT, self._config.integration_name)
87+
span.set_tag_str("openai.user.api_key", self._user_api_key)
88+
89+
# Do these dynamically as openai users can set these at any point
90+
# not necessarily before patch() time.
91+
# organization_id is only returned by a few endpoints, grab it when we can.
92+
for attr in ("api_base", "api_version", "organization_id"):
93+
v = getattr(self._openai, attr, None)
94+
if v is not None:
95+
if attr == "organization_id":
96+
span.set_tag_str("openai.organization.id", v or "")
97+
else:
98+
span.set_tag_str(attr, v)
99+
72100
def trace(self, pin, endpoint, model):
73101
"""Start an OpenAI span.
74102
@@ -82,17 +110,9 @@ def trace(self, pin, endpoint, model):
82110
span = pin.tracer.trace("openai.request", resource=resource, service=trace_utils.int_service(pin, self._config))
83111
# Enable trace metrics for these spans so users can see per-service openai usage in APM.
84112
span.set_tag(SPAN_MEASURED_KEY)
85-
span.set_tag_str(COMPONENT, self._config.integration_name)
86-
# Do these dynamically as openai users can set these at any point
87-
# not necessarily before patch() time.
88-
# organization_id is only returned by a few endpoints, grab it when we can.
89-
for attr in ("api_base", "api_version", "organization_id"):
90-
v = getattr(self._openai, attr, None)
91-
if v is not None:
92-
if attr == "organization_id":
93-
span.set_tag_str("openai.organization.id", v or "")
94-
else:
95-
span.set_tag_str(attr, v)
113+
114+
self.set_base_span_tags(span)
115+
96116
span.set_tag_str("openai.endpoint", endpoint)
97117
if model:
98118
span.set_tag_str("openai.model", model)
@@ -101,12 +121,16 @@ def trace(self, pin, endpoint, model):
101121
def log(self, span, level, msg, attrs):
102122
if not self._config.logs_enabled:
103123
return
104-
tags = "env:%s,version:%s,openai.endpoint:%s,openai.model:%s,openai.organization.name:%s" % (
105-
(config.env or ""),
106-
(config.version or ""),
107-
(span.get_tag("openai.endpoint") or ""),
108-
(span.get_tag("openai.model") or ""),
109-
(span.get_tag("openai.organization.name") or ""),
124+
tags = (
125+
"env:%s,version:%s,openai.endpoint:%s,openai.model:%s,openai.organization.name:%s,openai.user.api_key:%s"
126+
% (
127+
(config.env or ""),
128+
(config.version or ""),
129+
(span.get_tag("openai.endpoint") or ""),
130+
(span.get_tag("openai.model") or ""),
131+
(span.get_tag("openai.organization.name") or ""),
132+
(span.get_tag("openai.user.api_key") or ""),
133+
)
110134
)
111135

112136
log = {
@@ -133,6 +157,7 @@ def _metrics_tags(self, span):
133157
"openai.endpoint:%s" % (span.get_tag("openai.endpoint") or ""),
134158
"openai.organization.id:%s" % (span.get_tag("openai.organization.id") or ""),
135159
"openai.organization.name:%s" % (span.get_tag("openai.organization.name") or ""),
160+
"openai.user.api_key:%s" % (span.get_tag("openai.user.api_key") or ""),
136161
"error:%d" % span.error,
137162
]
138163
err_type = span.get_tag("error.type")
@@ -286,8 +311,10 @@ def _traced_endpoint(endpoint_hook, integration, pin, args, kwargs):
286311
if error is None:
287312
return e.value
288313
finally:
289-
span.finish()
290-
integration.metric(span, "dist", "request.duration", span.duration_ns)
314+
# Streamed responses will be finished when the generator exits.
315+
if not kwargs.get("stream"):
316+
span.finish()
317+
integration.metric(span, "dist", "request.duration", span.duration_ns)
291318

292319

293320
def _patched_endpoint(openai, integration, patch_hook):
@@ -392,23 +419,21 @@ def _handle_response(self, pin, span, integration, resp):
392419
"""
393420

394421
def shared_gen():
395-
stream_span = pin.tracer.start_span("openai.stream", child_of=span, activate=True)
396422
try:
397423
num_prompt_tokens = span.get_metric("openai.response.usage.prompt_tokens") or 0
398424
num_completion_tokens = yield
399425

400-
stream_span.set_metric("openai.response.usage.completion_tokens", num_completion_tokens)
426+
span.set_metric("openai.response.usage.completion_tokens", num_completion_tokens)
401427
total_tokens = num_prompt_tokens + num_completion_tokens
402-
stream_span.set_metric("openai.response.usage.total_tokens", total_tokens)
428+
span.set_metric("openai.response.usage.total_tokens", total_tokens)
429+
integration.metric(span, "dist", "tokens.prompt", num_prompt_tokens, tags=["openai.estimated:true"])
403430
integration.metric(
404431
span, "dist", "tokens.completion", num_completion_tokens, tags=["openai.estimated:true"]
405432
)
406433
integration.metric(span, "dist", "tokens.total", total_tokens, tags=["openai.estimated:true"])
407434
finally:
408-
stream_span.finish()
409-
# ``span`` could be flushed by this point. This is a best effort to attach the metric
410-
span.set_metric("openai.response.usage.completion_tokens", num_completion_tokens)
411-
span.set_metric("openai.response.usage.total_tokens", total_tokens)
435+
span.finish()
436+
integration.metric(span, "dist", "request.duration", span.duration_ns)
412437

413438
# A chunk corresponds to a token:
414439
# https://community.openai.com/t/how-to-get-total-tokens-from-a-stream-of-completioncreaterequests/110700
@@ -490,7 +515,6 @@ def handle_request(self, pin, integration, span, args, kwargs):
490515
for p in prompt:
491516
num_prompt_tokens += _est_tokens(p)
492517
span.set_metric("openai.response.usage.prompt_tokens", num_prompt_tokens)
493-
integration.metric(span, "dist", "tokens.prompt", num_prompt_tokens, tags=["openai.estimated:true"])
494518

495519
self._record_request(span, kwargs)
496520

@@ -551,17 +575,15 @@ def handle_request(self, pin, integration, span, args, kwargs):
551575
if "stream" in kwargs and kwargs["stream"]:
552576
# streamed responses do not have a usage field, so we have to
553577
# estimate the number of tokens returned.
554-
num_message_tokens = 0
578+
est_num_message_tokens = 0
555579
for m in messages:
556-
num_message_tokens += _est_tokens(m.get("content", ""))
557-
span.set_metric("openai.response.usage.prompt_tokens", num_message_tokens)
558-
integration.metric(span, "dist", "tokens.prompt", num_message_tokens, tags=["openai.estimated:true"])
580+
est_num_message_tokens += _est_tokens(m.get("content", ""))
581+
span.set_metric("openai.response.usage.prompt_tokens", est_num_message_tokens)
559582

560583
self._record_request(span, kwargs)
561584

562585
resp, error = yield
563586

564-
choices = []
565587
if resp and not kwargs.get("stream"):
566588
choices = resp.get("choices", [])
567589
for choice in choices:
@@ -621,49 +643,49 @@ def handle_request(self, pin, integration, span, args, kwargs):
621643
def _patched_convert(openai, integration):
622644
def patched_convert(func, args, kwargs):
623645
"""Patch convert captures header information in the openai response"""
624-
pin = Pin._find(openai, args[0])
646+
pin = Pin.get_from(openai)
625647
if not pin or not pin.enabled():
626648
return func(*args, **kwargs)
627649

628650
span = pin.tracer.current_span()
629651
if not span:
630652
return func(*args, **kwargs)
631653

632-
for val in args:
633-
if not isinstance(val, openai.openai_response.OpenAIResponse):
634-
continue
654+
val = args[0]
655+
if not isinstance(val, openai.openai_response.OpenAIResponse):
656+
return func(*args, **kwargs)
635657

636-
# This function is called for each chunk in the stream.
637-
# To prevent needlessly setting the same tags for each chunk, short-circuit here.
638-
if span.get_tag("openai.organization.name") is not None:
639-
continue
658+
# This function is called for each chunk in the stream.
659+
# To prevent needlessly setting the same tags for each chunk, short-circuit here.
660+
if span.get_tag("openai.organization.name") is not None:
661+
return func(*args, **kwargs)
662+
663+
val = val._headers
664+
if val.get("openai-organization"):
665+
org_name = val.get("openai-organization")
666+
span.set_tag("openai.organization.name", org_name)
640667

641-
val = val._headers
642-
if val.get("openai-organization"):
643-
org_name = val.get("openai-organization")
644-
span.set_tag("openai.organization.name", org_name)
645-
646-
# Gauge total rate limit
647-
if val.get("x-ratelimit-limit-requests"):
648-
v = val.get("x-ratelimit-limit-requests")
649-
if v is not None:
650-
integration.metric(span, "gauge", "ratelimit.requests", v)
651-
if val.get("x-ratelimit-limit-tokens"):
652-
v = val.get("x-ratelimit-limit-tokens")
653-
if v is not None:
654-
integration.metric(span, "gauge", "ratelimit.tokens", v)
655-
656-
# Gauge and set span info for remaining requests and tokens
657-
if val.get("x-ratelimit-remaining-requests"):
658-
v = val.get("x-ratelimit-remaining-requests")
659-
if v is not None:
660-
integration.metric(span, "gauge", "ratelimit.remaining.requests", v)
661-
span.set_tag("openai.organization.ratelimit.requests.remaining", v)
662-
if val.get("x-ratelimit-remaining-tokens"):
663-
v = val.get("x-ratelimit-remaining-tokens")
664-
if v is not None:
665-
integration.metric(span, "gauge", "ratelimit.remaining.tokens", v)
666-
span.set_tag("openai.organization.ratelimit.tokens.remaining", v)
668+
# Gauge total rate limit
669+
if val.get("x-ratelimit-limit-requests"):
670+
v = val.get("x-ratelimit-limit-requests")
671+
if v is not None:
672+
integration.metric(span, "gauge", "ratelimit.requests", v)
673+
if val.get("x-ratelimit-limit-tokens"):
674+
v = val.get("x-ratelimit-limit-tokens")
675+
if v is not None:
676+
integration.metric(span, "gauge", "ratelimit.tokens", v)
677+
678+
# Gauge and set span info for remaining requests and tokens
679+
if val.get("x-ratelimit-remaining-requests"):
680+
v = val.get("x-ratelimit-remaining-requests")
681+
if v is not None:
682+
integration.metric(span, "gauge", "ratelimit.remaining.requests", v)
683+
span.set_tag("openai.organization.ratelimit.requests.remaining", v)
684+
if val.get("x-ratelimit-remaining-tokens"):
685+
v = val.get("x-ratelimit-remaining-tokens")
686+
if v is not None:
687+
integration.metric(span, "gauge", "ratelimit.remaining.tokens", v)
688+
span.set_tag("openai.organization.ratelimit.tokens.remaining", v)
667689
return func(*args, **kwargs)
668690

669691
return patched_convert

0 commit comments

Comments
 (0)