diff --git a/Cargo.lock b/Cargo.lock index f40aa7189..5a6739921 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -902,6 +902,31 @@ dependencies = [ "generic-array", ] +[[package]] +name = "bon" +version = "3.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f61138465baf186c63e8d9b6b613b508cd832cba4ce93cf37ce5f096f91ac1a6" +dependencies = [ + "bon-macros", + "rustversion", +] + +[[package]] +name = "bon-macros" +version = "3.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "40d1dad34aa19bf02295382f08d9bc40651585bd497266831d40ee6296fb49ca" +dependencies = [ + "darling", + "ident_case", + "prettyplease", + "proc-macro2", + "quote", + "rustversion", + "syn 2.0.104", +] + [[package]] name = "bstr" version = "1.12.0" @@ -1066,6 +1091,7 @@ dependencies = [ "env_logger", "futures", "globset", + "google-cloud-aiplatform-v1", "google-drive3", "hex", "http-body-util", @@ -1912,6 +1938,222 @@ dependencies = [ "yup-oauth2 11.0.0", ] +[[package]] +name = "google-cloud-aiplatform-v1" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "571ef9524283a5973d4fc790ec4aa1f6d8c141c71e4b372e95ccb32ebb00b87f" +dependencies = [ + "async-trait", + "bytes", + "google-cloud-api", + "google-cloud-gax", + "google-cloud-gax-internal", + "google-cloud-iam-v1", + "google-cloud-location", + "google-cloud-longrunning", + "google-cloud-lro", + "google-cloud-rpc", + "google-cloud-type", + "google-cloud-wkt", + "lazy_static", + "reqwest", + "serde", + "serde_json", + "serde_with", + "tracing", +] + +[[package]] +name = "google-cloud-api" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b110420309adc187bf99c53bf698520bbf707fa56b76d1d47c8de9b00c7529d2" +dependencies = [ + "bytes", + "google-cloud-wkt", + "serde", + "serde_json", + "serde_with", +] + +[[package]] +name = "google-cloud-auth" +version = "0.21.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a65fb515e1e726bc58b925fc876e8f02b626ab574b15c8f3fa53cb08177a7815" +dependencies = [ + "async-trait", + "base64 0.22.1", + "bon", + "google-cloud-gax", + "http 1.3.1", + "reqwest", + "rustls 0.23.29", + "rustls-pemfile 2.2.0", + "serde", + "serde_json", + "thiserror 2.0.12", + "time", + "tokio", +] + +[[package]] +name = "google-cloud-gax" +version = "0.23.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d88f66b6fccca3449ad9f0366ec3d1b74a004fa1100d98353a201dabc3203bd2" +dependencies = [ + "base64 0.22.1", + "bytes", + "futures", + "google-cloud-rpc", + "google-cloud-wkt", + "http 1.3.1", + "pin-project", + "rand 0.9.1", + "serde", + "serde_json", + "thiserror 2.0.12", + "tokio", +] + +[[package]] +name = "google-cloud-gax-internal" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8291536f9608037835925d9d8e54dab3a74abaca59d42dbcb8dc6cfc5e35066a" +dependencies = [ + "bytes", + "google-cloud-auth", + "google-cloud-gax", + "google-cloud-rpc", + "http 1.3.1", + "http-body-util", + "percent-encoding", + "reqwest", + "rustc_version", + "serde", + "serde_json", + "thiserror 2.0.12", + "tokio", +] + +[[package]] +name = "google-cloud-iam-v1" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e5d75747af3509772c2526cf3328e4542203318e0351f29280c5e439dfb00702" +dependencies = [ + "async-trait", + "bytes", + "google-cloud-gax", + "google-cloud-gax-internal", + "google-cloud-type", + "google-cloud-wkt", + "lazy_static", + "reqwest", + "serde", + "serde_json", + "serde_with", + "tracing", +] + +[[package]] +name = "google-cloud-location" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2e50a61d6b6b35faebeecf91bc8c4fc5f63a362d877d2568a01fce6131d2ba24" +dependencies = [ + "async-trait", + "bytes", + "google-cloud-gax", + "google-cloud-gax-internal", + "google-cloud-wkt", + "lazy_static", + "reqwest", + "serde", + "serde_json", + "serde_with", + "tracing", +] + +[[package]] +name = "google-cloud-longrunning" +version = "0.25.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cc29656e6916ffbd835d54eafa163aab45232b7dde648a265efb123c078cec7d" +dependencies = [ + "async-trait", + "bytes", + "google-cloud-gax", + "google-cloud-gax-internal", + "google-cloud-rpc", + "google-cloud-wkt", + "lazy_static", + "reqwest", + "serde", + "serde_json", + "serde_with", + "tracing", +] + +[[package]] +name = "google-cloud-lro" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5325b6cf7bac79e9ec87ae6806d619d8fbf4b345a1223ede93fa87d99b9df41a" +dependencies = [ + "google-cloud-gax", + "google-cloud-longrunning", + "google-cloud-rpc", + "google-cloud-wkt", + "serde", + "tokio", +] + +[[package]] +name = "google-cloud-rpc" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f1b4b7867ab6e94e56944116f2b1bdcdbac522eea794743a9884d84db7728470" +dependencies = [ + "bytes", + "google-cloud-wkt", + "serde", + "serde_json", + "serde_with", +] + +[[package]] +name = "google-cloud-type" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "683af2d05d271b6d0a13bcba997037885d8369f850706473576d5e257fac3f62" +dependencies = [ + "bytes", + "google-cloud-wkt", + "serde", + "serde_json", + "serde_with", +] + +[[package]] +name = "google-cloud-wkt" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d4b1dbeb30ab8bde2423081b76f9c302d82fea1b99f8ca750f223d32065d3c6c" +dependencies = [ + "base64 0.22.1", + "bytes", + "serde", + "serde_json", + "serde_with", + "thiserror 2.0.12", + "time", + "url", +] + [[package]] name = "google-drive3" version = "6.0.0+20240618" diff --git a/Cargo.toml b/Cargo.toml index 4f439b1ab..803c378a6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -117,3 +117,4 @@ aws-sdk-sqs = "1.67.0" numpy = "0.25.0" infer = "0.19.0" serde_with = { version = "3.13.0", features = ["base64"] } +google-cloud-aiplatform-v1 = "0.4.0" diff --git a/README.md b/README.md index 8535c3ffe..3b52c2f46 100644 --- a/README.md +++ b/README.md @@ -29,7 +29,7 @@ Ultra performant data transformation framework for AI, with core engine written
-CocoIndex makes it super easy to transform data with AI workloads, and keep source data and target in sync effortlessly. +CocoIndex makes it super easy to transform data with AI workloads, and keep source data and target in sync effortlessly.
@@ -39,7 +39,7 @@ CocoIndex makes it super easy to transform data with AI workloads, and keep sour
-Either creating embedding, building knowledge graphs, or any data transformations - beyond traditional SQL. +Either creating embedding, building knowledge graphs, or any data transformations - beyond traditional SQL. ## Exceptional velocity Just declare transformation in dataflow with ~100 lines of python @@ -65,7 +65,7 @@ CocoIndex follows the idea of [Dataflow](https://en.wikipedia.org/wiki/Dataflow_ **Particularly**, developers don't explicitly mutate data by creating, updating and deleting. They just need to define transformation/formula for a set of source data. ## Build like LEGO -Native builtins for different source, targets and transformations. Standardize interface, make it 1-line code switch between different components. +Native builtins for different source, targets and transformations. Standardize interface, make it 1-line code switch between different components.

