Skip to content

Commit f9e090f

Browse files
fix(openai): ensure streamed spans with error are manually finished [backport 1.19] (#6911)
Backport f3beaf4 from #6891 to 1.19. Resolves #6769. This PR fixes an unhandled case in our OpenAI integration specifically regarding streamed `Chat/Completion` requests that have errors and result in an empty response. In this case, streamed spans were not finished (as we avoided finishing streamed response spans until the underlying generator was exhausted) because we did not handle empty response cases. The fix here is to add manual span finishing for streamed response spans with error, as we already do that for non-streamed spans. The only risk is if there are requests that we trace that might result in a non-empty response even with an error. This is highly unlikely as we shouldn't get any response if the corresponding request was faulty. Note: this PR also moves tagging prompt token usage information to the `EndpointHook.process_response()` instead of the `EndpointHook.process_request()` handler as keeping that in the latter results in prompt token information being recorded, even if there were no actual prompt/completion operation happening for OpenAI due to the erroneous request. ## Checklist - [x] Change(s) are motivated and described in the PR description. - [x] Testing strategy is described if automated tests are not included in the PR. - [x] Risk is outlined (performance impact, potential for breakage, maintainability, etc). - [x] Change is maintainable (easy to change, telemetry, documentation). - [x] [Library release note guidelines](https://ddtrace.readthedocs.io/en/stable/releasenotes.html) are followed. If no release note is required, add label `changelog/no-changelog`. - [x] Documentation is included (in-code, generated user docs, [public corp docs](https://github.com/DataDog/documentation/)). - [x] Backport labels are set (if [applicable](https://ddtrace.readthedocs.io/en/latest/contributing.html#backporting)) ## Reviewer Checklist - [x] Title is accurate. - [x] No unnecessary changes are introduced. - [x] Description motivates each change. - [x] Avoids breaking [API](https://ddtrace.readthedocs.io/en/stable/versioning.html#interfaces) changes unless absolutely necessary. - [x] Testing strategy adequately addresses listed risk(s). - [x] Change is maintainable (easy to change, telemetry, documentation). - [x] Release note makes sense to a user of the library. - [x] Reviewer has explicitly acknowledged and discussed the performance implications of this PR as reported in the benchmarks PR comment. - [x] Backport labels are set in a manner that is consistent with the [release branch maintenance policy](https://ddtrace.readthedocs.io/en/latest/contributing.html#backporting) - [x] If this PR touches code that signs or publishes builds or packages, or handles credentials of any kind, I've requested a review from `@DataDog/security-design-and-guidance`. - [x] This PR doesn't touch any of that. Co-authored-by: Yun Kim <[email protected]>
1 parent 907289b commit f9e090f

File tree

6 files changed

+139
-19
lines changed

6 files changed

+139
-19
lines changed

ddtrace/contrib/openai/_endpoint_hooks.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,6 @@ def shared_gen():
9191
try:
9292
num_prompt_tokens = span.get_metric("openai.response.usage.prompt_tokens") or 0
9393
num_completion_tokens = yield
94-
9594
span.set_metric("openai.response.usage.completion_tokens", num_completion_tokens)
9695
total_tokens = num_prompt_tokens + num_completion_tokens
9796
span.set_metric("openai.response.usage.total_tokens", total_tokens)
@@ -180,8 +179,15 @@ def _record_request(self, pin, integration, span, args, kwargs):
180179
elif prompt:
181180
for idx, p in enumerate(prompt):
182181
span.set_tag_str("openai.request.prompt.%d" % idx, integration.trunc(str(p)))
182+
return
183+
184+
def _record_response(self, pin, integration, span, args, kwargs, resp, error):
185+
if not resp:
186+
return self._handle_response(pin, span, integration, resp)
187+
prompt = kwargs.get("prompt", "")
183188
if kwargs.get("stream"):
184189
num_prompt_tokens = 0
190+
estimated = False
185191
if isinstance(prompt, str) or isinstance(prompt, list) and isinstance(prompt[0], int):
186192
estimated, prompt_tokens = _compute_prompt_token_count(prompt, kwargs.get("model"))
187193
num_prompt_tokens += prompt_tokens
@@ -191,10 +197,6 @@ def _record_request(self, pin, integration, span, args, kwargs):
191197
num_prompt_tokens += prompt_tokens
192198
span.set_metric("openai.request.prompt_tokens_estimated", int(estimated))
193199
span.set_metric("openai.response.usage.prompt_tokens", num_prompt_tokens)
194-
return
195-
196-
def _record_response(self, pin, integration, span, args, kwargs, resp, error):
197-
if not resp or kwargs.get("stream"):
198200
return self._handle_response(pin, span, integration, resp)
199201
if "choices" in resp:
200202
choices = resp["choices"]
@@ -212,7 +214,6 @@ def _record_response(self, pin, integration, span, args, kwargs, resp, error):
212214
span.set_tag_str("openai.response.choices.%d.text" % idx, integration.trunc(choice.get("text")))
213215
integration.record_usage(span, resp.get("usage"))
214216
if integration.is_pc_sampled_log(span):
215-
prompt = kwargs.get("prompt", "")
216217
integration.log(
217218
span,
218219
"info" if error is None else "error",
@@ -254,19 +255,20 @@ def _record_request(self, pin, integration, span, args, kwargs):
254255
span.set_tag_str("openai.request.messages.%d.content" % idx, content)
255256
span.set_tag_str("openai.request.messages.%d.role" % idx, role)
256257
span.set_tag_str("openai.request.messages.%d.name" % idx, name)
258+
return
259+
260+
def _record_response(self, pin, integration, span, args, kwargs, resp, error):
261+
if not resp:
262+
return self._handle_response(pin, span, integration, resp)
263+
messages = kwargs.get("messages")
257264
if kwargs.get("stream"):
258-
# streamed responses do not have a usage field, so we have to
259-
# estimate the number of tokens returned.
260265
est_num_message_tokens = 0
266+
estimated = False
261267
for m in messages:
262268
estimated, prompt_tokens = _compute_prompt_token_count(m.get("content", ""), kwargs.get("model"))
263269
est_num_message_tokens += prompt_tokens
264270
span.set_metric("openai.request.prompt_tokens_estimated", int(estimated))
265271
span.set_metric("openai.response.usage.prompt_tokens", est_num_message_tokens)
266-
return
267-
268-
def _record_response(self, pin, integration, span, args, kwargs, resp, error):
269-
if not resp or kwargs.get("stream"):
270272
return self._handle_response(pin, span, integration, resp)
271273
choices = resp.get("choices", [])
272274
span.set_metric("openai.response.choices_count", len(choices))
@@ -291,7 +293,6 @@ def _record_response(self, pin, integration, span, args, kwargs, resp, error):
291293
)
292294
integration.record_usage(span, resp.get("usage"))
293295
if integration.is_pc_sampled_log(span):
294-
messages = kwargs.get("messages")
295296
integration.log(
296297
span,
297298
"info" if error is None else "error",

ddtrace/contrib/openai/patch.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -349,6 +349,7 @@ def _patched_make_session(func, args, kwargs):
349349
def _traced_endpoint(endpoint_hook, integration, pin, args, kwargs):
350350
span = integration.trace(pin, endpoint_hook.OPERATION_ID)
351351
openai_api_key = _format_openai_api_key(kwargs.get("api_key"))
352+
err = None
352353
if openai_api_key:
353354
# API key can either be set on the import or per request
354355
span.set_tag_str("openai.user.api_key", openai_api_key)
@@ -357,22 +358,23 @@ def _traced_endpoint(endpoint_hook, integration, pin, args, kwargs):
357358
hook = endpoint_hook().handle_request(pin, integration, span, args, kwargs)
358359
hook.send(None)
359360

360-
resp, error = yield
361+
resp, err = yield
361362

362363
# Record any error information
363-
if error is not None:
364+
if err is not None:
364365
span.set_exc_info(*sys.exc_info())
365366
integration.metric(span, "incr", "request.error", 1)
366367

367368
# Pass the response and the error to the hook
368369
try:
369-
hook.send((resp, error))
370+
hook.send((resp, err))
370371
except StopIteration as e:
371-
if error is None:
372+
if err is None:
372373
return e.value
373374
finally:
374-
# Streamed responses will be finished when the generator exits.
375-
if not kwargs.get("stream"):
375+
# Streamed responses will be finished when the generator exits, so finish non-streamed spans here.
376+
# Streamed responses with error will need to be finished manually as well.
377+
if not kwargs.get("stream") or err is not None:
376378
span.finish()
377379
integration.metric(span, "dist", "request.duration", span.duration_ns)
378380

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
---
2+
fixes:
3+
- |
4+
openai: This fix resolves an issue where errors during streamed requests resulted in unfinished spans.
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
interactions:
2+
- request:
3+
body: '{"model": "text-curie-001", "prompt": "how does openai tokenize prompts?",
4+
"temperature": 0.8, "n": 1, "max_tokens": 150, "stream": true}'
5+
headers:
6+
Accept:
7+
- '*/*'
8+
Accept-Encoding:
9+
- gzip, deflate
10+
Connection:
11+
- keep-alive
12+
Content-Length:
13+
- '137'
14+
Content-Type:
15+
- application/json
16+
User-Agent:
17+
- OpenAI/v1 PythonBindings/0.27.2
18+
X-OpenAI-Client-User-Agent:
19+
- '{"bindings_version": "0.27.2", "httplib": "requests", "lang": "python", "lang_version":
20+
"3.10.5", "platform": "macOS-13.5.1-arm64-arm-64bit", "publisher": "openai",
21+
"uname": "Darwin 22.6.0 Darwin Kernel Version 22.6.0: Wed Jul 5 22:22:05
22+
PDT 2023; root:xnu-8796.141.3~6/RELEASE_ARM64_T6000 arm64"}'
23+
method: POST
24+
uri: https://api.openai.com/v1/completions
25+
response:
26+
body:
27+
string: "{\n \"error\": {\n \"message\": \"Incorrect API key provided:
28+
sk-wrong****-key. You can find your API key at https://platform.openai.com/account/api-keys.\",\n
29+
\ \"type\": \"invalid_request_error\",\n \"param\": null,\n \"code\":
30+
\"invalid_api_key\"\n }\n}\n"
31+
headers:
32+
CF-Cache-Status:
33+
- DYNAMIC
34+
CF-RAY:
35+
- 80599159bdd94288-EWR
36+
Connection:
37+
- keep-alive
38+
Content-Length:
39+
- '266'
40+
Content-Type:
41+
- application/json; charset=utf-8
42+
Date:
43+
- Tue, 12 Sep 2023 16:36:09 GMT
44+
Server:
45+
- cloudflare
46+
alt-svc:
47+
- h3=":443"; ma=86400
48+
strict-transport-security:
49+
- max-age=15724800; includeSubDomains
50+
vary:
51+
- Origin
52+
x-request-id:
53+
- 912bc0d688b018590ad4644213b9c72f
54+
status:
55+
code: 401
56+
message: Unauthorized
57+
version: 1

tests/contrib/openai/test_openai.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1645,6 +1645,21 @@ def test_misuse(openai, snapshot_tracer):
16451645
openai.Completion.create(input="wrong arg")
16461646

16471647

1648+
@pytest.mark.snapshot(ignores=["meta.http.useragent", "meta.error.stack"])
1649+
def test_span_finish_on_stream_error(openai, openai_vcr, snapshot_tracer):
1650+
with openai_vcr.use_cassette("completion_stream_wrong_api_key.yaml"):
1651+
with pytest.raises(openai.error.AuthenticationError):
1652+
openai.Completion.create(
1653+
api_key="sk-wrong-api-key",
1654+
model="text-curie-001",
1655+
prompt="how does openai tokenize prompts?",
1656+
temperature=0.8,
1657+
n=1,
1658+
max_tokens=150,
1659+
stream=True,
1660+
)
1661+
1662+
16481663
def test_completion_stream(openai, openai_vcr, mock_metrics, mock_tracer):
16491664
with openai_vcr.use_cassette("completion_streamed.yaml"):
16501665
with mock.patch("ddtrace.contrib.openai.utils.encoding_for_model", create=True) as mock_encoding:
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
[[
2+
{
3+
"name": "openai.request",
4+
"service": "",
5+
"resource": "createCompletion",
6+
"trace_id": 0,
7+
"span_id": 1,
8+
"parent_id": 0,
9+
"type": "",
10+
"error": 1,
11+
"meta": {
12+
"_dd.p.dm": "-0",
13+
"component": "openai",
14+
"error.message": "Incorrect API key provided: sk-wrong****-key. You can find your API key at https://platform.openai.com/account/api-keys.",
15+
"error.stack": "Traceback (most recent call last):\n openai.error.AuthenticationError: Incorrect API key provided: sk-wrong****-key. You can find your API key at https://platform.openai.com/account/api-keys.\n",
16+
"error.type": "openai.error.AuthenticationError",
17+
"language": "python",
18+
"openai.api_base": "https://api.openai.com/v1",
19+
"openai.api_type": "open_ai",
20+
"openai.request.endpoint": "/v1/completions",
21+
"openai.request.max_tokens": "150",
22+
"openai.request.method": "POST",
23+
"openai.request.model": "text-curie-001",
24+
"openai.request.n": "1",
25+
"openai.request.prompt": "how does openai tokenize prompts?",
26+
"openai.request.stream": "True",
27+
"openai.request.temperature": "0.8",
28+
"openai.user.api_key": "sk-...-key",
29+
"runtime-id": "0a0a92d644714949b7544ee81c6d1bf1"
30+
},
31+
"metrics": {
32+
"_dd.measured": 1,
33+
"_dd.top_level": 1,
34+
"_dd.tracer_kr": 1.0,
35+
"_sample_rate": 1.0,
36+
"_sampling_priority_v1": 1,
37+
"process_id": 91222
38+
},
39+
"duration": 291271000,
40+
"start": 1694536282656608000
41+
}]]

0 commit comments

Comments
 (0)