|
1 | 1 | use diesel::deserialize::{self, FromSql}; |
| 2 | +use diesel::expression::{AppearsOnTable, Expression, SelectableExpression, ValidGrouping}; |
2 | 3 | use diesel::pg::{Pg, PgValue}; |
3 | | -use diesel::query_builder::QueryId; |
| 4 | +use diesel::query_builder::{AstPass, QueryFragment, QueryId}; |
4 | 5 | use diesel::serialize::{self, IsNull, Output, ToSql}; |
5 | 6 | use diesel::sql_types::SqlType; |
| 7 | +use diesel::QueryResult; |
6 | 8 | use std::convert::TryFrom; |
7 | 9 | use std::io::Write; |
8 | 10 |
|
@@ -32,6 +34,42 @@ impl FromSql<HalfVectorType, Pg> for HalfVector { |
32 | 34 | } |
33 | 35 | } |
34 | 36 |
|
| 37 | +#[derive(Debug, Clone, Copy, QueryId, ValidGrouping)] |
| 38 | +pub struct HalfVecCast<Expr> { |
| 39 | + pub expr: Expr, |
| 40 | + pub dim: usize, |
| 41 | +} |
| 42 | + |
| 43 | +impl<Expr> HalfVecCast<Expr> { |
| 44 | + pub fn new(expr: Expr, dim: usize) -> Self { |
| 45 | + Self { expr, dim } |
| 46 | + } |
| 47 | +} |
| 48 | + |
| 49 | +impl<Expr> Expression for HalfVecCast<Expr> |
| 50 | +where |
| 51 | + Expr: Expression, |
| 52 | +{ |
| 53 | + type SqlType = HalfVectorType; |
| 54 | +} |
| 55 | + |
| 56 | +impl<Expr> QueryFragment<Pg> for HalfVecCast<Expr> |
| 57 | +where |
| 58 | + Expr: QueryFragment<Pg>, |
| 59 | +{ |
| 60 | + fn walk_ast<'b>(&'b self, mut out: AstPass<'_, 'b, Pg>) -> QueryResult<()> { |
| 61 | + out.push_sql("("); |
| 62 | + self.expr.walk_ast(out.reborrow())?; |
| 63 | + out.push_sql(")::halfvec("); |
| 64 | + out.push_sql(&self.dim.to_string()); |
| 65 | + out.push_sql(")"); |
| 66 | + Ok(()) |
| 67 | + } |
| 68 | +} |
| 69 | + |
| 70 | +impl<Expr, QS> AppearsOnTable<QS> for HalfVecCast<Expr> where Expr: AppearsOnTable<QS> {} |
| 71 | +impl<Expr, QS> SelectableExpression<QS> for HalfVecCast<Expr> where Expr: SelectableExpression<QS> {} |
| 72 | + |
35 | 73 | #[cfg(test)] |
36 | 74 | mod tests { |
37 | 75 | use crate::{HalfVector, VectorExpressionMethods}; |
|
0 commit comments