Skip to content

Commit f6b740e

Browse files
authored
fix(openai): ensure embeddings input gets pc sampled correctly [backport to 1.14] (#6074)
(cherry picked from commit 93d700e) ## Checklist - [x] Change(s) are motivated and described in the PR description. - [x] Testing strategy is described if automated tests are not included in the PR. - [x] Risk is outlined (performance impact, potential for breakage, maintainability, etc). - [x] Change is maintainable (easy to change, telemetry, documentation). - [x] [Library release note guidelines](https://ddtrace.readthedocs.io/en/stable/contributing.html#Release-Note-Guidelines) are followed. - [x] Documentation is included (in-code, generated user docs, [public corp docs](https://github.com/DataDog/documentation/)). ## Reviewer Checklist - [x] Title is accurate. - [x] No unnecessary changes are introduced. - [x] Description motivates each change. - [x] Avoids breaking [API](https://ddtrace.readthedocs.io/en/stable/versioning.html#interfaces) changes unless absolutely necessary. - [x] Testing strategy adequately addresses listed risk(s). - [x] Change is maintainable (easy to change, telemetry, documentation). - [x] Release note makes sense to a user of the library. - [x] Reviewer has explicitly acknowledged and discussed the performance implications of this PR as reported in the benchmarks PR comment.
1 parent 4eb4e27 commit f6b740e

File tree

3 files changed

+38
-13
lines changed

3 files changed

+38
-13
lines changed

ddtrace/contrib/openai/patch.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -623,27 +623,26 @@ def handle_request(self, pin, integration, span, args, kwargs):
623623
return self._handle_response(pin, span, integration, resp)
624624

625625

626-
class _EmbeddingHook(_EndpointHook):
626+
class _EmbeddingHook(_BaseCompletionHook):
627+
_request_tag_attrs = ["model", "user"]
628+
627629
def handle_request(self, pin, integration, span, args, kwargs):
628-
for kw_attr in ["model", "input", "user"]:
629-
if kw_attr in kwargs:
630-
if kw_attr == "input" and integration.is_pc_sampled_span(span):
631-
if isinstance(kwargs["input"], list):
632-
for idx, inp in enumerate(kwargs["input"]):
633-
span.set_tag_str("openai.request.input.%d" % idx, integration.trunc(str(inp)))
634-
else:
635-
span.set_tag("openai.request.%s" % kw_attr, kwargs[kw_attr])
636-
else:
637-
span.set_tag("openai.request.%s" % kw_attr, kwargs[kw_attr])
630+
embedding_input = kwargs.get("input", "")
631+
if integration.is_pc_sampled_span(span):
632+
if isinstance(embedding_input, list):
633+
for idx, inp in enumerate(embedding_input):
634+
span.set_tag_str("openai.request.input.%d" % idx, integration.trunc(str(inp)))
635+
else:
636+
span.set_tag("openai.request.input", embedding_input)
637+
638+
self._record_request(span, kwargs)
638639

639640
resp, error = yield
640641

641642
if resp:
642643
if "data" in resp:
643644
span.set_tag("openai.response.data.num-embeddings", len(resp["data"]))
644645
span.set_tag("openai.response.data.embedding-length", len(resp["data"][0]["embedding"]))
645-
if "object" in kwargs:
646-
span.set_tag("openai.response.%s" % kw_attr, kwargs[kw_attr])
647646
integration.record_usage(span, resp.get("usage"))
648647
return resp
649648

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
fixes:
3+
- |
4+
openai: This fix resolves an issue where embeddings inputs were always tagged regardless of the
5+
configured prompt-completion sample rate.

tests/contrib/openai/test_openai.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -951,6 +951,27 @@ def test_completion_truncation(openai, openai_vcr, mock_tracer):
951951
assert len(completion.replace("...", "")) == limit
952952

953953

954+
@pytest.mark.parametrize(
955+
"ddtrace_config_openai",
956+
[
957+
dict(
958+
_api_key="<not-real-but-it's-something>",
959+
span_prompt_completion_sample_rate=0,
960+
)
961+
],
962+
)
963+
def test_embedding_unsampled_prompt_completion(openai, openai_vcr, ddtrace_config_openai, mock_logs, mock_tracer):
964+
if not hasattr(openai, "Embedding"):
965+
pytest.skip("embedding not supported for this version of openai")
966+
with openai_vcr.use_cassette("embedding.yaml"):
967+
openai.Embedding.create(input="hello world", model="text-embedding-ada-002")
968+
logs = mock_logs.enqueue.call_count
969+
traces = mock_tracer.pop_traces()
970+
assert len(traces) == 1
971+
assert traces[0][0].get_tag("openai.request.input") is None
972+
assert logs == 0
973+
974+
954975
@pytest.mark.parametrize(
955976
"ddtrace_config_openai",
956977
[

0 commit comments

Comments
 (0)