Skip to content

Commit f809baf

Browse files
authored
feat(openai): add openai embeddings api support (#1345) (#1372)
1 parent 37468a1 commit f809baf

File tree

2 files changed

+176
-19
lines changed

2 files changed

+176
-19
lines changed

langfuse/openai.py

Lines changed: 86 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,20 @@ class OpenAiDefinition:
177177
sync=False,
178178
min_version="1.66.0",
179179
),
180+
OpenAiDefinition(
181+
module="openai.resources.embeddings",
182+
object="Embeddings",
183+
method="create",
184+
type="embedding",
185+
sync=True,
186+
),
187+
OpenAiDefinition(
188+
module="openai.resources.embeddings",
189+
object="AsyncEmbeddings",
190+
method="create",
191+
type="embedding",
192+
sync=False,
193+
),
180194
]
181195

182196

@@ -340,10 +354,13 @@ def _extract_chat_response(kwargs: Any) -> Any:
340354

341355

342356
def _get_langfuse_data_from_kwargs(resource: OpenAiDefinition, kwargs: Any) -> Any:
343-
name = kwargs.get("name", "OpenAI-generation")
357+
default_name = (
358+
"OpenAI-embedding" if resource.type == "embedding" else "OpenAI-generation"
359+
)
360+
name = kwargs.get("name", default_name)
344361

345362
if name is None:
346-
name = "OpenAI-generation"
363+
name = default_name
347364

348365
if name is not None and not isinstance(name, str):
349366
raise TypeError("name must be a string")
@@ -395,6 +412,8 @@ def _get_langfuse_data_from_kwargs(resource: OpenAiDefinition, kwargs: Any) -> A
395412
prompt = kwargs.get("input", None)
396413
elif resource.type == "chat":
397414
prompt = _extract_chat_prompt(kwargs)
415+
elif resource.type == "embedding":
416+
prompt = kwargs.get("input", None)
398417

399418
parsed_temperature = (
400419
kwargs.get("temperature", 1)
@@ -440,23 +459,41 @@ def _get_langfuse_data_from_kwargs(resource: OpenAiDefinition, kwargs: Any) -> A
440459

441460
parsed_n = kwargs.get("n", 1) if not isinstance(kwargs.get("n", 1), NotGiven) else 1
442461

443-
modelParameters = {
444-
"temperature": parsed_temperature,
445-
"max_tokens": parsed_max_tokens, # casing?
446-
"top_p": parsed_top_p,
447-
"frequency_penalty": parsed_frequency_penalty,
448-
"presence_penalty": parsed_presence_penalty,
449-
}
462+
if resource.type == "embedding":
463+
parsed_dimensions = (
464+
kwargs.get("dimensions", None)
465+
if not isinstance(kwargs.get("dimensions", None), NotGiven)
466+
else None
467+
)
468+
parsed_encoding_format = (
469+
kwargs.get("encoding_format", "float")
470+
if not isinstance(kwargs.get("encoding_format", "float"), NotGiven)
471+
else "float"
472+
)
450473

451-
if parsed_max_completion_tokens is not None:
452-
modelParameters.pop("max_tokens", None)
453-
modelParameters["max_completion_tokens"] = parsed_max_completion_tokens
474+
modelParameters = {}
475+
if parsed_dimensions is not None:
476+
modelParameters["dimensions"] = parsed_dimensions
477+
if parsed_encoding_format != "float":
478+
modelParameters["encoding_format"] = parsed_encoding_format
479+
else:
480+
modelParameters = {
481+
"temperature": parsed_temperature,
482+
"max_tokens": parsed_max_tokens,
483+
"top_p": parsed_top_p,
484+
"frequency_penalty": parsed_frequency_penalty,
485+
"presence_penalty": parsed_presence_penalty,
486+
}
454487

455-
if parsed_n is not None and parsed_n > 1:
456-
modelParameters["n"] = parsed_n
488+
if parsed_max_completion_tokens is not None:
489+
modelParameters.pop("max_tokens", None)
490+
modelParameters["max_completion_tokens"] = parsed_max_completion_tokens
457491

458-
if parsed_seed is not None:
459-
modelParameters["seed"] = parsed_seed
492+
if parsed_n is not None and parsed_n > 1:
493+
modelParameters["n"] = parsed_n
494+
495+
if parsed_seed is not None:
496+
modelParameters["seed"] = parsed_seed
460497

461498
langfuse_prompt = kwargs.get("langfuse_prompt", None)
462499

@@ -521,6 +558,14 @@ def _parse_usage(usage: Optional[Any] = None) -> Any:
521558
k: v for k, v in tokens_details_dict.items() if v is not None
522559
}
523560

561+
if (
562+
len(usage_dict) == 2
563+
and "prompt_tokens" in usage_dict
564+
and "total_tokens" in usage_dict
565+
):
566+
# handle embedding usage
567+
return {"input": usage_dict["prompt_tokens"]}
568+
524569
return usage_dict
525570

526571

@@ -646,7 +691,7 @@ def _extract_streamed_openai_response(resource: Any, chunks: Any) -> Any:
646691
curr[-1]["arguments"] = ""
647692

648693
curr[-1]["arguments"] += getattr(
649-
tool_call_chunk, "arguments", None
694+
tool_call_chunk, "arguments", ""
650695
)
651696

652697
if resource.type == "completion":
@@ -729,6 +774,20 @@ def _get_langfuse_data_from_default_response(
729774
else choice.get("message", None)
730775
)
731776

