Skip to content

Commit 1b0ac88

Browse files
authored
feat(vertex): support Vertex AI for embedding (#740)
* feat(vertex): support Vertex AI for embedding * style: format fix
1 parent 9740124 commit 1b0ac88

File tree

3 files changed

+101
-34
lines changed

3 files changed

+101
-34
lines changed

src/llm/gemini.rs

Lines changed: 86 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,23 @@ use crate::llm::{
66
};
77
use base64::prelude::*;
88
use google_cloud_aiplatform_v1 as vertexai;
9-
use phf::phf_map;
109
use serde_json::Value;
1110
use urlencoding::encode;
1211

13-
static DEFAULT_EMBEDDING_DIMENSIONS: phf::Map<&str, u32> = phf_map! {
14-
"gemini-embedding-exp-03-07" => 3072,
15-
"text-embedding-004" => 768,
16-
"embedding-001" => 768,
17-
};
12+
fn get_embedding_dimension(model: &str) -> Option<u32> {
13+
let model = model.to_ascii_lowercase();
14+
if model.starts_with("gemini-embedding-") {
15+
Some(3072)
16+
} else if model.starts_with("text-embedding-") {
17+
Some(768)
18+
} else if model.starts_with("embedding-") {
19+
Some(768)
20+
} else if model.starts_with("text-multilingual-embedding-") {
21+
Some(768)
22+
} else {
23+
None
24+
}
25+
}
1826

1927
pub struct AiStudioClient {
2028
api_key: String,
@@ -192,7 +200,7 @@ impl LlmEmbeddingClient for AiStudioClient {
192200
}
193201

194202
fn get_default_embedding_dimension(&self, model: &str) -> Option<u32> {
195-
DEFAULT_EMBEDDING_DIMENSIONS.get(model).copied()
203+
get_embedding_dimension(model)
196204
}
197205
}
198206

@@ -202,12 +210,30 @@ pub struct VertexAiClient {
202210
}
203211

204212
impl VertexAiClient {
205-
pub async fn new(config: super::VertexAiConfig) -> Result<Self> {
213+
pub async fn new(
214+
address: Option<String>,
215+
api_config: Option<super::LlmApiConfig>,
216+
) -> Result<Self> {
217+
if address.is_some() {
218+
api_bail!("VertexAi API address is not supported for VertexAi API type");
219+
}
220+
let Some(super::LlmApiConfig::VertexAi(config)) = api_config else {
221+
api_bail!("VertexAi API config is required for VertexAi API type");
222+
};
206223
let client = vertexai::client::PredictionService::builder()
207224
.build()
208225
.await?;
209226
Ok(Self { client, config })
210227
}
228+
229+
fn get_model_path(&self, model: &str) -> String {
230+
format!(
231+
"projects/{}/locations/{}/publishers/google/models/{}",
232+
self.config.project,
233+
self.config.region.as_deref().unwrap_or("global"),
234+
model
235+
)
236+
}
211237
}
212238

213239
#[async_trait]
@@ -254,20 +280,10 @@ impl LlmGenerationClient for VertexAiClient {
254280
);
255281
}
256282

257-
// projects/{project_id}/locations/global/publishers/google/models/{MODEL}
258-
259-
let model = format!(
260-
"projects/{}/locations/{}/publishers/google/models/{}",
261-
self.config.project,
262-
self.config.region.as_deref().unwrap_or("global"),
263-
request.model
264-
);
265-
266-
// Build the request
267283
let mut req = self
268284
.client
269285
.generate_content()
270-
.set_model(model)
286+
.set_model(self.get_model_path(request.model))
271287
.set_contents(contents);
272288
if let Some(sys) = system_instruction {
273289
req = req.set_system_instruction(sys);
@@ -301,3 +317,54 @@ impl LlmGenerationClient for VertexAiClient {
301317
}
302318
}
303319
}
320+
321+
#[async_trait]
322+
impl LlmEmbeddingClient for VertexAiClient {
323+
async fn embed_text<'req>(
324+
&self,
325+
request: super::LlmEmbeddingRequest<'req>,
326+
) -> Result<super::LlmEmbeddingResponse> {
327+
// Create the instances for the request
328+
let mut instance = serde_json::json!({
329+
"content": request.text
330+
});
331+
// Add task type if specified
332+
if let Some(task_type) = &request.task_type {
333+
instance["task_type"] = serde_json::Value::String(task_type.to_string());
334+
}
335+
336+
let instances = vec![instance];
337+
338+
// Prepare the request parameters
339+
let mut parameters = serde_json::json!({});
340+
if let Some(output_dimension) = request.output_dimension {
341+
parameters["outputDimensionality"] = serde_json::Value::Number(output_dimension.into());
342+
}
343+
344+
// Build the prediction request using the raw predict builder
345+
let response = self
346+
.client
347+
.predict()
348+
.set_endpoint(self.get_model_path(request.model))
349+
.set_instances(instances)
350+
.set_parameters(parameters)
351+
.send()
352+
.await?;
353+
354+
// Extract the embedding from the response
355+
let embeddings = response
356+
.predictions
357+
.into_iter()
358+
.next()
359+
.and_then(|mut e| e.get_mut("embeddings").map(|v| v.take()))
360+
.ok_or_else(|| anyhow::anyhow!("No embeddings in response"))?;
361+
let embedding: ContentEmbedding = serde_json::from_value(embeddings)?;
362+
Ok(super::LlmEmbeddingResponse {
363+
embedding: embedding.values,
364+
})
365+
}
366+
367+
fn get_default_embedding_dimension(&self, model: &str) -> Option<u32> {
368+
get_embedding_dimension(model)
369+
}
370+
}

