Skip to content

Commit 0a3969a

Browse files
BohuTANGb41shmergify[bot]
authored
feat: add cosine_distance for vector similarity compute (#10737)
* feat: add cosine_distance for vector similarity compute * fix cosine_distance return value * cosine_distance support empty array * change the distance from vec to one * fix function list.txt * change to ndarray * change to ndarray * change to array view avoid mem copy * add logic test * fix unit test --------- Co-authored-by: baishen <[email protected]> Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
1 parent 7df8459 commit 0a3969a

File tree

16 files changed

+287
-0
lines changed

16 files changed

+287
-0
lines changed

Cargo.lock

Lines changed: 38 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ members = [
2626
"src/common/tracing",
2727
"src/common/storage",
2828
"src/common/profile",
29+
"src/common/vector",
2930
# Query
3031
"src/query/ast",
3132
"src/query/codegen",

src/common/base/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,5 +28,6 @@ pub mod containers;
2828
pub mod mem_allocator;
2929
pub mod rangemap;
3030
pub mod runtime;
31+
3132
pub use runtime::match_join_handle;
3233
pub use runtime::set_alloc_error_hook;

src/common/vector/Cargo.toml

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
[package]
2+
name = "common-vector"
3+
version = { workspace = true }
4+
authors = { workspace = true }
5+
license = { workspace = true }
6+
publish = { workspace = true }
7+
edition = { workspace = true }
8+
9+
[lib]
10+
doctest = false
11+
test = false
12+
13+
[dependencies] # In alphabetical order
14+
common-exception = { path = "../exception" }
15+
16+
ndarray = "0.15.6"
17+
18+
[build-dependencies]
19+
20+
[features]
21+
22+
[dev-dependencies]
23+
approx = "0.5.1"

src/common/vector/src/distance.rs

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
// Copyright 2023 Datafuse Labs.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
use common_exception::ErrorCode;
16+
use common_exception::Result;
17+
use ndarray::ArrayView;
18+
19+
pub fn cosine_distance(from: &[f32], to: &[f32]) -> Result<f32> {
20+
if from.len() != to.len() {
21+
return Err(ErrorCode::InvalidArgument(format!(
22+
"Vector length not equal: {:} != {:}",
23+
from.len(),
24+
to.len(),
25+
)));
26+
}
27+
28+
let a = ArrayView::from(from);
29+
let b = ArrayView::from(to);
30+
let aa_sum = (&a * &a).sum();
31+
let bb_sum = (&b * &b).sum();
32+
33+
Ok((&a * &b).sum() / ((aa_sum).sqrt() * (bb_sum).sqrt()))
34+
}

src/common/vector/src/lib.rs

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
// Copyright 2023 Datafuse Labs.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
mod distance;
16+
17+
pub use distance::cosine_distance;
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
// Copyright 2023 Datafuse Labs.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
use common_vector::cosine_distance;
16+
17+
#[test]
18+
fn test_cosine() {
19+
{
20+
let x: Vec<f32> = (1..9).map(|v| v as f32).collect();
21+
let y: Vec<f32> = (100..108).map(|v| v as f32).collect();
22+
let d = cosine_distance(&x, &y).unwrap();
23+
// from scipy.spatial.distance.cosine
24+
approx::assert_relative_eq!(d, 0.900_957);
25+
}
26+
27+
{
28+
let x = vec![3.0, 45.0, 7.0, 2.0, 5.0, 20.0, 13.0, 12.0];
29+
let y = vec![2.0, 54.0, 13.0, 15.0, 22.0, 34.0, 50.0, 1.0];
30+
let d = cosine_distance(&x, &y).unwrap();
31+
// from sklearn.metrics.pairwise import cosine_similarity
32+
approx::assert_relative_eq!(d, 0.873_580_6);
33+
}
34+
35+
{
36+
let x = vec![3.0, 45.0, 7.0, 2.0, 5.0, 20.0, 13.0, 12.0];
37+
let y = vec![2.0, 54.0];
38+
let d = cosine_distance(&x, &y);
39+
assert!(d.is_err());
40+
}
41+
}

src/common/vector/tests/it/main.rs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
// Copyright 2023 Datafuse Labs.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
mod distance;

src/query/functions/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ common-exception = { path = "../../common/exception" }
1717
common-expression = { path = "../expression" }
1818
common-hashtable = { path = "../../common/hashtable" }
1919
common-io = { path = "../../common/io" }
20+
common-vector = { path = "../../common/vector" }
2021
jsonb = { workspace = true }
2122

2223
# Crates.io dependencies

src/query/functions/src/scalars/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ mod map;
2525
mod math;
2626
mod tuple;
2727
mod variant;
28+
mod vector;
2829

2930
mod comparison;
3031
mod decimal;
@@ -55,4 +56,5 @@ pub fn register(registry: &mut FunctionRegistry) {
5556
hash::register(registry);
5657
other::register(registry);
5758
decimal::register(registry);
59+
vector::register(registry);
5860
}

0 commit comments

Comments
 (0)