Skip to content

Commit 4689ad9

Browse files
committed
feat: allow configuring Postgres vector index method and options
1 parent b9dfc7e commit 4689ad9

File tree

6 files changed

+149
-8
lines changed

6 files changed

+149
-8
lines changed

docs/docs/core/flow_def.mdx

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,7 @@ Types of the fields must be key types. See [Key Types](data_types#key-types) for
313313

314314
* `field_name`: the field to create vector index.
315315
* `metric`: the similarity metric to use.
316+
* `method`: the index algorithm and optional tuning parameters. Defaults to HNSW. Use `cocoindex.HnswVectorIndexMethod()` or `cocoindex.IvfFlatVectorIndexMethod()` to customize the method and its parameters.
316317

317318
#### Similarity Metrics
318319

docs/docs/examples/examples/simple_vector_index.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,16 @@ doc_embeddings.export(
105105
CocoIndex supports other vector databases as well, with 1-line switch.
106106
<DocumentationButton url="https://cocoindex.io/docs/ops/targets" text="Targets" />
107107

108+
Need IVFFlat or custom HNSW parameters? Pass a method, for example:
109+
110+
```python
111+
cocoindex.VectorIndexDef(
112+
field_name="embedding",
113+
metric=cocoindex.VectorSimilarityMetric.COSINE_SIMILARITY,
114+
method=cocoindex.IvfFlatVectorIndexMethod(lists=200),
115+
)
116+
```
117+
108118
## Query the index
109119

110120
### Define a shared flow for both indexing and querying

python/cocoindex/__init__.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,13 @@
2121
from .flow import update_all_flows_async, setup_all_flows, drop_all_flows
2222
from .lib import settings, init, start_server, stop
2323
from .llm import LlmSpec, LlmApiType
24-
from .index import VectorSimilarityMetric, VectorIndexDef, IndexOptions
24+
from .index import (
25+
VectorSimilarityMetric,
26+
VectorIndexDef,
27+
IndexOptions,
28+
HnswVectorIndexMethod,
29+
IvfFlatVectorIndexMethod,
30+
)
2531
from .setting import DatabaseConnectionSpec, Settings, ServerSettings
2632
from .setting import get_app_namespace
2733
from .query_handler import QueryHandlerResultFields, QueryInfo, QueryOutput
@@ -82,6 +88,8 @@
8288
"VectorSimilarityMetric",
8389
"VectorIndexDef",
8490
"IndexOptions",
91+
"HnswVectorIndexMethod",
92+
"IvfFlatVectorIndexMethod",
8593
# Settings
8694
"DatabaseConnectionSpec",
8795
"Settings",

python/cocoindex/index.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from enum import Enum
2-
from dataclasses import dataclass
3-
from typing import Sequence
2+
from dataclasses import dataclass, field
3+
from typing import Sequence, Union
44

55

66
class VectorSimilarityMetric(Enum):
@@ -9,6 +9,26 @@ class VectorSimilarityMetric(Enum):
99
INNER_PRODUCT = "InnerProduct"
1010

1111

12+
@dataclass
13+
class HnswVectorIndexMethod:
14+
"""HNSW vector index parameters."""
15+
16+
type: str = field(init=False, default="hnsw")
17+
m: int | None = None
18+
ef_construction: int | None = None
19+
20+
21+
@dataclass
22+
class IvfFlatVectorIndexMethod:
23+
"""IVFFlat vector index parameters."""
24+
25+
type: str = field(init=False, default="ivfflat")
26+
lists: int | None = None
27+
28+
29+
VectorIndexMethod = Union[HnswVectorIndexMethod, IvfFlatVectorIndexMethod]
30+
31+
1232
@dataclass
1333
class VectorIndexDef:
1434
"""
@@ -17,6 +37,7 @@ class VectorIndexDef:
1737

1838
field_name: str
1939
metric: VectorSimilarityMetric
40+
method: VectorIndexMethod = field(default_factory=HnswVectorIndexMethod)
2041

2142

2243
@dataclass

src/base/spec.rs

Lines changed: 72 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -384,15 +384,86 @@ impl fmt::Display for VectorSimilarityMetric {
384384
}
385385
}
386386

387+
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
388+
#[serde(tag = "type", rename_all = "snake_case")]
389+
pub enum VectorIndexMethod {
390+
Hnsw {
391+
#[serde(default, skip_serializing_if = "Option::is_none")]
392+
m: Option<u32>,
393+
#[serde(default, skip_serializing_if = "Option::is_none")]
394+
ef_construction: Option<u32>,
395+
},
396+
IvfFlat {
397+
#[serde(default, skip_serializing_if = "Option::is_none")]
398+
lists: Option<u32>,
399+
},
400+
}
401+
402+
impl Default for VectorIndexMethod {
403+
fn default() -> Self {
404+
Self::Hnsw {
405+
m: None,
406+
ef_construction: None,
407+
}
408+
}
409+
}
410+
411+
impl VectorIndexMethod {
412+
pub fn kind(&self) -> &'static str {
413+
match self {
414+
Self::Hnsw { .. } => "hnsw",
415+
Self::IvfFlat { .. } => "ivfflat",
416+
}
417+
}
418+
419+
pub fn is_default(&self) -> bool {
420+
matches!(self, Self::Hnsw { m: None, ef_construction: None })
421+
}
422+
}
423+
424+
impl fmt::Display for VectorIndexMethod {
425+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
426+
match self {
427+
Self::Hnsw { m, ef_construction } => {
428+
let mut parts = Vec::new();
429+
if let Some(m) = m {
430+
parts.push(format!("m={}", m));
431+
}
432+
if let Some(ef) = ef_construction {
433+
parts.push(format!("ef_construction={}", ef));
434+
}
435+
if parts.is_empty() {
436+
write!(f, "hnsw")
437+
} else {
438+
write!(f, "hnsw({})", parts.join(","))
439+
}
440+
}
441+
Self::IvfFlat { lists } => {
442+
if let Some(lists) = lists {
443+
write!(f, "ivfflat(lists={lists})")
444+
} else {
445+
write!(f, "ivfflat")
446+
}
447+
}
448+
}
449+
}
450+
}
451+
387452
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
388453
pub struct VectorIndexDef {
389454
pub field_name: FieldName,
390455
pub metric: VectorSimilarityMetric,
456+
#[serde(default)]
457+
pub method: VectorIndexMethod,
391458
}
392459

393460
impl fmt::Display for VectorIndexDef {
394461
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
395-
write!(f, "{}:{}", self.field_name, self.metric)
462+
if self.method.is_default() {
463+
write!(f, "{}:{}", self.field_name, self.metric)
464+
} else {
465+
write!(f, "{}:{}:{}", self.field_name, self.metric, self.method)
466+
}
396467
}
397468
}
398469

src/ops/targets/postgres.rs

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -461,21 +461,51 @@ fn to_vector_similarity_metric_sql(metric: VectorSimilarityMetric) -> &'static s
461461
}
462462

463463
fn to_index_spec_sql(index_spec: &VectorIndexDef) -> Cow<'static, str> {
464+
let method = match &index_spec.method {
465+
spec::VectorIndexMethod::Hnsw { .. } => "hnsw",
466+
spec::VectorIndexMethod::IvfFlat { .. } => "ivfflat",
467+
};
468+
let options = match &index_spec.method {
469+
spec::VectorIndexMethod::Hnsw { m, ef_construction } => {
470+
let mut opts = Vec::new();
471+
if let Some(m) = m {
472+
opts.push(format!("m = {}", m));
473+
}
474+
if let Some(ef) = ef_construction {
475+
opts.push(format!("ef_construction = {}", ef));
476+
}
477+
opts
478+
}
479+
spec::VectorIndexMethod::IvfFlat { lists } => lists
480+
.map(|lists| vec![format!("lists = {}", lists)])
481+
.unwrap_or_default(),
482+
};
483+
let with_clause = if options.is_empty() {
484+
String::new()
485+
} else {
486+
format!(" WITH ({})", options.join(", "))
487+
};
464488
format!(
465-
"USING hnsw ({} {})",
489+
"USING {method} ({} {}){}",
466490
index_spec.field_name,
467-
to_vector_similarity_metric_sql(index_spec.metric)
491+
to_vector_similarity_metric_sql(index_spec.metric),
492+
with_clause
468493
)
469494
.into()
470495
}
471496

472497
fn to_vector_index_name(table_name: &str, vector_index_def: &spec::VectorIndexDef) -> String {
473-
format!(
498+
let mut name = format!(
474499
"{}__{}__{}",
475500
table_name,
476501
vector_index_def.field_name,
477502
to_vector_similarity_metric_sql(vector_index_def.metric)
478-
)
503+
);
504+
if !vector_index_def.method.is_default() {
505+
name.push_str("__");
506+
name.push_str(vector_index_def.method.kind());
507+
}
508+
name
479509
}
480510

481511
fn describe_index_spec(index_name: &str, index_spec: &VectorIndexDef) -> String {

0 commit comments

Comments
 (0)