Skip to content

Commit 1ace76d

Browse files
authored
feat(query): Support vector functions inner_product, vector_dims and vector_norm (#18414)
1 parent 00a0693 commit 1ace76d

File tree

13 files changed

+724
-41
lines changed

13 files changed

+724
-41
lines changed

โ€Žsrc/common/vector/src/distance.rs

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,21 @@ pub fn l2_distance(lhs: &[f32], rhs: &[f32]) -> Result<f32> {
6666
.sqrt())
6767
}
6868

69+
pub fn inner_product(lhs: &[f32], rhs: &[f32]) -> Result<f32> {
70+
if lhs.len() != rhs.len() {
71+
return Err(ErrorCode::InvalidArgument(format!(
72+
"Vector length not equal: {:} != {:}",
73+
lhs.len(),
74+
rhs.len(),
75+
)));
76+
}
77+
78+
let a = ArrayView::from(lhs);
79+
let b = ArrayView::from(rhs);
80+
81+
Ok((&a * &b).sum())
82+
}
83+
6984
pub fn cosine_distance_64(lhs: &[f64], rhs: &[f64]) -> Result<f64> {
7085
if lhs.len() != rhs.len() {
7186
return Err(ErrorCode::InvalidArgument(format!(
@@ -115,3 +130,23 @@ pub fn l2_distance_64(lhs: &[f64], rhs: &[f64]) -> Result<f64> {
115130
.sum::<f64>()
116131
.sqrt())
117132
}
133+
134+
pub fn inner_product_64(lhs: &[f64], rhs: &[f64]) -> Result<f64> {
135+
if lhs.len() != rhs.len() {
136+
return Err(ErrorCode::InvalidArgument(format!(
137+
"Vector length not equal: {:} != {:}",
138+
lhs.len(),
139+
rhs.len(),
140+
)));
141+
}
142+
143+
let a = ArrayView::from(lhs);
144+
let b = ArrayView::from(rhs);
145+
146+
Ok((&a * &b).sum())
147+
}
148+
149+
pub fn vector_norm(vector: &[f32]) -> f32 {
150+
let a = ArrayView::from(vector);
151+
(&a * &a).sum().sqrt()
152+
}

โ€Žsrc/common/vector/src/lib.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,10 @@ mod distance;
1616

1717
pub use distance::cosine_distance;
1818
pub use distance::cosine_distance_64;
19+
pub use distance::inner_product;
20+
pub use distance::inner_product_64;
1921
pub use distance::l1_distance;
2022
pub use distance::l1_distance_64;
2123
pub use distance::l2_distance;
2224
pub use distance::l2_distance_64;
25+
pub use distance::vector_norm;

โ€Žsrc/query/ast/src/ast/expr.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1532,6 +1532,8 @@ pub enum BinaryOperator {
15321532
BitwiseXor,
15331533
BitwiseShiftLeft,
15341534
BitwiseShiftRight,
1535+
CosineDistance,
1536+
L1Distance,
15351537
L2Distance,
15361538
}
15371539

@@ -1560,6 +1562,8 @@ impl BinaryOperator {
15601562
BinaryOperator::BitwiseShiftLeft => "bit_shift_left".to_string(),
15611563
BinaryOperator::BitwiseShiftRight => "bit_shift_right".to_string(),
15621564
BinaryOperator::Caret => "pow".to_string(),
1565+
BinaryOperator::CosineDistance => "cosine_distance".to_string(),
1566+
BinaryOperator::L1Distance => "l1_distance".to_string(),
15631567
BinaryOperator::L2Distance => "l2_distance".to_string(),
15641568
BinaryOperator::LikeAny(_) => "like_any".to_string(),
15651569
BinaryOperator::Like(_) => "like".to_string(),
@@ -1667,6 +1671,12 @@ impl Display for BinaryOperator {
16671671
BinaryOperator::BitwiseShiftRight => {
16681672
write!(f, ">>")
16691673
}
1674+
BinaryOperator::CosineDistance => {
1675+
write!(f, "<=>")
1676+
}
1677+
BinaryOperator::L1Distance => {
1678+
write!(f, "<+>")
1679+
}
16701680
BinaryOperator::L2Distance => {
16711681
write!(f, "<->")
16721682
}

โ€Žsrc/query/ast/src/parser/expr.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -398,6 +398,8 @@ const fn binary_affix(op: &BinaryOperator) -> Affix {
398398
BinaryOperator::BitwiseOr => Affix::Infix(Precedence(22), Associativity::Left),
399399
BinaryOperator::BitwiseAnd => Affix::Infix(Precedence(22), Associativity::Left),
400400
BinaryOperator::BitwiseXor => Affix::Infix(Precedence(22), Associativity::Left),
401+
BinaryOperator::CosineDistance => Affix::Infix(Precedence(22), Associativity::Left),
402+
BinaryOperator::L1Distance => Affix::Infix(Precedence(22), Associativity::Left),
401403
BinaryOperator::L2Distance => Affix::Infix(Precedence(22), Associativity::Left),
402404
BinaryOperator::BitwiseShiftLeft => Affix::Infix(Precedence(23), Associativity::Left),
403405
BinaryOperator::BitwiseShiftRight => Affix::Infix(Precedence(23), Associativity::Left),
@@ -1541,6 +1543,8 @@ pub fn binary_op(i: Input) -> IResult<BinaryOperator> {
15411543
value(BinaryOperator::Div, rule! { DIV }),
15421544
value(BinaryOperator::Modulo, rule! { "%" }),
15431545
value(BinaryOperator::StringConcat, rule! { "||" }),
1546+
value(BinaryOperator::CosineDistance, rule! { "<=>" }),
1547+
value(BinaryOperator::L1Distance, rule! { "<+>" }),
15441548
value(BinaryOperator::L2Distance, rule! { "<->" }),
15451549
value(BinaryOperator::Gt, rule! { ">" }),
15461550
value(BinaryOperator::Lt, rule! { "<" }),

โ€Žsrc/query/ast/src/parser/token.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -803,7 +803,9 @@ pub enum TokenKind {
803803
SECONDARY,
804804
#[token("ROLES", ignore(ascii_case))]
805805
ROLES,
806-
/// L2DISTANCE op, from https://github.com/pgvector/pgvector
806+
/// L1DISTANCE and L2DISTANCE op, from https://github.com/pgvector/pgvector
807+
#[token("<+>")]
808+
L1DISTANCE,
807809
#[token("<->")]
808810
L2DISTANCE,
809811
#[token("LEADING", ignore(ascii_case))]
@@ -1537,6 +1539,7 @@ impl TokenKind {
15371539
| Abs
15381540
| SquareRoot
15391541
| CubeRoot
1542+
| L1DISTANCE
15401543
| L2DISTANCE
15411544
| Placeholder
15421545
| QuestionOr

โ€Žsrc/query/ast/tests/it/testdata/expr-error.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ error:
5252
--> SQL:1:10
5353
|
5454
1 | CAST(col1)
55-
| ---- ^ unexpected `)`, expecting `AS`, `,`, `(`, `IS`, `NOT`, `IN`, `LIKE`, `EXISTS`, `BETWEEN`, `+`, `-`, `*`, `/`, `//`, `DIV`, `%`, `||`, `<->`, `>`, `<`, `>=`, `<=`, `=`, `<>`, `!=`, `^`, `AND`, `OR`, `XOR`, `REGEXP`, `RLIKE`, `SOUNDS`, <BitWiseOr>, <BitWiseAnd>, <BitWiseXor>, <ShiftLeft>, <ShiftRight>, `->`, `->>`, `#>`, `#>>`, `?`, `?|`, `?&`, `@>`, `<@`, `@?`, `@@`, `#-`, <Factorial>, <SquareRoot>, <BitWiseNot>, <CubeRoot>, <Abs>, `CAST`, `TRY_CAST`, `::`, `POSITION`, `IdentVariable`, `DATEADD`, or 44 more ...
55+
| ---- ^ unexpected `)`, expecting `AS`, `,`, `(`, `IS`, `NOT`, `IN`, `LIKE`, `EXISTS`, `BETWEEN`, `+`, `-`, `*`, `/`, `//`, `DIV`, `%`, `||`, `<=>`, `<+>`, `<->`, `>`, `<`, `>=`, `<=`, `=`, `<>`, `!=`, `^`, `AND`, `OR`, `XOR`, `REGEXP`, `RLIKE`, `SOUNDS`, <BitWiseOr>, <BitWiseAnd>, <BitWiseXor>, <ShiftLeft>, <ShiftRight>, `->`, `->>`, `#>`, `#>>`, `?`, `?|`, `?&`, `@>`, `<@`, `@?`, `@@`, `#-`, <Factorial>, <SquareRoot>, <BitWiseNot>, <CubeRoot>, <Abs>, `CAST`, `TRY_CAST`, `::`, `POSITION`, or 46 more ...
5656
| |
5757
| while parsing `CAST(... AS ...)`
5858
| while parsing expression
@@ -81,7 +81,7 @@ error:
8181
1 | $ abc + 3
8282
| ^
8383
| |
84-
| unexpected `$`, expecting `IS`, `IN`, `LIKE`, `EXISTS`, `BETWEEN`, `+`, `-`, `*`, `/`, `//`, `DIV`, `%`, `||`, `<->`, `>`, `<`, `>=`, `<=`, `=`, `<>`, `!=`, `^`, `AND`, `OR`, `XOR`, `NOT`, `REGEXP`, `RLIKE`, `SOUNDS`, <BitWiseOr>, <BitWiseAnd>, <BitWiseXor>, <ShiftLeft>, <ShiftRight>, `->`, `->>`, `#>`, `#>>`, `?`, `?|`, `?&`, `@>`, `<@`, `@?`, `@@`, `#-`, <Factorial>, <SquareRoot>, <BitWiseNot>, <CubeRoot>, <Abs>, `CAST`, `TRY_CAST`, `::`, `POSITION`, `IdentVariable`, `DATEADD`, `DATE_ADD`, `DATE_DIFF`, `DATEDIFF`, or 42 more ...
84+
| unexpected `$`, expecting `IS`, `IN`, `LIKE`, `EXISTS`, `BETWEEN`, `+`, `-`, `*`, `/`, `//`, `DIV`, `%`, `||`, `<=>`, `<+>`, `<->`, `>`, `<`, `>=`, `<=`, `=`, `<>`, `!=`, `^`, `AND`, `OR`, `XOR`, `NOT`, `REGEXP`, `RLIKE`, `SOUNDS`, <BitWiseOr>, <BitWiseAnd>, <BitWiseXor>, <ShiftLeft>, <ShiftRight>, `->`, `->>`, `#>`, `#>>`, `?`, `?|`, `?&`, `@>`, `<@`, `@?`, `@@`, `#-`, <Factorial>, <SquareRoot>, <BitWiseNot>, <CubeRoot>, <Abs>, `CAST`, `TRY_CAST`, `::`, `POSITION`, `IdentVariable`, `DATEADD`, `DATE_ADD`, or 44 more ...
8585
| while parsing expression
8686

8787

โ€Žsrc/query/ast/tests/it/testdata/stmt-error.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -556,7 +556,7 @@ error:
556556
--> SQL:1:41
557557
|
558558
1 | SELECT * FROM t GROUP BY GROUPING SETS ()
559-
| ------ ^ unexpected `)`, expecting `(`, `IS`, `IN`, `LIKE`, `EXISTS`, `BETWEEN`, `+`, `-`, `*`, `/`, `//`, `DIV`, `%`, `||`, `<->`, `>`, `<`, `>=`, `<=`, `=`, `<>`, `!=`, `^`, `AND`, `OR`, `XOR`, `NOT`, `REGEXP`, `RLIKE`, `SOUNDS`, <BitWiseOr>, <BitWiseAnd>, <BitWiseXor>, <ShiftLeft>, <ShiftRight>, `->`, `->>`, `#>`, `#>>`, `?`, `?|`, `?&`, `@>`, `<@`, `@?`, `@@`, `#-`, <Factorial>, <SquareRoot>, <BitWiseNot>, <CubeRoot>, <Abs>, `CAST`, `TRY_CAST`, `::`, `POSITION`, `IdentVariable`, `DATEADD`, `DATE_ADD`, `DATE_DIFF`, or 42 more ...
559+
| ------ ^ unexpected `)`, expecting `(`, `IS`, `IN`, `LIKE`, `EXISTS`, `BETWEEN`, `+`, `-`, `*`, `/`, `//`, `DIV`, `%`, `||`, `<=>`, `<+>`, `<->`, `>`, `<`, `>=`, `<=`, `=`, `<>`, `!=`, `^`, `AND`, `OR`, `XOR`, `NOT`, `REGEXP`, `RLIKE`, `SOUNDS`, <BitWiseOr>, <BitWiseAnd>, <BitWiseXor>, <ShiftLeft>, <ShiftRight>, `->`, `->>`, `#>`, `#>>`, `?`, `?|`, `?&`, `@>`, `<@`, `@?`, `@@`, `#-`, <Factorial>, <SquareRoot>, <BitWiseNot>, <CubeRoot>, <Abs>, `CAST`, `TRY_CAST`, `::`, `POSITION`, `IdentVariable`, `DATEADD`, or 44 more ...
560560
| |
561561
| while parsing `SELECT ...`
562562

@@ -978,7 +978,7 @@ error:
978978
--> SQL:1:65
979979
|
980980
1 | CREATE FUNCTION IF NOT EXISTS isnotempty AS(p) -> not(is_null(p)
981-
| ------ -- ---- ^ unexpected end of input, expecting `)`, `(`, `WITHIN`, `IGNORE`, `RESPECT`, `OVER`, `IS`, `NOT`, `IN`, `LIKE`, `EXISTS`, `BETWEEN`, `+`, `-`, `*`, `/`, `//`, `DIV`, `%`, `||`, `<->`, `>`, `<`, `>=`, `<=`, `=`, `<>`, `!=`, `^`, `AND`, `OR`, `XOR`, `REGEXP`, `RLIKE`, `SOUNDS`, <BitWiseOr>, <BitWiseAnd>, <BitWiseXor>, <ShiftLeft>, <ShiftRight>, `->`, `->>`, `#>`, `#>>`, `?`, `?|`, `?&`, `@>`, `<@`, `@?`, `@@`, `#-`, <Factorial>, <SquareRoot>, <BitWiseNot>, <CubeRoot>, <Abs>, `CAST`, `TRY_CAST`, `::`, or 48 more ...
981+
| ------ -- ---- ^ unexpected end of input, expecting `)`, `(`, `WITHIN`, `IGNORE`, `RESPECT`, `OVER`, `IS`, `NOT`, `IN`, `LIKE`, `EXISTS`, `BETWEEN`, `+`, `-`, `*`, `/`, `//`, `DIV`, `%`, `||`, `<=>`, `<+>`, `<->`, `>`, `<`, `>=`, `<=`, `=`, `<>`, `!=`, `^`, `AND`, `OR`, `XOR`, `REGEXP`, `RLIKE`, `SOUNDS`, <BitWiseOr>, <BitWiseAnd>, <BitWiseXor>, <ShiftLeft>, <ShiftRight>, `->`, `->>`, `#>`, `#>>`, `?`, `?|`, `?&`, `@>`, `<@`, `@?`, `@@`, `#-`, <Factorial>, <SquareRoot>, <BitWiseNot>, <CubeRoot>, <Abs>, `CAST`, or 50 more ...
982982
| | | | |
983983
| | | | while parsing `(<expr> [, ...])`
984984
| | | while parsing expression

โ€Žsrc/query/expression/src/types/vector.rs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,13 @@ impl VectorScalar {
288288
VectorScalar::Float32(vals) => vals.len() * 4,
289289
}
290290
}
291+
292+
pub const fn dimension(&self) -> usize {
293+
match self {
294+
VectorScalar::Int8(vals) => vals.len(),
295+
VectorScalar::Float32(vals) => vals.len(),
296+
}
297+
}
291298
}
292299

293300
impl<'a> VectorScalarRef<'a> {
@@ -312,6 +319,13 @@ impl<'a> VectorScalarRef<'a> {
312319
}
313320
}
314321

322+
pub const fn dimension(&self) -> usize {
323+
match self {
324+
VectorScalarRef::Int8(vals) => vals.len(),
325+
VectorScalarRef::Float32(vals) => vals.len(),
326+
}
327+
}
328+
315329
pub fn data_type(&self) -> VectorDataType {
316330
match self {
317331
VectorScalarRef::Int8(vals) => VectorDataType::Int8(vals.len() as u64),

0 commit comments

Comments
ย (0)