Skip to content

Commit 8af786c

Browse files
authored
chore: change cosine_distance to 1 - e (#10844)
* chore: change cosine_distance to 1 - e * fix ut * fix sql logic test
1 parent 42b5060 commit 8af786c

File tree

5 files changed

+8
-8
lines changed

5 files changed

+8
-8
lines changed

src/common/vector/src/distance.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,5 +30,5 @@ pub fn cosine_distance(from: &[f32], to: &[f32]) -> Result<f32> {
3030
let aa_sum = (&a * &a).sum();
3131
let bb_sum = (&b * &b).sum();
3232

33-
Ok((&a * &b).sum() / ((aa_sum).sqrt() * (bb_sum).sqrt()))
33+
Ok(1.0 - (&a * &b).sum() / ((aa_sum).sqrt() * (bb_sum).sqrt()))
3434
}

src/common/vector/tests/it/distance.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,15 @@ fn test_cosine() {
2121
let y: Vec<f32> = (100..108).map(|v| v as f32).collect();
2222
let d = cosine_distance(&x, &y).unwrap();
2323
// from scipy.spatial.distance.cosine
24-
approx::assert_relative_eq!(d, 0.900_957);
24+
approx::assert_relative_eq!(d, 1.0 - 0.900_957);
2525
}
2626

2727
{
2828
let x = vec![3.0, 45.0, 7.0, 2.0, 5.0, 20.0, 13.0, 12.0];
2929
let y = vec![2.0, 54.0, 13.0, 15.0, 22.0, 34.0, 50.0, 1.0];
3030
let d = cosine_distance(&x, &y).unwrap();
3131
// from sklearn.metrics.pairwise import cosine_similarity
32-
approx::assert_relative_eq!(d, 0.873_580_6);
32+
approx::assert_relative_eq!(d, 1.0 - 0.873_580_6);
3333
}
3434

3535
{

src/query/functions/tests/it/scalars/testdata/vector.txt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,16 @@ evaluation:
88
| Type | Float32 | Float32 | Float32 |
99
| Domain | {0..=2} | {3..=5} | Unknown |
1010
| Row 0 | 0 | 3 | NaN |
11-
| Row 1 | 1 | 4 | 1 |
12-
| Row 2 | 2 | 5 | 1 |
11+
| Row 1 | 1 | 4 | 0 |
12+
| Row 2 | 2 | 5 | 0 |
1313
+--------+---------+---------+---------+
1414
evaluation (internal):
1515
+--------+----------------------+
1616
| Column | Data |
1717
+--------+----------------------+
1818
| a | Float32([0, 1, 2]) |
1919
| b | Float32([3, 4, 5]) |
20-
| Output | Float32([NaN, 1, 1]) |
20+
| Output | Float32([NaN, 0, 0]) |
2121
+--------+----------------------+
2222

2323

tests/sqllogictests/suites/query/02_function/02_0063_function_vector

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@
22
query F
33
select cosine_distance([3.0, 45.0, 7.0, 2.0, 5.0, 20.0, 13.0, 12.0], [2.0, 54.0, 13.0, 15.0, 22.0, 34.0, 50.0, 1.0]) as sim
44
----
5-
0.8735807
5+
0.1264193

website/blog/2023-03-24-databend-weekly-86.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ Databend has added a new function called `cosine_distance`. This function accept
8787
```sql
8888
select cosine_distance([3.0, 45.0, 7.0, 2.0, 5.0, 20.0, 13.0, 12.0], [2.0, 54.0, 13.0, 15.0, 22.0, 34.0, 50.0, 1.0]) as sim
8989
----
90-
0.8735807
90+
0.1264193
9191
```
9292

9393
The Rust implementation efficiently performs calculations by utilizing the `ArrayView` type from the [ndarray](https://crates.io/crates/ndarray) crate.

0 commit comments

Comments
 (0)