CocoIndex Features diff --git a/python/cocoindex/__init__.py b/python/cocoindex/__init__.py index f2730424e..41cbb3320 100644 --- a/python/cocoindex/__init__.py +++ b/python/cocoindex/__init__.py @@ -33,6 +33,7 @@ # Submodules "_engine", "functions", + "llm", "sources", "targets", "storages", diff --git a/python/cocoindex/functions.py b/python/cocoindex/functions.py index 2b9ae1802..877c83e82 100644 --- a/python/cocoindex/functions.py +++ b/python/cocoindex/functions.py @@ -45,6 +45,7 @@ class EmbedText(op.FunctionSpec): address: str | None = None output_dimension: int | None = None task_type: str | None = None + api_config: llm.VertexAiConfig | None = None class ExtractByLlm(op.FunctionSpec): diff --git a/python/cocoindex/llm.py b/python/cocoindex/llm.py index 6a77e93e8..28ffca5fe 100644 --- a/python/cocoindex/llm.py +++ b/python/cocoindex/llm.py @@ -8,6 +8,7 @@ class LlmApiType(Enum): OPENAI = "OpenAi" OLLAMA = "Ollama" GEMINI = "Gemini" + VERTEX_AI = "VertexAi" ANTHROPIC = "Anthropic" LITE_LLM = "LiteLlm" OPEN_ROUTER = "OpenRouter" @@ -15,6 +16,16 @@ class LlmApiType(Enum): VLLM = "Vllm" +@dataclass +class VertexAiConfig: + """A specification for a Vertex AI LLM.""" + + kind = "VertexAi" + + project: str + region: str | None = None + + @dataclass class LlmSpec: """A specification for a LLM.""" @@ -22,3 +33,4 @@ class LlmSpec: api_type: LlmApiType model: str address: str | None = None + api_config: VertexAiConfig | None = None diff --git a/src/llm/gemini.rs b/src/llm/gemini.rs index 17ec63041..1eb869749 100644 --- a/src/llm/gemini.rs +++ b/src/llm/gemini.rs @@ -5,6 +5,7 @@ use crate::llm::{ ToJsonSchemaOptions, detect_image_mime_type, }; use base64::prelude::*; +use google_cloud_aiplatform_v1 as vertexai; use phf::phf_map; use serde_json::Value; use urlencoding::encode; @@ -15,12 +16,12 @@ static DEFAULT_EMBEDDING_DIMENSIONS: phf::Map<&str, u32> = phf_map! { "embedding-001" => 768, }; -pub struct Client { +pub struct AiStudioClient { api_key: String, client: reqwest::Client, } -impl Client { +impl AiStudioClient { pub fn new(address: Option) -> Result { if address.is_some() { api_bail!("Gemini doesn't support custom API address"); @@ -54,7 +55,7 @@ fn remove_additional_properties(value: &mut Value) { } } -impl Client { +impl AiStudioClient { fn get_api_url(&self, model: &str, api_name: &str) -> String { format!( "https://generativelanguage.googleapis.com/v1beta/models/{}:{}?key={}", @@ -66,7 +67,7 @@ impl Client { } #[async_trait] -impl LlmGenerationClient for Client { +impl LlmGenerationClient for AiStudioClient { async fn generate<'req>( &self, request: LlmGenerateRequest<'req>, @@ -159,7 +160,7 @@ struct EmbedContentResponse { } #[async_trait] -impl LlmEmbeddingClient for Client { +impl LlmEmbeddingClient for AiStudioClient { async fn embed_text<'req>( &self, request: super::LlmEmbeddingRequest<'req>, @@ -194,3 +195,109 @@ impl LlmEmbeddingClient for Client { DEFAULT_EMBEDDING_DIMENSIONS.get(model).copied() } } + +pub struct VertexAiClient { + client: vertexai::client::PredictionService, + config: super::VertexAiConfig, +} + +impl VertexAiClient { + pub async fn new(config: super::VertexAiConfig) -> Result { + let client = vertexai::client::PredictionService::builder() + .build() + .await?; + Ok(Self { client, config }) + } +} + +#[async_trait] +impl LlmGenerationClient for VertexAiClient { + async fn generate<'req>( + &self, + request: super::LlmGenerateRequest<'req>, + ) -> Result { + use vertexai::model::{Blob, Content, GenerationConfig, Part, Schema, part::Data}; + + // Compose parts + let mut parts = Vec::new(); + // Add text part + parts.push(Part::new().set_text(request.user_prompt.to_string())); + // Add image part if present + if let Some(image_bytes) = request.image { + let mime_type = detect_image_mime_type(image_bytes.as_ref())?; + parts.push( + Part::new().set_inline_data( + Blob::new() + .set_data(image_bytes.into_owned()) + .set_mime_type(mime_type.to_string()), + ), + ); + } + // Compose content + let mut contents = Vec::new(); + contents.push(Content::new().set_role("user".to_string()).set_parts(parts)); + // Compose system instruction if present + let system_instruction = request.system_prompt.as_ref().map(|sys| { + Content::new() + .set_role("system".to_string()) + .set_parts(vec![Part::new().set_text(sys.to_string())]) + }); + + // Compose generation config + let mut generation_config = None; + if let Some(OutputFormat::JsonSchema { schema, .. }) = &request.output_format { + let schema_json = serde_json::to_value(schema)?; + generation_config = Some( + GenerationConfig::new() + .set_response_mime_type("application/json".to_string()) + .set_response_schema(serde_json::from_value::(schema_json)?), + ); + } + + // projects/{project_id}/locations/global/publishers/google/models/{MODEL} + + let model = format!( + "projects/{}/locations/{}/publishers/google/models/{}", + self.config.project, + self.config.region.as_deref().unwrap_or("global"), + request.model + ); + + // Build the request + let mut req = self + .client + .generate_content() + .set_model(model) + .set_contents(contents); + if let Some(sys) = system_instruction { + req = req.set_system_instruction(sys); + } + if let Some(config) = generation_config { + req = req.set_generation_config(config); + } + + // Call the API + let resp = req.send().await?; + // Extract text from response + let Some(Data::Text(text)) = resp + .candidates + .into_iter() + .next() + .and_then(|c| c.content) + .and_then(|content| content.parts.into_iter().next()) + .and_then(|part| part.data) + else { + bail!("No text in response"); + }; + Ok(super::LlmGenerateResponse { text }) + } + + fn json_schema_options(&self) -> ToJsonSchemaOptions { + ToJsonSchemaOptions { + fields_always_required: false, + supports_format: false, + extract_descriptions: false, + top_level_must_be_object: true, + } + } +} diff --git a/src/llm/mod.rs b/src/llm/mod.rs index 68589e343..a89f9e679 100644 --- a/src/llm/mod.rs +++ b/src/llm/mod.rs @@ -17,6 +17,19 @@ pub enum LlmApiType { OpenRouter, Voyage, Vllm, + VertexAi, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct VertexAiConfig { + pub project: String, + pub region: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "kind")] +pub enum LlmApiConfig { + VertexAi(VertexAiConfig), } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -24,6 +37,7 @@ pub struct LlmSpec { pub api_type: LlmApiType, pub address: Option, pub model: String, + pub api_config: Option, } #[derive(Debug)] @@ -86,12 +100,14 @@ mod litellm; mod ollama; mod openai; mod openrouter; +mod vertex_ai; mod vllm; mod voyage; pub async fn new_llm_generation_client( api_type: LlmApiType, address: Option, + api_config: Option, ) -> Result> { let client = match api_type { LlmApiType::Ollama => { @@ -101,7 +117,17 @@ pub async fn new_llm_generation_client( Box::new(openai::Client::new(address)?) as Box } LlmApiType::Gemini => { - Box::new(gemini::Client::new(address)?) as Box + Box::new(gemini::AiStudioClient::new(address)?) as Box + } + LlmApiType::VertexAi => { + if address.is_some() { + api_bail!("VertexAi API address is not supported for VertexAi API type"); + } + let Some(LlmApiConfig::VertexAi(config)) = api_config else { + api_bail!("VertexAi API config is required for VertexAi API type"); + }; + let config = config.clone(); + Box::new(gemini::VertexAiClient::new(config).await?) as Box } LlmApiType::Anthropic => { Box::new(anthropic::Client::new(address).await?) as Box @@ -127,7 +153,7 @@ pub fn new_llm_embedding_client( ) -> Result> { let client = match api_type { LlmApiType::Gemini => { - Box::new(gemini::Client::new(address)?) as Box + Box::new(gemini::AiStudioClient::new(address)?) as Box } LlmApiType::OpenAi => { Box::new(openai::Client::new(address)?) as Box @@ -139,7 +165,8 @@ pub fn new_llm_embedding_client( | LlmApiType::OpenRouter | LlmApiType::LiteLlm | LlmApiType::Vllm - | LlmApiType::Anthropic => { + | LlmApiType::Anthropic + | LlmApiType::VertexAi => { api_bail!("Embedding is not supported for API type {:?}", api_type) } }; diff --git a/src/llm/vertex_ai.rs b/src/llm/vertex_ai.rs new file mode 100644 index 000000000..8b1378917 --- /dev/null +++ b/src/llm/vertex_ai.rs @@ -0,0 +1 @@ + diff --git a/src/ops/functions/extract_by_llm.rs b/src/ops/functions/extract_by_llm.rs index bbd95811d..6e174268c 100644 --- a/src/ops/functions/extract_by_llm.rs +++ b/src/ops/functions/extract_by_llm.rs @@ -52,8 +52,12 @@ Unless explicitly instructed otherwise, output only the JSON. DO NOT include exp impl Executor { async fn new(spec: Spec, args: Args) -> Result { - let client = - new_llm_generation_client(spec.llm_spec.api_type, spec.llm_spec.address).await?; + let client = new_llm_generation_client( + spec.llm_spec.api_type, + spec.llm_spec.address, + spec.llm_spec.api_config, + ) + .await?; let schema_output = build_json_schema(spec.output_type, client.json_schema_options())?; Ok(Self { args, @@ -190,6 +194,7 @@ mod tests { api_type: crate::llm::LlmApiType::OpenAi, model: "gpt-4o".to_string(), address: None, + api_config: None, }, output_type: output_type_spec, instruction: Some("Extract the name and value from the text. The name is a string, the value is an integer.".to_string()),