Skip to content

Commit 5caecf1

Browse files
fix(openai): ensure streamed spans with error are manually finished [backport 1.18] (#6910)
Backport f3beaf4 from #6891 to 1.18. Resolves #6769. This PR fixes an unhandled case in our OpenAI integration specifically regarding streamed `Chat/Completion` requests that have errors and result in an empty response. In this case, streamed spans were not finished (as we avoided finishing streamed response spans until the underlying generator was exhausted) because we did not handle empty response cases. The fix here is to add manual span finishing for streamed response spans with error, as we already do that for non-streamed spans. The only risk is if there are requests that we trace that might result in a non-empty response even with an error. This is highly unlikely as we shouldn't get any response if the corresponding request was faulty. Note: this PR also moves tagging prompt token usage information to the `EndpointHook.process_response()` instead of the `EndpointHook.process_request()` handler as keeping that in the latter results in prompt token information being recorded, even if there were no actual prompt/completion operation happening for OpenAI due to the erroneous request. ## Checklist - [x] Change(s) are motivated and described in the PR description. - [x] Testing strategy is described if automated tests are not included in the PR. - [x] Risk is outlined (performance impact, potential for breakage, maintainability, etc). - [x] Change is maintainable (easy to change, telemetry, documentation). - [x] [Library release note guidelines](https://ddtrace.readthedocs.io/en/stable/releasenotes.html) are followed. If no release note is required, add label `changelog/no-changelog`. - [x] Documentation is included (in-code, generated user docs, [public corp docs](https://github.com/DataDog/documentation/)). - [x] Backport labels are set (if [applicable](https://ddtrace.readthedocs.io/en/latest/contributing.html#backporting)) ## Reviewer Checklist - [x] Title is accurate. - [x] No unnecessary changes are introduced. - [x] Description motivates each change. - [x] Avoids breaking [API](https://ddtrace.readthedocs.io/en/stable/versioning.html#interfaces) changes unless absolutely necessary. - [x] Testing strategy adequately addresses listed risk(s). - [x] Change is maintainable (easy to change, telemetry, documentation). - [x] Release note makes sense to a user of the library. - [x] Reviewer has explicitly acknowledged and discussed the performance implications of this PR as reported in the benchmarks PR comment. - [x] Backport labels are set in a manner that is consistent with the [release branch maintenance policy](https://ddtrace.readthedocs.io/en/latest/contributing.html#backporting) - [x] If this PR touches code that signs or publishes builds or packages, or handles credentials of any kind, I've requested a review from `@DataDog/security-design-and-guidance`. - [x] This PR doesn't touch any of that. Co-authored-by: Yun Kim <[email protected]>
1 parent 54480da commit 5caecf1

File tree

6 files changed

+139
-19
lines changed

6 files changed

+139
-19
lines changed

ddtrace/contrib/openai/_endpoint_hooks.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,6 @@ def shared_gen():
9191
try:
9292
num_prompt_tokens = span.get_metric("openai.response.usage.prompt_tokens") or 0
9393
num_completion_tokens = yield
94-
9594
span.set_metric("openai.response.usage.completion_tokens", num_completion_tokens)
9695
total_tokens = num_prompt_tokens + num_completion_tokens
9796
span.set_metric("openai.response.usage.total_tokens", total_tokens)
@@ -180,8 +179,15 @@ def _record_request(self, pin, integration, span, args, kwargs):
180179
elif prompt:
181180
for idx, p in enumerate(prompt):
182181
span.set_tag_str("openai.request.prompt.%d" % idx, integration.trunc(str(p)))
182+
return
183+
184+
def _record_response(self, pin, integration, span, args, kwargs, resp, error):
185+
if not resp:
186+
return self._handle_response(pin, span, integration, resp)
187+
prompt = kwargs.get("prompt", "")
183188
if kwargs.get("stream"):
184189
num_prompt_tokens = 0
190+
estimated = False
185191
if isinstance(prompt, str) or isinstance(prompt, list) and isinstance(prompt[0], int):
186192
estimated, prompt_tokens = _compute_prompt_token_count(prompt, kwargs.get("model"))
187193
num_prompt_tokens += prompt_tokens
@@ -191,10 +197,6 @@ def _record_request(self, pin, integration, span, args, kwargs):
191197
num_prompt_tokens += prompt_tokens
192198
span.set_metric("openai.request.prompt_tokens_estimated", int(estimated))
193199
span.set_metric("openai.response.usage.prompt_tokens", num_prompt_tokens)
194-
return
195-
196-
def _record_response(self, pin, integration, span, args, kwargs, resp, error):
197-
if not resp or kwargs.get("stream"):
198200
return self._handle_response(pin, span, integration, resp)
199201
if "choices" in resp:
200202
choices = resp["choices"]
@@ -212,7 +214,6 @@ def _record_response(self, pin, integration, span, args, kwargs, resp, error):
212214
span.set_tag_str("openai.response.choices.%d.text" % idx, integration.trunc(choice.get("text")))
213215
integration.record_usage(span, resp.get("usage"))
214216
if integration.is_pc_sampled_log(span):
215-
prompt = kwargs.get("prompt", "")
216217
integration.log(
217218
span,
218219
"info" if error is None else "error",
@@ -254,19 +255,20 @@ def _record_request(self, pin, integration, span, args, kwargs):
254255
span.set_tag_str("openai.request.messages.%d.content" % idx, content)
255256
span.set_tag_str("openai.request.messages.%d.role" % idx, role)
256257
span.set_tag_str("openai.request.messages.%d.name" % idx, name)
258+
return
259+
260+
def _record_response(self, pin, integration, span, args, kwargs, resp, error):
261+
if not resp:
262+
return self._handle_response(pin, span, integration, resp)
263+
messages = kwargs.get("messages")
257264
if kwargs.get("stream"):
258-
# streamed responses do not have a usage field, so we have to
259-
# estimate the number of tokens returned.
260265
est_num_message_tokens = 0
266+
estimated = False
261267
for m in messages:
262268
estimated, prompt_tokens = _compute_prompt_token_count(m.get("content", ""), kwargs.get("model"))
263269
est_num_message_tokens += prompt_tokens
264270
span.set_metric("openai.request.prompt_tokens_estimated", int(estimated))
265271
span.set_metric("openai.response.usage.prompt_tokens", est_num_message_tokens)
266-
return
267-
268-
def _record_response(self, pin, integration, span, args, kwargs, resp, error):
269-
if not resp or kwargs.get("stream"):
270272
return self._handle_response(pin, span, integration, resp)
271273
choices = resp.get("choices", [])
272274
span.set_metric("openai.response.choices_count", len(choices))
@@ -291,7 +293,6 @@ def _record_response(self, pin, integration, span, args, kwargs, resp, error):
291293
)
292294
integration.record_usage(span, resp.get("usage"))
293295
if integration.is_pc_sampled_log(span):
294-
messages = kwargs.get("messages")
295296
integration.log(
296297
span,
297298
"info" if error is None else "error",

ddtrace/contrib/openai/patch.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -342,6 +342,7 @@ def _patched_make_session(func, args, kwargs):
342342
def _traced_endpoint(endpoint_hook, integration, pin, args, kwargs):
343343
span = integration.trace(pin, endpoint_hook.OPERATION_ID)
344344
openai_api_key = _format_openai_api_key(kwargs.get("api_key"))
345+
err = None
345346
if openai_api_key:
346347
# API key can either be set on the import or per request
347348
span.set_tag_str("openai.user.api_key", openai_api_key)
@@ -350,22 +351,23 @@ def _traced_endpoint(endpoint_hook, integration, pin, args, kwargs):
350351
hook = endpoint_hook().handle_request(pin, integration, span, args, kwargs)
351352
hook.send(None)
352353

353-
resp, error = yield
354+
resp, err = yield
354355

355356
# Record any error information
356-
if error is not None:
357+
if err is not None:
357358
span.set_exc_info(*sys.exc_info())
358359
integration.metric(span, "incr", "request.error", 1)
359360

360361
# Pass the response and the error to the hook
361362
try:
362-
hook.send((resp, error))
363+
hook.send((resp, err))
363364
except StopIteration as e:
364-
if error is None:
365+
if err is None:
365366
return e.value
366367
finally:
367-
# Streamed responses will be finished when the generator exits.
368-
if not kwargs.get("stream"):
368+
# Streamed responses will be finished when the generator exits, so finish non-streamed spans here.
369+
# Streamed responses with error will need to be finished manually as well.
370+
if not kwargs.get("stream") or err is not None:
369371
span.finish()
370372
integration.metric(span, "dist", "request.duration", span.duration_ns)
371373

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
---
2+
fixes:
3+
- |
4+
openai: This fix resolves an issue where errors during streamed requests resulted in unfinished spans.
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
interactions:
2+
- request:
3+
body: '{"model": "text-curie-001", "prompt": "how does openai tokenize prompts?",
4+
"temperature": 0.8, "n": 1, "max_tokens": 150, "stream": true}'
5+
headers:
6+
Accept:
7+
- '*/*'
8+
Accept-Encoding:
9+
- gzip, deflate
10+
Connection:
11+
- keep-alive
12+
Content-Length:
13+
- '137'
14+
Content-Type:
15+
- application/json
16+
User-Agent:
17+
- OpenAI/v1 PythonBindings/0.27.2
18+
X-OpenAI-Client-User-Agent:
19+
- '{"bindings_version": "0.27.2", "httplib": "requests", "lang": "python", "lang_version":
20+
"3.10.5", "platform": "macOS-13.5.1-arm64-arm-64bit", "publisher": "openai",
21+
"uname": "Darwin 22.6.0 Darwin Kernel Version 22.6.0: Wed Jul 5 22:22:05
22+
PDT 2023; root:xnu-8796.141.3~6/RELEASE_ARM64_T6000 arm64"}'
23+
method: POST
24+
uri: https://api.openai.com/v1/completions
25+
response:
26+
body:
27+
string: "{\n \"error\": {\n \"message\": \"Incorrect API key provided:
28+
sk-wrong****-key. You can find your API key at https://platform.openai.com/account/api-keys.\",\n
29+
\ \"type\": \"invalid_request_error\",\n \"param\": null,\n \"code\":
30+
\"invalid_api_key\"\n }\n}\n"
31+
headers:
32+
CF-Cache-Status:
33+
- DYNAMIC
34+
CF-RAY:
35+
- 80599159bdd94288-EWR
36+
Connection:
37+
- keep-alive
38+
Content-Length:
39+
- '266'
40+
Content-Type:
41+
- application/json; charset=utf-8
42+
Date:
43+
- Tue, 12 Sep 2023 16:36:09 GMT
44+
Server:
45+
- cloudflare
46+
alt-svc:
47+
- h3=":443"; ma=86400
48+
strict-transport-security:
49+
- max-age=15724800; includeSubDomains
50+
vary:
51+
- Origin
52+
x-request-id:
53+
- 912bc0d688b018590ad4644213b9c72f
54+
status:
55+
code: 401
56+
message: Unauthorized
57+
version: 1

tests/contrib/openai/test_openai.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1638,6 +1638,21 @@ def test_misuse(openai, snapshot_tracer):
16381638
openai.Completion.create(input="wrong arg")
16391639

16401640

1641+
@pytest.mark.snapshot(ignores=["meta.http.useragent", "meta.error.stack"])
1642+
def test_span_finish_on_stream_error(openai, openai_vcr, snapshot_tracer):
1643+
with openai_vcr.use_cassette("completion_stream_wrong_api_key.yaml"):
1644+
with pytest.raises(openai.error.AuthenticationError):
1645+
openai.Completion.create(
1646+
api_key="sk-wrong-api-key",
1647+
model="text-curie-001",
1648+
prompt="how does openai tokenize prompts?",
1649+
temperature=0.8,
1650+
n=1,
1651+
max_tokens=150,
1652+
stream=True,
1653+
)
1654+
1655+
16411656
def test_completion_stream(openai, openai_vcr, mock_metrics, mock_tracer):
16421657
with openai_vcr.use_cassette("completion_streamed.yaml"):
16431658
with mock.patch("ddtrace.contrib.openai.utils.encoding_for_model", create=True) as mock_encoding:
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
[[
2+
{
3+
"name": "openai.request",
4+
"service": "",
5+
"resource": "createCompletion",
6+
"trace_id": 0,
7+
"span_id": 1,
8+
"parent_id": 0,
9+
"type": "",
10+
"error": 1,
11+
"meta": {
12+
"_dd.p.dm": "-0",
13+
"component": "openai",
14+
"error.message": "Incorrect API key provided: sk-wrong****-key. You can find your API key at https://platform.openai.com/account/api-keys.",
15+
"error.stack": "Traceback (most recent call last):\n openai.error.AuthenticationError: Incorrect API key provided: sk-wrong****-key. You can find your API key at https://platform.openai.com/account/api-keys.\n",
16+
"error.type": "openai.error.AuthenticationError",
17+
"language": "python",
18+
"openai.api_base": "https://api.openai.com/v1",
19+
"openai.api_type": "open_ai",
20+
"openai.request.endpoint": "/v1/completions",
21+
"openai.request.max_tokens": "150",
22+
"openai.request.method": "POST",
23+
"openai.request.model": "text-curie-001",
24+
"openai.request.n": "1",
25+
"openai.request.prompt": "how does openai tokenize prompts?",
26+
"openai.request.stream": "True",
27+
"openai.request.temperature": "0.8",
28+
"openai.user.api_key": "sk-...-key",
29+
"runtime-id": "0a0a92d644714949b7544ee81c6d1bf1"
30+
},
31+
"metrics": {
32+
"_dd.agent_psr": 1.0,
33+
"_dd.measured": 1,
34+
"_dd.top_level": 1,
35+
"_dd.tracer_kr": 1.0,
36+
"_sampling_priority_v1": 1,
37+
"process_id": 91222
38+
},
39+
"duration": 291271000,
40+
"start": 1694536282656608000
41+
}]]

0 commit comments

Comments
 (0)