Skip to content

Commit 65cbf27

Browse files
authored
feat(text-embed): add support for Voyage embedding model (#648)
1 parent 780a9f8 commit 65cbf27

File tree

7 files changed

+154
-18
lines changed

7 files changed

+154
-18
lines changed

examples/code_embedding/main.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
from dotenv import load_dotenv
22
from psycopg_pool import ConnectionPool
3+
from pgvector.psycopg import register_vector
4+
from typing import Any
35
import cocoindex
46
import os
7+
from numpy.typing import NDArray
8+
import numpy as np
59

610

711
@cocoindex.op.function()
@@ -13,10 +17,17 @@ def extract_extension(filename: str) -> str:
1317
@cocoindex.transform_flow()
1418
def code_to_embedding(
1519
text: cocoindex.DataSlice[str],
16-
) -> cocoindex.DataSlice[list[float]]:
20+
) -> cocoindex.DataSlice[NDArray[np.float32]]:
1721
"""
1822
Embed the text using a SentenceTransformer model.
1923
"""
24+
# You can also switch to Voyage embedding model:
25+
# return text.transform(
26+
# cocoindex.functions.EmbedText(
27+
# api_type=cocoindex.llm.LlmApiType.VOYAGE,
28+
# model="voyage-code-3",
29+
# )
30+
# )
2031
return text.transform(
2132
cocoindex.functions.SentenceTransformerEmbed(
2233
model="sentence-transformers/all-MiniLM-L6-v2"
@@ -71,7 +82,7 @@ def code_embedding_flow(
7182
)
7283

7384

74-
def search(pool: ConnectionPool, query: str, top_k: int = 5):
85+
def search(pool: ConnectionPool, query: str, top_k: int = 5) -> list[dict[str, Any]]:
7586
# Get the table name, for the export target in the code_embedding_flow above.
7687
table_name = cocoindex.utils.get_target_default_name(
7788
code_embedding_flow, "code_embeddings"
@@ -80,10 +91,11 @@ def search(pool: ConnectionPool, query: str, top_k: int = 5):
8091
query_vector = code_to_embedding.eval(query)
8192
# Run the query and get the results.
8293
with pool.connection() as conn:
94+
register_vector(conn)
8395
with conn.cursor() as cur:
8496
cur.execute(
8597
f"""
86-
SELECT filename, code, embedding <=> %s::vector AS distance
98+
SELECT filename, code, embedding <=> %s AS distance
8799
FROM {table_name} ORDER BY distance LIMIT %s
88100
""",
89101
(query_vector, top_k),
@@ -94,7 +106,7 @@ def search(pool: ConnectionPool, query: str, top_k: int = 5):
94106
]
95107

96108

97-
def _main():
109+
def _main() -> None:
98110
# Make sure the flow is built and up-to-date.
99111
stats = code_embedding_flow.update()
100112
print("Updated index: ", stats)

python/cocoindex/flow.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ def _spec_kind(spec: Any) -> str:
9292

9393

9494
T = TypeVar("T")
95+
S = TypeVar("S")
9596

9697

9798
class _DataSliceState:
@@ -216,7 +217,7 @@ def transform(
216217
),
217218
)
218219

219-
def call(self, func: Callable[[DataSlice[T]], T], *args: Any, **kwargs: Any) -> T:
220+
def call(self, func: Callable[..., S], *args: Any, **kwargs: Any) -> S:
220221
"""
221222
Call a function with the data slice.
222223
"""

python/cocoindex/llm.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ class LlmApiType(Enum):
1111
ANTHROPIC = "Anthropic"
1212
LITE_LLM = "LiteLlm"
1313
OPEN_ROUTER = "OpenRouter"
14+
VOYAGE = "Voyage"
1415

1516

1617
@dataclass

src/llm/gemini.rs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,12 @@ use phf::phf_map;
88
use serde_json::Value;
99
use urlencoding::encode;
1010

11+
static DEFAULT_EMBEDDING_DIMENSIONS: phf::Map<&str, u32> = phf_map! {
12+
"gemini-embedding-exp-03-07" => 3072,
13+
"text-embedding-004" => 768,
14+
"embedding-001" => 768,
15+
};
16+
1117
pub struct Client {
1218
api_key: String,
1319
client: reqwest::Client,
@@ -127,12 +133,6 @@ impl LlmGenerationClient for Client {
127133
}
128134
}
129135

130-
static DEFAULT_EMBEDDING_DIMENSIONS: phf::Map<&str, u32> = phf_map! {
131-
"gemini-embedding-exp-03-07" => 3072,
132-
"text-embedding-004" => 768,
133-
"embedding-001" => 768,
134-
};
135-
136136
#[derive(Deserialize)]
137137
struct ContentEmbedding {
138138
values: Vec<f32>,

src/llm/mod.rs

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ pub enum LlmApiType {
1212
Anthropic,
1313
LiteLlm,
1414
OpenRouter,
15+
Voyage,
1516
}
1617

1718
#[derive(Debug, Clone, Serialize, Deserialize)]
@@ -80,6 +81,7 @@ mod litellm;
8081
mod ollama;
8182
mod openai;
8283
mod openrouter;
84+
mod voyage;
8385

8486
pub async fn new_llm_generation_client(
8587
api_type: LlmApiType,
@@ -103,6 +105,9 @@ pub async fn new_llm_generation_client(
103105
}
104106
LlmApiType::OpenRouter => Box::new(openrouter::Client::new_openrouter(address).await?)
105107
as Box<dyn LlmGenerationClient>,
108+
LlmApiType::Voyage => {
109+
api_bail!("Voyage is not supported for generation")
110+
}
106111
};
107112
Ok(client)
108113
}
@@ -118,7 +123,15 @@ pub fn new_llm_embedding_client(
118123
LlmApiType::OpenAi => {
119124
Box::new(openai::Client::new(address)?) as Box<dyn LlmEmbeddingClient>
120125
}
121-
_ => api_bail!("Embedding is not supported for API type {:?}", api_type),
126+
LlmApiType::Voyage => {
127+
Box::new(voyage::Client::new(address)?) as Box<dyn LlmEmbeddingClient>
128+
}
129+
LlmApiType::Ollama
130+
| LlmApiType::OpenRouter
131+
| LlmApiType::LiteLlm
132+
| LlmApiType::Anthropic => {
133+
api_bail!("Embedding is not supported for API type {:?}", api_type)
134+
}
122135
};
123136
Ok(client)
124137
}

src/llm/openai.rs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,12 @@ use async_openai::{
1515
use async_trait::async_trait;
1616
use phf::phf_map;
1717

18+
static DEFAULT_EMBEDDING_DIMENSIONS: phf::Map<&str, u32> = phf_map! {
19+
"text-embedding-3-small" => 1536,
20+
"text-embedding-3-large" => 3072,
21+
"text-embedding-ada-002" => 1536,
22+
};
23+
1824
pub struct Client {
1925
client: async_openai::Client<OpenAIConfig>,
2026
}
@@ -111,12 +117,6 @@ impl LlmGenerationClient for Client {
111117
}
112118
}
113119

114-
static DEFAULT_EMBEDDING_DIMENSIONS: phf::Map<&str, u32> = phf_map! {
115-
"text-embedding-3-small" => 1536,
116-
"text-embedding-3-large" => 3072,
117-
"text-embedding-ada-002" => 1536,
118-
};
119-
120120
#[async_trait]
121121
impl LlmEmbeddingClient for Client {
122122
async fn embed_text<'req>(

src/llm/voyage.rs

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
use crate::prelude::*;
2+
3+
use crate::llm::{LlmEmbeddingClient, LlmEmbeddingRequest, LlmEmbeddingResponse};
4+
use phf::phf_map;
5+
6+
static DEFAULT_EMBEDDING_DIMENSIONS: phf::Map<&str, u32> = phf_map! {
7+
// Current models
8+
"voyage-3-large" => 1024,
9+
"voyage-3.5" => 1024,
10+
"voyage-3.5-lite" => 1024,
11+
"voyage-code-3" => 1024,
12+
"voyage-finance-2" => 1024,
13+
"voyage-law-2" => 1024,
14+
"voyage-code-2" => 1536,
15+
16+
// Legacy models
17+
"voyage-3" => 1024,
18+
"voyage-3-lite" => 512,
19+
"voyage-multilingual-2" => 1024,
20+
"voyage-large-2-instruct" => 1024,
21+
"voyage-large-2" => 1536,
22+
"voyage-2" => 1024,
23+
"voyage-lite-02-instruct" => 1024,
24+
"voyage-02" => 1024,
25+
"voyage-01" => 1024,
26+
"voyage-lite-01" => 1024,
27+
"voyage-lite-01-instruct" => 1024,
28+
};
29+
30+
pub struct Client {
31+
api_key: String,
32+
client: reqwest::Client,
33+
}
34+
35+
impl Client {
36+
pub fn new(address: Option<String>) -> Result<Self> {
37+
if address.is_some() {
38+
api_bail!("Voyage AI doesn't support custom API address");
39+
}
40+
let api_key = match std::env::var("VOYAGE_API_KEY") {
41+
Ok(val) => val,
42+
Err(_) => api_bail!("VOYAGE_API_KEY environment variable must be set"),
43+
};
44+
Ok(Self {
45+
api_key,
46+
client: reqwest::Client::new(),
47+
})
48+
}
49+
}
50+
51+
#[derive(Deserialize)]
52+
struct EmbeddingData {
53+
embedding: Vec<f32>,
54+
}
55+
56+
#[derive(Deserialize)]
57+
struct EmbedResponse {
58+
data: Vec<EmbeddingData>,
59+
}
60+
61+
#[async_trait]
62+
impl LlmEmbeddingClient for Client {
63+
async fn embed_text<'req>(
64+
&self,
65+
request: LlmEmbeddingRequest<'req>,
66+
) -> Result<LlmEmbeddingResponse> {
67+
let url = "https://api.voyageai.com/v1/embeddings";
68+
69+
let mut payload = serde_json::json!({
70+
"input": request.text,
71+
"model": request.model,
72+
});
73+
74+
if let Some(task_type) = request.task_type {
75+
payload["input_type"] = serde_json::Value::String(task_type.into());
76+
}
77+
78+
let resp = self
79+
.client
80+
.post(url)
81+
.header("Authorization", format!("Bearer {}", self.api_key))
82+
.json(&payload)
83+
.send()
84+
.await
85+
.context("HTTP error")?;
86+
87+
if !resp.status().is_success() {
88+
bail!(
89+
"Voyage AI API error: {:?}\n{}\n",
90+
resp.status(),
91+
resp.text().await?
92+
);
93+
}
94+
95+
let embedding_resp: EmbedResponse = resp.json().await.context("Invalid JSON")?;
96+
97+
if embedding_resp.data.is_empty() {
98+
bail!("No embedding data in response");
99+
}
100+
101+
Ok(LlmEmbeddingResponse {
102+
embedding: embedding_resp.data[0].embedding.clone(),
103+
})
104+
}
105+
106+
fn get_default_embedding_dimension(&self, model: &str) -> Option<u32> {
107+
DEFAULT_EMBEDDING_DIMENSIONS.get(model).copied()
108+
}
109+
}

0 commit comments

Comments
 (0)