Skip to content

Commit 1b96684

Browse files
committed
Implement casting for Vector to HalfVec
1 parent bab73fc commit 1b96684

File tree

2 files changed

+47
-1
lines changed

2 files changed

+47
-1
lines changed

src/diesel_ext/expression_methods.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@ use diesel::expression::{AsExpression, Expression};
22
use diesel::pg::Pg;
33
use diesel::sql_types::{Double, SqlType};
44

5+
#[cfg(feature = "halfvec")]
6+
use crate::diesel_ext::halfvec::HalfVecCast;
7+
58
diesel::infix_operator!(L2Distance, " <-> ", Double, backend: Pg);
69
diesel::infix_operator!(MaxInnerProduct, " <#> ", Double, backend: Pg);
710
diesel::infix_operator!(CosineDistance, " <=> ", Double, backend: Pg);
@@ -10,6 +13,11 @@ diesel::infix_operator!(HammingDistance, " <~> ", Double, backend: Pg);
1013
diesel::infix_operator!(JaccardDistance, " <%> ", Double, backend: Pg);
1114

1215
pub trait VectorExpressionMethods: Expression + Sized {
16+
#[cfg(feature = "halfvec")]
17+
fn cast_to_halfvec(self, dim: usize) -> HalfVecCast<Self> {
18+
HalfVecCast::new(self, dim)
19+
}
20+
1321
fn l2_distance<T>(self, other: T) -> L2Distance<Self, T::Expression>
1422
where
1523
Self::SqlType: SqlType,

src/diesel_ext/halfvec.rs

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
use diesel::deserialize::{self, FromSql};
2+
use diesel::expression::{AppearsOnTable, Expression, SelectableExpression, ValidGrouping};
23
use diesel::pg::{Pg, PgValue};
3-
use diesel::query_builder::QueryId;
4+
use diesel::query_builder::{AstPass, QueryFragment, QueryId};
45
use diesel::serialize::{self, IsNull, Output, ToSql};
56
use diesel::sql_types::SqlType;
7+
use diesel::QueryResult;
68
use std::convert::TryFrom;
79
use std::io::Write;
810

@@ -32,6 +34,42 @@ impl FromSql<HalfVectorType, Pg> for HalfVector {
3234
}
3335
}
3436

37+
#[derive(Debug, Clone, Copy, QueryId, ValidGrouping)]
38+
pub struct HalfVecCast<Expr> {
39+
pub expr: Expr,
40+
pub dim: usize,
41+
}
42+
43+
impl<Expr> HalfVecCast<Expr> {
44+
pub fn new(expr: Expr, dim: usize) -> Self {
45+
Self { expr, dim }
46+
}
47+
}
48+
49+
impl<Expr> Expression for HalfVecCast<Expr>
50+
where
51+
Expr: Expression,
52+
{
53+
type SqlType = HalfVectorType;
54+
}
55+
56+
impl<Expr> QueryFragment<Pg> for HalfVecCast<Expr>
57+
where
58+
Expr: QueryFragment<Pg>,
59+
{
60+
fn walk_ast<'b>(&'b self, mut out: AstPass<'_, 'b, Pg>) -> QueryResult<()> {
61+
out.push_sql("(");
62+
self.expr.walk_ast(out.reborrow())?;
63+
out.push_sql(")::halfvec(");
64+
out.push_sql(&self.dim.to_string());
65+
out.push_sql(")");
66+
Ok(())
67+
}
68+
}
69+
70+
impl<Expr, QS> AppearsOnTable<QS> for HalfVecCast<Expr> where Expr: AppearsOnTable<QS> {}
71+
impl<Expr, QS> SelectableExpression<QS> for HalfVecCast<Expr> where Expr: SelectableExpression<QS> {}
72+
3573
#[cfg(test)]
3674
mod tests {
3775
use crate::{HalfVector, VectorExpressionMethods};

0 commit comments

Comments
 (0)