Skip to content

Commit 89e3be0

Browse files
authored
feat(embed-text): add EmbedText for OpenAI and Gemini (#645)
1 parent 80313a3 commit 89e3be0

File tree

12 files changed

+313
-41
lines changed

12 files changed

+313
-41
lines changed

examples/text_embedding/main.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from dotenv import load_dotenv
22
from psycopg_pool import ConnectionPool
33
from pgvector.psycopg import register_vector
4+
from typing import Any
45
import cocoindex
56
import os
67
from numpy.typing import NDArray
@@ -15,6 +16,13 @@ def text_to_embedding(
1516
Embed the text using a SentenceTransformer model.
1617
This is a shared logic between indexing and querying, so extract it as a function.
1718
"""
19+
# You can also switch to remote embedding model:
20+
# return text.transform(
21+
# cocoindex.functions.EmbedText(
22+
# api_type=cocoindex.llm.LlmApiType.OPENAI,
23+
# model="text-embedding-3-small",
24+
# )
25+
# )
1826
return text.transform(
1927
cocoindex.functions.SentenceTransformerEmbed(
2028
model="sentence-transformers/all-MiniLM-L6-v2"

python/cocoindex/functions.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,16 @@ class SplitRecursively(op.FunctionSpec):
3232
custom_languages: list[CustomLanguageSpec] = dataclasses.field(default_factory=list)
3333

3434

35+
class EmbedText(op.FunctionSpec):
36+
"""Embed a text into a vector space."""
37+
38+
api_type: llm.LlmApiType
39+
model: str
40+
address: str | None = None
41+
output_dimension: int | None = None
42+
task_type: str | None = None
43+
44+
3545
class ExtractByLlm(op.FunctionSpec):
3646
"""Extract information from a text using a LLM."""
3747

src/llm/anthropic.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use crate::llm::{
2-
LlmClient, LlmGenerateRequest, LlmGenerateResponse, OutputFormat, ToJsonSchemaOptions,
2+
LlmGenerateRequest, LlmGenerateResponse, LlmGenerationClient, OutputFormat, ToJsonSchemaOptions,
33
};
44
use anyhow::{Context, Result, bail};
55
use async_trait::async_trait;
@@ -31,7 +31,7 @@ impl Client {
3131
}
3232

3333
#[async_trait]
34-
impl LlmClient for Client {
34+
impl LlmGenerationClient for Client {
3535
async fn generate<'req>(
3636
&self,
3737
request: LlmGenerateRequest<'req>,

src/llm/gemini.rs

Lines changed: 80 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1-
use crate::api_bail;
1+
use crate::prelude::*;
2+
23
use crate::llm::{
3-
LlmClient, LlmGenerateRequest, LlmGenerateResponse, OutputFormat, ToJsonSchemaOptions,
4+
LlmEmbeddingClient, LlmGenerateRequest, LlmGenerateResponse, LlmGenerationClient, OutputFormat,
5+
ToJsonSchemaOptions,
46
};
5-
use anyhow::{Context, Result, bail};
6-
use async_trait::async_trait;
7+
use phf::phf_map;
78
use serde_json::Value;
89
use urlencoding::encode;
910

@@ -13,7 +14,7 @@ pub struct Client {
1314
}
1415

1516
impl Client {
16-
pub async fn new(address: Option<String>) -> Result<Self> {
17+
pub fn new(address: Option<String>) -> Result<Self> {
1718
if address.is_some() {
1819
api_bail!("Gemini doesn't support custom API address");
1920
}
@@ -46,8 +47,19 @@ fn remove_additional_properties(value: &mut Value) {
4647
}
4748
}
4849

50+
impl Client {
51+
fn get_api_url(&self, model: &str, api_name: &str) -> String {
52+
format!(
53+
"https://generativelanguage.googleapis.com/v1beta/models/{}:{}?key={}",
54+
encode(model),
55+
api_name,
56+
encode(&self.api_key)
57+
)
58+
}
59+
}
60+
4961
#[async_trait]
50-
impl LlmClient for Client {
62+
impl LlmGenerationClient for Client {
5163
async fn generate<'req>(
5264
&self,
5365
request: LlmGenerateRequest<'req>,
@@ -76,21 +88,21 @@ impl LlmClient for Client {
7688
});
7789
}
7890

79-
let api_key = &self.api_key;
80-
let url = format!(
81-
"https://generativelanguage.googleapis.com/v1beta/models/{}:generateContent?key={}",
82-
encode(request.model),
83-
encode(api_key)
84-
);
85-
91+
let url = self.get_api_url(request.model, "generateContent");
8692
let resp = self
8793
.client
8894
.post(&url)
8995
.json(&payload)
9096
.send()
9197
.await
9298
.context("HTTP error")?;
93-
99+
if !resp.status().is_success() {
100+
bail!(
101+
"Gemini API error: {:?}\n{}\n",
102+
resp.status(),
103+
resp.text().await?
104+
);
105+
}
94106
let resp_json: Value = resp.json().await.context("Invalid JSON")?;
95107

96108
if let Some(error) = resp_json.get("error") {
@@ -114,3 +126,57 @@ impl LlmClient for Client {
114126
}
115127
}
116128
}
129+
130+
static DEFAULT_EMBEDDING_DIMENSIONS: phf::Map<&str, u32> = phf_map! {
131+
"gemini-embedding-exp-03-07" => 3072,
132+
"text-embedding-004" => 768,
133+
"embedding-001" => 768,
134+
};
135+
136+
#[derive(Deserialize)]
137+
struct ContentEmbedding {
138+
values: Vec<f32>,
139+
}
140+
#[derive(Deserialize)]
141+
struct EmbedContentResponse {
142+
embedding: ContentEmbedding,
143+
}
144+
145+
#[async_trait]
146+
impl LlmEmbeddingClient for Client {
147+
async fn embed_text<'req>(
148+
&self,
149+
request: super::LlmEmbeddingRequest<'req>,
150+
) -> Result<super::LlmEmbeddingResponse> {
151+
let url = self.get_api_url(request.model, "embedContent");
152+
let mut payload = serde_json::json!({
153+
"model": request.model,
154+
"content": { "parts": [{ "text": request.text }] },
155+
});
156+
if let Some(task_type) = request.task_type {
157+
payload["taskType"] = serde_json::Value::String(task_type.into());
158+
}
159+
let resp = self
160+
.client
161+
.post(&url)
162+
.json(&payload)
163+
.send()
164+
.await
165+
.context("HTTP error")?;
166+
if !resp.status().is_success() {
167+
bail!(
168+
"Gemini API error: {:?}\n{}\n",
169+
resp.status(),
170+
resp.text().await?
171+
);
172+
}
173+
let embedding_resp: EmbedContentResponse = resp.json().await.context("Invalid JSON")?;
174+
Ok(super::LlmEmbeddingResponse {
175+
embedding: embedding_resp.embedding.values,
176+
})
177+
}
178+
179+
fn get_default_embedding_dimension(&self, model: &str) -> Option<u32> {
180+
DEFAULT_EMBEDDING_DIMENSIONS.get(model).copied()
181+
}
182+
}

src/llm/mod.rs

Lines changed: 56 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,10 @@
1-
use std::borrow::Cow;
2-
3-
use anyhow::Result;
4-
use async_trait::async_trait;
5-
use schemars::schema::SchemaObject;
6-
use serde::{Deserialize, Serialize};
1+
use crate::prelude::*;
72

83
use crate::base::json_schema::ToJsonSchemaOptions;
4+
use schemars::schema::SchemaObject;
5+
use std::borrow::Cow;
96

10-
#[derive(Debug, Clone, Serialize, Deserialize)]
7+
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
118
pub enum LlmApiType {
129
Ollama,
1310
OpenAi,
@@ -46,7 +43,7 @@ pub struct LlmGenerateResponse {
4643
}
4744

4845
#[async_trait]
49-
pub trait LlmClient: Send + Sync {
46+
pub trait LlmGenerationClient: Send + Sync {
5047
async fn generate<'req>(
5148
&self,
5249
request: LlmGenerateRequest<'req>,
@@ -55,6 +52,28 @@ pub trait LlmClient: Send + Sync {
5552
fn json_schema_options(&self) -> ToJsonSchemaOptions;
5653
}
5754

55+
#[derive(Debug)]
56+
pub struct LlmEmbeddingRequest<'a> {
57+
pub model: &'a str,
58+
pub text: Cow<'a, str>,
59+
pub output_dimension: u32,
60+
pub task_type: Option<Cow<'a, str>>,
61+
}
62+
63+
pub struct LlmEmbeddingResponse {
64+
pub embedding: Vec<f32>,
65+
}
66+
67+
#[async_trait]
68+
pub trait LlmEmbeddingClient: Send + Sync {
69+
async fn embed_text<'req>(
70+
&self,
71+
request: LlmEmbeddingRequest<'req>,
72+
) -> Result<LlmEmbeddingResponse>;
73+
74+
fn get_default_embedding_dimension(&self, model: &str) -> Option<u32>;
75+
}
76+
5877
mod anthropic;
5978
mod gemini;
6079
mod litellm;
@@ -65,20 +84,41 @@ mod openrouter;
6584
pub async fn new_llm_generation_client(
6685
api_type: LlmApiType,
6786
address: Option<String>,
68-
) -> Result<Box<dyn LlmClient>> {
87+
) -> Result<Box<dyn LlmGenerationClient>> {
6988
let client = match api_type {
70-
LlmApiType::Ollama => Box::new(ollama::Client::new(address).await?) as Box<dyn LlmClient>,
71-
LlmApiType::OpenAi => Box::new(openai::Client::new(address).await?) as Box<dyn LlmClient>,
72-
LlmApiType::Gemini => Box::new(gemini::Client::new(address).await?) as Box<dyn LlmClient>,
89+
LlmApiType::Ollama => {
90+
Box::new(ollama::Client::new(address).await?) as Box<dyn LlmGenerationClient>
91+
}
92+
LlmApiType::OpenAi => {
93+
Box::new(openai::Client::new(address)?) as Box<dyn LlmGenerationClient>
94+
}
95+
LlmApiType::Gemini => {
96+
Box::new(gemini::Client::new(address)?) as Box<dyn LlmGenerationClient>
97+
}
7398
LlmApiType::Anthropic => {
74-
Box::new(anthropic::Client::new(address).await?) as Box<dyn LlmClient>
99+
Box::new(anthropic::Client::new(address).await?) as Box<dyn LlmGenerationClient>
75100
}
76101
LlmApiType::LiteLlm => {
77-
Box::new(litellm::Client::new_litellm(address).await?) as Box<dyn LlmClient>
102+
Box::new(litellm::Client::new_litellm(address).await?) as Box<dyn LlmGenerationClient>
103+
}
104+
LlmApiType::OpenRouter => Box::new(openrouter::Client::new_openrouter(address).await?)
105+
as Box<dyn LlmGenerationClient>,
106+
};
107+
Ok(client)
108+
}
109+
110+
pub fn new_llm_embedding_client(
111+
api_type: LlmApiType,
112+
address: Option<String>,
113+
) -> Result<Box<dyn LlmEmbeddingClient>> {
114+
let client = match api_type {
115+
LlmApiType::Gemini => {
116+
Box::new(gemini::Client::new(address)?) as Box<dyn LlmEmbeddingClient>
78117
}
79-
LlmApiType::OpenRouter => {
80-
Box::new(openrouter::Client::new_openrouter(address).await?) as Box<dyn LlmClient>
118+
LlmApiType::OpenAi => {
119+
Box::new(openai::Client::new(address)?) as Box<dyn LlmEmbeddingClient>
81120
}
121+
_ => api_bail!("Embedding is not supported for API type {:?}", api_type),
82122
};
83123
Ok(client)
84124
}

src/llm/ollama.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use super::LlmClient;
1+
use super::LlmGenerationClient;
22
use anyhow::Result;
33
use async_trait::async_trait;
44
use schemars::schema::SchemaObject;
@@ -45,7 +45,7 @@ impl Client {
4545
}
4646

4747
#[async_trait]
48-
impl LlmClient for Client {
48+
impl LlmGenerationClient for Client {
4949
async fn generate<'req>(
5050
&self,
5151
request: super::LlmGenerateRequest<'req>,

src/llm/openai.rs

Lines changed: 43 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,19 @@
11
use crate::api_bail;
22

3-
use super::LlmClient;
3+
use super::{LlmEmbeddingClient, LlmGenerationClient};
44
use anyhow::Result;
55
use async_openai::{
66
Client as OpenAIClient,
77
config::OpenAIConfig,
88
types::{
99
ChatCompletionRequestMessage, ChatCompletionRequestSystemMessage,
1010
ChatCompletionRequestSystemMessageContent, ChatCompletionRequestUserMessage,
11-
ChatCompletionRequestUserMessageContent, CreateChatCompletionRequest, ResponseFormat,
12-
ResponseFormatJsonSchema,
11+
ChatCompletionRequestUserMessageContent, CreateChatCompletionRequest,
12+
CreateEmbeddingRequest, EmbeddingInput, ResponseFormat, ResponseFormatJsonSchema,
1313
},
1414
};
1515
use async_trait::async_trait;
16+
use phf::phf_map;
1617

1718
pub struct Client {
1819
client: async_openai::Client<OpenAIConfig>,
@@ -23,7 +24,7 @@ impl Client {
2324
Self { client }
2425
}
2526

26-
pub async fn new(address: Option<String>) -> Result<Self> {
27+
pub fn new(address: Option<String>) -> Result<Self> {
2728
if let Some(address) = address {
2829
api_bail!("OpenAI doesn't support custom API address: {address}");
2930
}
@@ -39,7 +40,7 @@ impl Client {
3940
}
4041

4142
#[async_trait]
42-
impl LlmClient for Client {
43+
impl LlmGenerationClient for Client {
4344
async fn generate<'req>(
4445
&self,
4546
request: super::LlmGenerateRequest<'req>,
@@ -109,3 +110,40 @@ impl LlmClient for Client {
109110
}
110111
}
111112
}
113+
114+
static DEFAULT_EMBEDDING_DIMENSIONS: phf::Map<&str, u32> = phf_map! {
115+
"text-embedding-3-small" => 1536,
116+
"text-embedding-3-large" => 3072,
117+
"text-embedding-ada-002" => 1536,
118+
};
119+
120+
#[async_trait]
121+
impl LlmEmbeddingClient for Client {
122+
async fn embed_text<'req>(
123+
&self,
124+
request: super::LlmEmbeddingRequest<'req>,
125+
) -> Result<super::LlmEmbeddingResponse> {
126+
let response = self
127+
.client
128+
.embeddings()
129+
.create(CreateEmbeddingRequest {
130+
model: request.model.to_string(),
131+
input: EmbeddingInput::String(request.text.to_string()),
132+
dimensions: Some(request.output_dimension),
133+
..Default::default()
134+
})
135+
.await?;
136+
Ok(super::LlmEmbeddingResponse {
137+
embedding: response
138+
.data
139+
.into_iter()
140+
.next()
141+
.ok_or_else(|| anyhow::anyhow!("No embedding returned from OpenAI"))?
142+
.embedding,
143+
})
144+
}
145+
146+
fn get_default_embedding_dimension(&self, model: &str) -> Option<u32> {
147+
DEFAULT_EMBEDDING_DIMENSIONS.get(model).copied()
148+
}
149+
}

0 commit comments

Comments
 (0)