Skip to content

Commit 8125f53

Browse files
committed
Merge remote-tracking branch 'upstream/HEAD' into qdrant
2 parents d7f419b + f0dc6ee commit 8125f53

File tree

15 files changed

+792
-229
lines changed

15 files changed

+792
-229
lines changed

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,3 +103,4 @@ tokio-stream = "0.1.17"
103103
async-stream = "0.3.6"
104104
neo4rs = "0.8.0"
105105
bytes = "1.10.1"
106+
rand = "0.9.0"

examples/docs_to_kg/main.py

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,11 +55,19 @@ def docs_to_kg_flow(flow_builder: cocoindex.FlowBuilder, data_scope: cocoindex.D
5555
"Each relationship should be a tuple of (subject, predicate, object).")))
5656

5757
with chunk["relationships"]["relationships"].row() as relationship:
58+
relationship["subject_embedding"] = relationship["subject"].transform(
59+
cocoindex.functions.SentenceTransformerEmbed(
60+
model="sentence-transformers/all-MiniLM-L6-v2"))
61+
relationship["object_embedding"] = relationship["object"].transform(
62+
cocoindex.functions.SentenceTransformerEmbed(
63+
model="sentence-transformers/all-MiniLM-L6-v2"))
5864
relationships.collect(
5965
id=cocoindex.GeneratedField.UUID,
6066
subject=relationship["subject"],
61-
predicate=relationship["predicate"],
67+
subject_embedding=relationship["subject_embedding"],
6268
object=relationship["object"],
69+
object_embedding=relationship["object_embedding"],
70+
predicate=relationship["predicate"],
6371
)
6472

