@@ -177,6 +177,20 @@ class OpenAiDefinition:
177177 sync = False ,
178178 min_version = "1.66.0" ,
179179 ),
180+ OpenAiDefinition (
181+ module = "openai.resources.embeddings" ,
182+ object = "Embeddings" ,
183+ method = "create" ,
184+ type = "embedding" ,
185+ sync = True ,
186+ ),
187+ OpenAiDefinition (
188+ module = "openai.resources.embeddings" ,
189+ object = "AsyncEmbeddings" ,
190+ method = "create" ,
191+ type = "embedding" ,
192+ sync = False ,
193+ ),
180194]
181195
182196
@@ -340,10 +354,13 @@ def _extract_chat_response(kwargs: Any) -> Any:
340354
341355
342356def _get_langfuse_data_from_kwargs (resource : OpenAiDefinition , kwargs : Any ) -> Any :
343- name = kwargs .get ("name" , "OpenAI-generation" )
357+ default_name = (
358+ "OpenAI-embedding" if resource .type == "embedding" else "OpenAI-generation"
359+ )
360+ name = kwargs .get ("name" , default_name )
344361
345362 if name is None :
346- name = "OpenAI-generation"
363+ name = default_name
347364
348365 if name is not None and not isinstance (name , str ):
349366 raise TypeError ("name must be a string" )
@@ -395,6 +412,8 @@ def _get_langfuse_data_from_kwargs(resource: OpenAiDefinition, kwargs: Any) -> A
395412 prompt = kwargs .get ("input" , None )
396413 elif resource .type == "chat" :
397414 prompt = _extract_chat_prompt (kwargs )
415+ elif resource .type == "embedding" :
416+ prompt = kwargs .get ("input" , None )
398417
399418 parsed_temperature = (
400419 kwargs .get ("temperature" , 1 )
@@ -440,23 +459,41 @@ def _get_langfuse_data_from_kwargs(resource: OpenAiDefinition, kwargs: Any) -> A
440459
441460 parsed_n = kwargs .get ("n" , 1 ) if not isinstance (kwargs .get ("n" , 1 ), NotGiven ) else 1
442461
443- modelParameters = {
444- "temperature" : parsed_temperature ,
445- "max_tokens" : parsed_max_tokens , # casing?
446- "top_p" : parsed_top_p ,
447- "frequency_penalty" : parsed_frequency_penalty ,
448- "presence_penalty" : parsed_presence_penalty ,
449- }
462+ if resource .type == "embedding" :
463+ parsed_dimensions = (
464+ kwargs .get ("dimensions" , None )
465+ if not isinstance (kwargs .get ("dimensions" , None ), NotGiven )
466+ else None
467+ )
468+ parsed_encoding_format = (
469+ kwargs .get ("encoding_format" , "float" )
470+ if not isinstance (kwargs .get ("encoding_format" , "float" ), NotGiven )
471+ else "float"
472+ )
450473
451- if parsed_max_completion_tokens is not None :
452- modelParameters .pop ("max_tokens" , None )
453- modelParameters ["max_completion_tokens" ] = parsed_max_completion_tokens
474+ modelParameters = {}
475+ if parsed_dimensions is not None :
476+ modelParameters ["dimensions" ] = parsed_dimensions
477+ if parsed_encoding_format != "float" :
478+ modelParameters ["encoding_format" ] = parsed_encoding_format
479+ else :
480+ modelParameters = {
481+ "temperature" : parsed_temperature ,
482+ "max_tokens" : parsed_max_tokens ,
483+ "top_p" : parsed_top_p ,
484+ "frequency_penalty" : parsed_frequency_penalty ,
485+ "presence_penalty" : parsed_presence_penalty ,
486+ }
454487
455- if parsed_n is not None and parsed_n > 1 :
456- modelParameters ["n" ] = parsed_n
488+ if parsed_max_completion_tokens is not None :
489+ modelParameters .pop ("max_tokens" , None )
490+ modelParameters ["max_completion_tokens" ] = parsed_max_completion_tokens
457491
458- if parsed_seed is not None :
459- modelParameters ["seed" ] = parsed_seed
492+ if parsed_n is not None and parsed_n > 1 :
493+ modelParameters ["n" ] = parsed_n
494+
495+ if parsed_seed is not None :
496+ modelParameters ["seed" ] = parsed_seed
460497
461498 langfuse_prompt = kwargs .get ("langfuse_prompt" , None )
462499
@@ -521,6 +558,14 @@ def _parse_usage(usage: Optional[Any] = None) -> Any:
521558 k : v for k , v in tokens_details_dict .items () if v is not None
522559 }
523560
561+ if (
562+ len (usage_dict ) == 2
563+ and "prompt_tokens" in usage_dict
564+ and "total_tokens" in usage_dict
565+ ):
566+ # handle embedding usage
567+ return {"input" : usage_dict ["prompt_tokens" ]}
568+
524569 return usage_dict
525570
526571
@@ -646,7 +691,7 @@ def _extract_streamed_openai_response(resource: Any, chunks: Any) -> Any:
646691 curr [- 1 ]["arguments" ] = ""
647692
648693 curr [- 1 ]["arguments" ] += getattr (
649- tool_call_chunk , "arguments" , None
694+ tool_call_chunk , "arguments" , ""
650695 )
651696
652697 if resource .type == "completion" :
@@ -729,6 +774,20 @@ def _get_langfuse_data_from_default_response(
729774 else choice .get ("message" , None )
730775 )
731776
777+ elif resource .type == "embedding" :
778+ data = response .get ("data" , [])
779+ if len (data ) > 0 :
780+ first_embedding = data [0 ]
781+ embedding_vector = (
782+ first_embedding .embedding
783+ if hasattr (first_embedding , "embedding" )
784+ else first_embedding .get ("embedding" , [])
785+ )
786+ completion = {
787+ "dimensions" : len (embedding_vector ) if embedding_vector else 0 ,
788+ "count" : len (data ),
789+ }
790+
732791 usage = _parse_usage (response .get ("usage" , None ))
733792
734793 return (model , completion , usage )
@@ -757,8 +816,12 @@ def _wrap(
757816 langfuse_data = _get_langfuse_data_from_kwargs (open_ai_resource , langfuse_args )
758817 langfuse_client = get_client (public_key = langfuse_args ["langfuse_public_key" ])
759818
819+ observation_type = (
820+ "embedding" if open_ai_resource .type == "embedding" else "generation"
821+ )
822+
760823 generation = langfuse_client .start_observation (
761- as_type = "generation" ,
824+ as_type = observation_type , # type: ignore
762825 name = langfuse_data ["name" ],
763826 input = langfuse_data .get ("input" , None ),
764827 metadata = langfuse_data .get ("metadata" , None ),
@@ -824,8 +887,12 @@ async def _wrap_async(
824887 langfuse_data = _get_langfuse_data_from_kwargs (open_ai_resource , langfuse_args )
825888 langfuse_client = get_client (public_key = langfuse_args ["langfuse_public_key" ])
826889
890+ observation_type = (
891+ "embedding" if open_ai_resource .type == "embedding" else "generation"
892+ )
893+
827894 generation = langfuse_client .start_observation (
828- as_type = "generation" ,
895+ as_type = observation_type , # type: ignore
829896 name = langfuse_data ["name" ],
830897 input = langfuse_data .get ("input" , None ),
831898 metadata = langfuse_data .get ("metadata" , None ),
0 commit comments