Skip to content

Commit a1818af

Browse files
authored
Correct similarity score for cosine / inner product. (#44)
pgvector is using cosine *distance* and *negative* inner product, so we need a conversion.
1 parent dab8a6f commit a1818af

File tree

1 file changed

+13
-3
lines changed

1 file changed

+13
-3
lines changed

src/ops/storages/postgres.rs

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -369,7 +369,7 @@ impl QueryTarget for Executor {
369369
let query_str = format!(
370370
"SELECT {} {} $1 AS {SCORE_FIELD_NAME}, {} FROM {} ORDER BY {SCORE_FIELD_NAME} LIMIT $2",
371371
ValidIdentifier::try_from(query.vector_field_name)?,
372-
to_similarity_operator(query.similarity_metric),
372+
to_distance_operator(query.similarity_metric),
373373
self.all_fields_comma_separated,
374374
self.table_name,
375375
);
@@ -380,7 +380,7 @@ impl QueryTarget for Executor {
380380
.await?
381381
.into_iter()
382382
.map(|r| -> Result<QueryResult> {
383-
let score: f64 = r.try_get(0)?;
383+
let score: f64 = distance_to_similarity(query.similarity_metric, r.try_get(0)?);
384384
let data = self
385385
.key_fields_schema
386386
.iter()
@@ -400,14 +400,24 @@ impl QueryTarget for Executor {
400400
}
401401
}
402402

403-
fn to_similarity_operator(metric: VectorSimilarityMetric) -> &'static str {
403+
fn to_distance_operator(metric: VectorSimilarityMetric) -> &'static str {
404404
match metric {
405405
VectorSimilarityMetric::CosineSimilarity => "<=>",
406406
VectorSimilarityMetric::L2Distance => "<->",
407407
VectorSimilarityMetric::InnerProduct => "<#>",
408408
}
409409
}
410410

411+
fn distance_to_similarity(metric: VectorSimilarityMetric, distance: f64) -> f64 {
412+
match metric {
413+
// cosine distance => cosine similarity
414+
VectorSimilarityMetric::CosineSimilarity => 1.0 - distance,
415+
VectorSimilarityMetric::L2Distance => distance,
416+
// negative inner product => inner product
417+
VectorSimilarityMetric::InnerProduct => -distance,
418+
}
419+
}
420+
411421
pub struct Factory {
412422
db_pools: Mutex<
413423
HashMap<

0 commit comments

Comments
 (0)