777+
elif resource.type == "embedding":
778+
data = response.get("data", [])
779+
if len(data) > 0:
780+
first_embedding = data[0]
781+
embedding_vector = (
782+
first_embedding.embedding
783+
if hasattr(first_embedding, "embedding")
784+
else first_embedding.get("embedding", [])
785+
)
786+
completion = {
787+
"dimensions": len(embedding_vector) if embedding_vector else 0,
788+
"count": len(data),
789+
}
790+
732791
usage = _parse_usage(response.get("usage", None))
733792

734793
return (model, completion, usage)
@@ -757,8 +816,12 @@ def _wrap(
757816
langfuse_data = _get_langfuse_data_from_kwargs(open_ai_resource, langfuse_args)
758817
langfuse_client = get_client(public_key=langfuse_args["langfuse_public_key"])
759818

819+
observation_type = (
820+
"embedding" if open_ai_resource.type == "embedding" else "generation"
821+
)
822+
760823
generation = langfuse_client.start_observation(
761-
as_type="generation",
824+
as_type=observation_type, # type: ignore
762825
name=langfuse_data["name"],
763826
input=langfuse_data.get("input", None),
764827
metadata=langfuse_data.get("metadata", None),
@@ -824,8 +887,12 @@ async def _wrap_async(
824887
langfuse_data = _get_langfuse_data_from_kwargs(open_ai_resource, langfuse_args)
825888
langfuse_client = get_client(public_key=langfuse_args["langfuse_public_key"])
826889

890+
observation_type = (
891+
"embedding" if open_ai_resource.type == "embedding" else "generation"
892+
)
893+
827894
generation = langfuse_client.start_observation(
828-
as_type="generation",
895+
as_type=observation_type, # type: ignore
829896
name=langfuse_data["name"],
830897
input=langfuse_data.get("input", None),
831898
metadata=langfuse_data.get("metadata", None),

tests/test_openai.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1514,3 +1514,93 @@ def test_response_api_reasoning(openai):
15141514
assert generationData.usage.total is not None
15151515
assert generationData.output is not None
15161516
assert generationData.metadata is not None
1517+
1518+
1519+
def test_openai_embeddings(openai):
1520+
embedding_name = create_uuid()
1521+
openai.OpenAI().embeddings.create(
1522+
name=embedding_name,
1523+
model="text-embedding-ada-002",
1524+
input="The quick brown fox jumps over the lazy dog",
1525+
metadata={"test_key": "test_value"},
1526+
)
1527+
1528+
langfuse.flush()
1529+
sleep(1)
1530+
1531+
embedding = get_api().observations.get_many(name=embedding_name, type="EMBEDDING")
1532+
1533+
assert len(embedding.data) != 0
1534+
embedding_data = embedding.data[0]
1535+
assert embedding_data.name == embedding_name
1536+
assert embedding_data.metadata["test_key"] == "test_value"
1537+
assert embedding_data.input == "The quick brown fox jumps over the lazy dog"
1538+
assert embedding_data.type == "EMBEDDING"
1539+
assert "text-embedding-ada-002" in embedding_data.model
1540+
assert embedding_data.start_time is not None
1541+
assert embedding_data.end_time is not None
1542+
assert embedding_data.start_time < embedding_data.end_time
1543+
assert embedding_data.usage.input is not None
1544+
assert embedding_data.usage.total is not None
1545+
assert embedding_data.output is not None
1546+
assert "dimensions" in embedding_data.output
1547+
assert "count" in embedding_data.output
1548+
assert embedding_data.output["count"] == 1
1549+
1550+
1551+
def test_openai_embeddings_multiple_inputs(openai):
1552+
embedding_name = create_uuid()
1553+
inputs = ["The quick brown fox", "jumps over the lazy dog", "Hello world"]
1554+
1555+
openai.OpenAI().embeddings.create(
1556+
name=embedding_name,
1557+
model="text-embedding-ada-002",
1558+
input=inputs,
1559+
metadata={"batch_size": len(inputs)},
1560+
)
1561+
1562+
langfuse.flush()
1563+
sleep(1)
1564+
1565+
embedding = get_api().observations.get_many(name=embedding_name, type="EMBEDDING")
1566+
1567+
assert len(embedding.data) != 0
1568+
embedding_data = embedding.data[0]
1569+
assert embedding_data.name == embedding_name
1570+
assert embedding_data.input == inputs
1571+
assert embedding_data.type == "EMBEDDING"
1572+
assert "text-embedding-ada-002" in embedding_data.model
1573+
assert embedding_data.usage.input is not None
1574+
assert embedding_data.usage.total is not None
1575+
assert embedding_data.output["count"] == len(inputs)
1576+
1577+
1578+
@pytest.mark.asyncio
1579+
async def test_async_openai_embeddings(openai):
1580+
client = openai.AsyncOpenAI()
1581+
embedding_name = create_uuid()
1582+
print(embedding_name)
1583+
1584+
result = await client.embeddings.create(
1585+
name=embedding_name,
1586+
model="text-embedding-ada-002",
1587+
input="Async embedding test",
1588+
metadata={"async": True},
1589+
)
1590+
1591+
print("result:", result.usage)
1592+
1593+
langfuse.flush()
1594+
sleep(1)
1595+
1596+
embedding = get_api().observations.get_many(name=embedding_name, type="EMBEDDING")
1597+
1598+
assert len(embedding.data) != 0
1599+
embedding_data = embedding.data[0]
1600+
assert embedding_data.name == embedding_name
1601+
assert embedding_data.input == "Async embedding test"
1602+
assert embedding_data.type == "EMBEDDING"
1603+
assert "text-embedding-ada-002" in embedding_data.model
1604+
assert embedding_data.metadata["async"] is True
1605+
assert embedding_data.usage.input is not None
1606+
assert embedding_data.usage.total is not None

0 commit comments

Comments
 (0)