6573
relationships.export(
@@ -69,14 +77,34 @@ def docs_to_kg_flow(flow_builder: cocoindex.FlowBuilder, data_scope: cocoindex.D
6977
rel_type="RELATIONSHIP",
7078
source=cocoindex.storages.Neo4jRelationshipEndSpec(
7179
label="Entity",
72-
fields=[cocoindex.storages.Neo4jFieldMapping(field_name="subject", node_field_name="value")]
80+
fields=[
81+
cocoindex.storages.Neo4jFieldMapping(
82+
field_name="subject", node_field_name="value"),
83+
cocoindex.storages.Neo4jFieldMapping(
84+
field_name="subject_embedding", node_field_name="embedding"),
85+
]
7386
),
7487
target=cocoindex.storages.Neo4jRelationshipEndSpec(
7588
label="Entity",
76-
fields=[cocoindex.storages.Neo4jFieldMapping(field_name="object", node_field_name="value")]
89+
fields=[
90+
cocoindex.storages.Neo4jFieldMapping(
91+
field_name="object", node_field_name="value"),
92+
cocoindex.storages.Neo4jFieldMapping(
93+
field_name="object_embedding", node_field_name="embedding"),
94+
]
7795
),
7896
nodes={
79-
"Entity": cocoindex.storages.Neo4jRelationshipNodeSpec(key_field_name="value"),
97+
"Entity": cocoindex.storages.Neo4jRelationshipNodeSpec(
98+
index_options=cocoindex.IndexOptions(
99+
primary_key_fields=["value"],
100+
vector_index_defs=[
101+
cocoindex.VectorIndexDef(
102+
field_name="embedding",
103+
metric=cocoindex.VectorSimilarityMetric.COSINE_SIMILARITY,
104+
),
105+
],
106+
),
107+
),
80108
},
81109
),
82110
primary_key_fields=["id"],

python/cocoindex/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from .flow import EvaluateAndDumpOptions, GeneratedField
77
from .flow import update_all_flows, FlowLiveUpdater, FlowLiveUpdaterOptions
88
from .llm import LlmSpec, LlmApiType
9-
from .vector import VectorSimilarityMetric
9+
from .index import VectorSimilarityMetric, VectorIndexDef, IndexOptions
1010
from .auth_registry import AuthEntryReference, add_auth_entry, ref_auth_entry
1111
from .lib import *
1212
from ._engine import OpArgSchema

python/cocoindex/flow.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from dataclasses import dataclass
1616

1717
from . import _engine
18-
from . import vector
18+
from . import index
1919
from . import op
2020
from .convert import dump_engine_object
2121
from .typing import encode_enriched_type
@@ -268,7 +268,7 @@ def collect(self, **kwargs):
268268

269269
def export(self, name: str, target_spec: op.StorageSpec, /, *,
270270
primary_key_fields: Sequence[str] | None = None,
271-
vector_index: Sequence[tuple[str, vector.VectorSimilarityMetric]] = (),
271+
vector_index: Sequence[tuple[str, index.VectorSimilarityMetric]] = (),
272272
setup_by_user: bool = False):
273273
"""
274274
Export the collected data to the specified target.

python/cocoindex/index.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
from enum import Enum
2+
from dataclasses import dataclass
3+
4+
class VectorSimilarityMetric(Enum):
5+
COSINE_SIMILARITY = "CosineSimilarity"
6+
L2_DISTANCE = "L2Distance"
7+
INNER_PRODUCT = "InnerProduct"
8+
9+
@dataclass
10+
class VectorIndexDef:
11+
"""
12+
Define a vector index on a field.
13+
"""
14+
field_name: str
15+
metric: VectorSimilarityMetric
16+
17+
@dataclass
18+
class IndexOptions:
19+
"""
20+
Options for an index.
21+
"""
22+
primary_key_fields: list[str] | None = None
23+
vector_index_defs: list[VectorIndexDef] | None = None

python/cocoindex/query.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from threading import Lock
44

55
from . import flow as fl
6-
from . import vector
6+
from . import index
77
from . import _engine
88

99
_handlers_lock = Lock()
@@ -14,7 +14,7 @@ class SimpleSemanticsQueryInfo:
1414
"""
1515
Additional information about the query.
1616
"""
17-
similarity_metric: vector.VectorSimilarityMetric
17+
similarity_metric: index.VectorSimilarityMetric
1818
query_vector: list[float]
1919
vector_field_name: str
2020

@@ -39,7 +39,7 @@ def __init__(
3939
flow: fl.Flow,
4040
target_name: str,
4141
query_transform_flow: Callable[..., fl.DataSlice],
42-
default_similarity_metric: vector.VectorSimilarityMetric = vector.VectorSimilarityMetric.COSINE_SIMILARITY) -> None:
42+
default_similarity_metric: index.VectorSimilarityMetric = index.VectorSimilarityMetric.COSINE_SIMILARITY) -> None:
4343

4444
engine_handler = None
4545
lock = Lock()
@@ -66,7 +66,7 @@ def internal_handler(self) -> _engine.SimpleSemanticsQueryHandler:
6666
return self._lazy_query_handler()
6767

6868
def search(self, query: str, limit: int, vector_field_name: str | None = None,
69-
similarity_matric: vector.VectorSimilarityMetric | None = None) -> tuple[list[QueryResult], SimpleSemanticsQueryInfo]:
69+
similarity_matric: index.VectorSimilarityMetric | None = None) -> tuple[list[QueryResult], SimpleSemanticsQueryInfo]:
7070
"""
7171
Search the index with the given query, limit, vector field name, and similarity metric.
7272
"""
@@ -76,7 +76,7 @@ def search(self, query: str, limit: int, vector_field_name: str | None = None,
7676
fields = [field['name'] for field in internal_results['fields']]
7777
results = [QueryResult(data=dict(zip(fields, result['data'])), score=result['score']) for result in internal_results['results']]
7878
info = SimpleSemanticsQueryInfo(
79-
similarity_metric=vector.VectorSimilarityMetric(internal_info['similarity_metric']),
79+
similarity_metric=index.VectorSimilarityMetric(internal_info['similarity_metric']),
8080
query_vector=internal_info['query_vector'],
8181
vector_field_name=internal_info['vector_field_name']
8282
)

python/cocoindex/storages.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from dataclasses import dataclass
33

44
from . import op
5+
from . import index
56
from .auth_registry import AuthEntryReference
67

78
class Postgres(op.StorageSpec):
@@ -42,7 +43,7 @@ class Neo4jRelationshipEndSpec:
4243
class Neo4jRelationshipNodeSpec:
4344
"""Spec for a Neo4j node type."""
4445
key_field_name: str | None = None
45-
46+
index_options: index.IndexOptions | None = None
4647
class Neo4jRelationship(op.StorageSpec):
4748
"""Graph storage powered by Neo4j."""
4849

python/cocoindex/vector.py

Lines changed: 0 additions & 6 deletions
This file was deleted.

src/base/spec.rs

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,13 +204,23 @@ pub struct CollectOpSpec {
204204
pub auto_uuid_field: Option<FieldName>,
205205
}
206206

207-
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq)]
207+
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
208208
pub enum VectorSimilarityMetric {
209209
CosineSimilarity,
210210
L2Distance,
211211
InnerProduct,
212212
}
213213

214+
impl std::fmt::Display for VectorSimilarityMetric {
215+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
216+
match self {
217+
VectorSimilarityMetric::CosineSimilarity => write!(f, "Cosine"),
218+
VectorSimilarityMetric::L2Distance => write!(f, "L2"),
219+
VectorSimilarityMetric::InnerProduct => write!(f, "InnerProduct"),
220+
}
221+
}
222+
}
223+
214224
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
215225
pub struct VectorIndexDef {
216226
pub field_name: FieldName,

src/base/value.rs

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,21 @@ impl std::fmt::Display for KeyValue {
174174
}
175175

176176
impl KeyValue {
177+
pub fn fields_iter<'a>(
178+
&'a self,
179+
num_fields: usize,
180+
) -> Result<impl Iterator<Item = &'a KeyValue>> {
181+
let slice = if num_fields == 1 {
182+
std::slice::from_ref(self)
183+
} else {
184+
match self {
185+
KeyValue::Struct(v) => v,
186+
_ => api_bail!("Invalid key value type"),
187+
}
188+
};
189+
Ok(slice.iter())
190+
}
191+
177192
fn parts_from_str(
178193
values_iter: &mut impl Iterator<Item = String>,
179194
schema: &ValueType,
@@ -533,6 +548,23 @@ impl From<KeyValue> for Value {
533548
}
534549
}
535550

551+
impl From<&KeyValue> for Value {
552+
fn from(value: &KeyValue) -> Self {
553+
match value {
554+
KeyValue::Bytes(v) => Value::Basic(BasicValue::Bytes(v.clone())),
555+
KeyValue::Str(v) => Value::Basic(BasicValue::Str(v.clone())),
556+
KeyValue::Bool(v) => Value::Basic(BasicValue::Bool(*v)),
557+
KeyValue::Int64(v) => Value::Basic(BasicValue::Int64(*v)),
558+
KeyValue::Range(v) => Value::Basic(BasicValue::Range(*v)),
559+
KeyValue::Uuid(v) => Value::Basic(BasicValue::Uuid(*v)),
560+
KeyValue::Date(v) => Value::Basic(BasicValue::Date(*v)),
561+
KeyValue::Struct(v) => Value::Struct(FieldValues {
562+
fields: v.iter().map(Value::from).collect(),
563+
}),
564+
}
565+
}
566+
}
567+
536568
impl From<FieldValues> for Value {
537569
fn from(value: FieldValues) -> Self {
538570
Value::Struct(value)

0 commit comments

Comments
 (0)