Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/docs/core/flow_def.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,7 @@ Types of the fields must be key types. See [Key Types](data_types#key-types) for

* `field_name`: the field to create vector index.
* `metric`: the similarity metric to use.
* `method` (optional): the index algorithm and optional tuning parameters. Leave unset to use the target default (HNSW for Postgres). Use `cocoindex.HnswVectorIndexMethod()` or `cocoindex.IvfFlatVectorIndexMethod()` to customize the method and its parameters.

#### Similarity Metrics

Expand Down
10 changes: 10 additions & 0 deletions docs/docs/examples/examples/simple_vector_index.md
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,16 @@ doc_embeddings.export(
CocoIndex supports other vector databases as well, with 1-line switch.
<DocumentationButton url="https://cocoindex.io/docs/ops/targets" text="Targets" />

Need IVFFlat or custom HNSW parameters? Pass a method, for example:

```python
cocoindex.VectorIndexDef(
field_name="embedding",
metric=cocoindex.VectorSimilarityMetric.COSINE_SIMILARITY,
method=cocoindex.IvfFlatVectorIndexMethod(lists=200),
)
```

## Query the index

### Define a shared flow for both indexing and querying
Expand Down
10 changes: 9 additions & 1 deletion python/cocoindex/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,13 @@
from .flow import update_all_flows_async, setup_all_flows, drop_all_flows
from .lib import settings, init, start_server, stop
from .llm import LlmSpec, LlmApiType
from .index import VectorSimilarityMetric, VectorIndexDef, IndexOptions
from .index import (
VectorSimilarityMetric,
VectorIndexDef,
IndexOptions,
HnswVectorIndexMethod,
IvfFlatVectorIndexMethod,
)
from .setting import DatabaseConnectionSpec, Settings, ServerSettings
from .setting import get_app_namespace
from .query_handler import QueryHandlerResultFields, QueryInfo, QueryOutput
Expand Down Expand Up @@ -82,6 +88,8 @@
"VectorSimilarityMetric",
"VectorIndexDef",
"IndexOptions",
"HnswVectorIndexMethod",
"IvfFlatVectorIndexMethod",
# Settings
"DatabaseConnectionSpec",
"Settings",
Expand Down
23 changes: 22 additions & 1 deletion python/cocoindex/index.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from enum import Enum
from dataclasses import dataclass
from typing import Sequence
from typing import Sequence, Union


class VectorSimilarityMetric(Enum):
Expand All @@ -9,6 +9,26 @@ class VectorSimilarityMetric(Enum):
INNER_PRODUCT = "InnerProduct"


@dataclass
class HnswVectorIndexMethod:
"""HNSW vector index parameters."""

kind: str = "Hnsw"
m: int | None = None
ef_construction: int | None = None


@dataclass
class IvfFlatVectorIndexMethod:
"""IVFFlat vector index parameters."""

kind: str = "IvfFlat"
lists: int | None = None


VectorIndexMethod = Union[HnswVectorIndexMethod, IvfFlatVectorIndexMethod]


@dataclass
class VectorIndexDef:
"""
Expand All @@ -17,6 +37,7 @@ class VectorIndexDef:

field_name: str
metric: VectorSimilarityMetric
method: VectorIndexMethod | None = None


@dataclass
Expand Down
72 changes: 71 additions & 1 deletion src/base/spec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -384,15 +384,85 @@ impl fmt::Display for VectorSimilarityMetric {
}
}

#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(tag = "kind")]
pub enum VectorIndexMethod {
Hnsw {
#[serde(default, skip_serializing_if = "Option::is_none")]
m: Option<u32>,
#[serde(default, skip_serializing_if = "Option::is_none")]
ef_construction: Option<u32>,
},
IvfFlat {
#[serde(default, skip_serializing_if = "Option::is_none")]
lists: Option<u32>,
},
}

impl VectorIndexMethod {
pub fn kind(&self) -> &'static str {
match self {
Self::Hnsw { .. } => "Hnsw",
Self::IvfFlat { .. } => "IvfFlat",
}
}

pub fn is_default(&self) -> bool {
matches!(
self,
Self::Hnsw {
m: None,
ef_construction: None,
}
)
}
}

impl fmt::Display for VectorIndexMethod {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Hnsw { m, ef_construction } => {
let mut parts = Vec::new();
if let Some(m) = m {
parts.push(format!("m={}", m));
}
if let Some(ef) = ef_construction {
parts.push(format!("ef_construction={}", ef));
}
if parts.is_empty() {
write!(f, "Hnsw")
} else {
write!(f, "Hnsw({})", parts.join(","))
}
}
Self::IvfFlat { lists } => {
if let Some(lists) = lists {
write!(f, "IvfFlat(lists={lists})")
} else {
write!(f, "IvfFlat")
}
}
}
}
}

#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct VectorIndexDef {
pub field_name: FieldName,
pub metric: VectorSimilarityMetric,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub method: Option<VectorIndexMethod>,
}

impl fmt::Display for VectorIndexDef {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}:{}", self.field_name, self.metric)
match &self.method {
None => write!(f, "{}:{}", self.field_name, self.metric),
Some(method) if method.is_default() => {
write!(f, "{}:{}", self.field_name, self.metric)
}
Some(method) => write!(f, "{}:{}:{}", self.field_name, self.metric, method),
}
}
}

Expand Down
40 changes: 30 additions & 10 deletions src/llm/gemini.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,30 @@ impl AiStudioClient {
}
}

fn build_embed_payload(
model: &str,
text: &str,
task_type: Option<&str>,
output_dimension: Option<u32>,
) -> serde_json::Value {
let mut payload = serde_json::json!({
"model": model,
"content": { "parts": [{ "text": text }] },
});
if let Some(task_type) = task_type {
payload["taskType"] = serde_json::Value::String(task_type.to_string());
}
if let Some(output_dimension) = output_dimension {
payload["outputDimensionality"] = serde_json::json!(output_dimension);
if model.starts_with("gemini-embedding-") {
payload["config"] = serde_json::json!({
"outputDimensionality": output_dimension,
});
}
}
payload
}

#[async_trait]
impl LlmGenerationClient for AiStudioClient {
async fn generate<'req>(
Expand Down Expand Up @@ -174,16 +198,12 @@ impl LlmEmbeddingClient for AiStudioClient {
request: super::LlmEmbeddingRequest<'req>,
) -> Result<super::LlmEmbeddingResponse> {
let url = self.get_api_url(request.model, "embedContent");
let mut payload = serde_json::json!({
"model": request.model,
"content": { "parts": [{ "text": request.text }] },
});
if let Some(task_type) = request.task_type {
payload["taskType"] = serde_json::Value::String(task_type.into());
}
if let Some(output_dimension) = request.output_dimension {
payload["outputDimensionality"] = serde_json::Value::Number(output_dimension.into());
}
let payload = build_embed_payload(
request.model,
request.text.as_ref(),
request.task_type.as_deref(),
request.output_dimension,
);
let resp = retryable::run(
|| async {
self.client
Expand Down
38 changes: 34 additions & 4 deletions src/ops/targets/postgres.rs
Original file line number Diff line number Diff line change
Expand Up @@ -461,21 +461,51 @@ fn to_vector_similarity_metric_sql(metric: VectorSimilarityMetric) -> &'static s
}

fn to_index_spec_sql(index_spec: &VectorIndexDef) -> Cow<'static, str> {
let (method, options) = match index_spec.method.as_ref() {
Some(spec::VectorIndexMethod::Hnsw { m, ef_construction }) => {
let mut opts = Vec::new();
if let Some(m) = m {
opts.push(format!("m = {}", m));
}
if let Some(ef) = ef_construction {
opts.push(format!("ef_construction = {}", ef));
}
("hnsw", opts)
}
Some(spec::VectorIndexMethod::IvfFlat { lists }) => (
"ivfflat",
lists
.map(|lists| vec![format!("lists = {}", lists)])
.unwrap_or_default(),
),
None => ("hnsw", Vec::new()),
};
let with_clause = if options.is_empty() {
String::new()
} else {
format!(" WITH ({})", options.join(", "))
};
format!(
"USING hnsw ({} {})",
"USING {method} ({} {}){}",
index_spec.field_name,
to_vector_similarity_metric_sql(index_spec.metric)
to_vector_similarity_metric_sql(index_spec.metric),
with_clause
)
.into()
}

fn to_vector_index_name(table_name: &str, vector_index_def: &spec::VectorIndexDef) -> String {
format!(
let mut name = format!(
"{}__{}__{}",
table_name,
vector_index_def.field_name,
to_vector_similarity_metric_sql(vector_index_def.metric)
)
);
if let Some(method) = vector_index_def.method.as_ref().filter(|m| !m.is_default()) {
name.push_str("__");
name.push_str(&method.kind().to_ascii_lowercase());
}
name
}

fn describe_index_spec(index_name: &str, index_spec: &VectorIndexDef) -> String {
Expand Down
Loading