src/llm/mod.rs

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -119,16 +119,8 @@ pub async fn new_llm_generation_client(
119119
LlmApiType::Gemini => {
120120
Box::new(gemini::AiStudioClient::new(address)?) as Box<dyn LlmGenerationClient>
121121
}
122-
LlmApiType::VertexAi => {
123-
if address.is_some() {
124-
api_bail!("VertexAi API address is not supported for VertexAi API type");
125-
}
126-
let Some(LlmApiConfig::VertexAi(config)) = api_config else {
127-
api_bail!("VertexAi API config is required for VertexAi API type");
128-
};
129-
let config = config.clone();
130-
Box::new(gemini::VertexAiClient::new(config).await?) as Box<dyn LlmGenerationClient>
131-
}
122+
LlmApiType::VertexAi => Box::new(gemini::VertexAiClient::new(address, api_config).await?)
123+
as Box<dyn LlmGenerationClient>,
132124
LlmApiType::Anthropic => {
133125
Box::new(anthropic::Client::new(address).await?) as Box<dyn LlmGenerationClient>
134126
}
@@ -147,9 +139,10 @@ pub async fn new_llm_generation_client(
147139
Ok(client)
148140
}
149141

150-
pub fn new_llm_embedding_client(
142+
pub async fn new_llm_embedding_client(
151143
api_type: LlmApiType,
152144
address: Option<String>,
145+
api_config: Option<LlmApiConfig>,
153146
) -> Result<Box<dyn LlmEmbeddingClient>> {
154147
let client = match api_type {
155148
LlmApiType::Gemini => {
@@ -161,12 +154,13 @@ pub fn new_llm_embedding_client(
161154
LlmApiType::Voyage => {
162155
Box::new(voyage::Client::new(address)?) as Box<dyn LlmEmbeddingClient>
163156
}
157+
LlmApiType::VertexAi => Box::new(gemini::VertexAiClient::new(address, api_config).await?)
158+
as Box<dyn LlmEmbeddingClient>,
164159
LlmApiType::Ollama
165160
| LlmApiType::OpenRouter
166161
| LlmApiType::LiteLlm
167162
| LlmApiType::Vllm
168-
| LlmApiType::Anthropic
169-
| LlmApiType::VertexAi => {
163+
| LlmApiType::Anthropic => {
170164
api_bail!("Embedding is not supported for API type {:?}", api_type)
171165
}
172166
};

src/ops/functions/embed_text.rs

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
use crate::{
2-
llm::{LlmApiType, LlmEmbeddingClient, LlmEmbeddingRequest, new_llm_embedding_client},
2+
llm::{
3+
LlmApiConfig, LlmApiType, LlmEmbeddingClient, LlmEmbeddingRequest, new_llm_embedding_client,
4+
},
35
ops::sdk::*,
46
};
57

@@ -8,6 +10,7 @@ struct Spec {
810
api_type: LlmApiType,
911
model: String,
1012
address: Option<String>,
13+
api_config: Option<LlmApiConfig>,
1114
output_dimension: Option<u32>,
1215
task_type: Option<String>,
1316
}
@@ -67,7 +70,9 @@ impl SimpleFunctionFactoryBase for Factory {
6770
_context: &FlowInstanceContext,
6871
) -> Result<(Self::ResolvedArgs, EnrichedValueType)> {
6972
let text = args_resolver.next_arg("text")?;
70-
let client = new_llm_embedding_client(spec.api_type, spec.address.clone())?;
73+
let client =
74+
new_llm_embedding_client(spec.api_type, spec.address.clone(), spec.api_config.clone())
75+
.await?;
7176
let output_dimension = match spec.output_dimension {
7277
Some(output_dimension) => output_dimension,
7378
None => {
@@ -108,6 +113,7 @@ mod tests {
108113
api_type: LlmApiType::OpenAi,
109114
model: "text-embedding-ada-002".to_string(),
110115
address: None,
116+
api_config: None,
111117
output_dimension: None,
112118
task_type: None,
113119
};

0 commit comments

Comments
 (0)