Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions src/distance/angular.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
use std::borrow::Cow;

use bytemuck::{Pod, Zeroable};
use rand::Rng;

Expand Down Expand Up @@ -67,15 +65,15 @@ impl Distance for Angular {
fn create_split<'a, R: Rng>(
children: &'a ImmutableSubsetLeafs<Self>,
rng: &mut R,
) -> heed::Result<Cow<'a, UnalignedVector<Self::VectorCodec>>> {
) -> heed::Result<Leaf<'static, Self>> {
let [node_p, node_q] = two_means(rng, children, true)?;
let vector: Vec<f32> =
node_p.vector.iter().zip(node_q.vector.iter()).map(|(p, q)| p - q).collect();
let unaligned_vector = UnalignedVector::from_vec(vector);
let mut normal = Leaf { header: NodeHeaderAngular { norm: 0.0 }, vector: unaligned_vector };
Self::normalize(&mut normal);

Ok(normal.vector)
Ok(normal)
}

fn margin_no_header(
Expand Down
6 changes: 2 additions & 4 deletions src/distance/binary_quantized_angular.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
use std::borrow::Cow;

use bytemuck::{Pod, Zeroable};
use rand::Rng;

Expand Down Expand Up @@ -72,7 +70,7 @@ impl Distance for BinaryQuantizedAngular {
fn create_split<'a, R: Rng>(
children: &'a ImmutableSubsetLeafs<Self>,
rng: &mut R,
) -> heed::Result<Cow<'a, UnalignedVector<Self::VectorCodec>>> {
) -> heed::Result<Leaf<'static, Self>> {
let [node_p, node_q] = two_means::<Self, Angular, R>(rng, children, true)?;
let vector: Vec<f32> =
node_p.vector.iter().zip(node_q.vector.iter()).map(|(p, q)| p - q).collect();
Expand All @@ -83,7 +81,7 @@ impl Distance for BinaryQuantizedAngular {
};
Self::normalize(&mut normal);

Ok(normal.vector)
Ok(normal)
}

fn margin_no_header(
Expand Down
28 changes: 22 additions & 6 deletions src/distance/binary_quantized_euclidean.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
use std::borrow::Cow;

use bytemuck::{Pod, Zeroable};
use rand::Rng;

use super::{two_means_binary_quantized as two_means, Euclidean};
use super::{two_means_binary_quantized as two_means, Euclidean, NodeHeaderEuclidean};
use crate::distance::Distance;
use crate::node::Leaf;
use crate::parallel::ImmutableSubsetLeafs;
Expand Down Expand Up @@ -59,17 +57,35 @@ impl Distance for BinaryQuantizedEuclidean {
fn create_split<'a, R: Rng>(
children: &'a ImmutableSubsetLeafs<Self>,
rng: &mut R,
) -> heed::Result<Cow<'a, UnalignedVector<Self::VectorCodec>>> {
) -> heed::Result<Leaf<'static, Self>> {
let [node_p, node_q] = two_means::<Self, Euclidean, R>(rng, children, false)?;
let vector: Vec<f32> =
node_p.vector.iter().zip(node_q.vector.iter()).map(|(p, q)| p - q).collect();
let mut normal = Leaf {
header: NodeHeaderBinaryQuantizedEuclidean { bias: 0.0 },
header: NodeHeaderBinaryQuantizedEuclidean::zeroed(),
vector: UnalignedVector::from_slice(&vector),
};
Self::normalize(&mut normal);

Ok(Cow::Owned(normal.vector.into_owned()))
normal.header.bias = normal
.vector
.iter()
.zip(
UnalignedVector::<BinaryQuantized>::from_slice(
&node_p.vector.iter().collect::<Vec<_>>(),
)
.iter(),
)
.zip(
UnalignedVector::<BinaryQuantized>::from_slice(
&node_q.vector.iter().collect::<Vec<_>>(),
)
.iter(),
)
.map(|((n, p), q)| -n * (p + q) / 2.0)
.sum();

Ok(normal.into_owned())
}

fn margin(p: &Leaf<Self>, q: &Leaf<Self>) -> f32 {
Expand Down
26 changes: 22 additions & 4 deletions src/distance/binary_quantized_manhattan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::borrow::Cow;
use bytemuck::{Pod, Zeroable};
use rand::Rng;

use super::{two_means_binary_quantized as two_means, Manhattan};
use super::{two_means_binary_quantized as two_means, Manhattan, NodeHeaderManhattan};
use crate::distance::Distance;
use crate::node::Leaf;
use crate::parallel::ImmutableSubsetLeafs;
Expand Down Expand Up @@ -63,17 +63,35 @@ impl Distance for BinaryQuantizedManhattan {
fn create_split<'a, R: Rng>(
children: &'a ImmutableSubsetLeafs<Self>,
rng: &mut R,
) -> heed::Result<Cow<'a, UnalignedVector<Self::VectorCodec>>> {
) -> heed::Result<Leaf<'static, Self>> {
let [node_p, node_q] = two_means::<Self, Manhattan, R>(rng, children, false)?;
let vector: Vec<f32> =
node_p.vector.iter().zip(node_q.vector.iter()).map(|(p, q)| p - q).collect();
let mut normal = Leaf {
header: NodeHeaderBinaryQuantizedManhattan { bias: 0.0 },
header: NodeHeaderBinaryQuantizedManhattan::zeroed(),
vector: UnalignedVector::from_slice(&vector),
};
Self::normalize(&mut normal);

Ok(Cow::Owned(normal.vector.into_owned()))
normal.header.bias = normal
.vector
.iter()
.zip(
UnalignedVector::<BinaryQuantized>::from_slice(
&node_p.vector.iter().collect::<Vec<_>>(),
)
.iter(),
)
.zip(
UnalignedVector::<BinaryQuantized>::from_slice(
&node_q.vector.iter().collect::<Vec<_>>(),
)
.iter(),
)
.map(|((n, p), q)| -n * (p + q) / 2.0)
.sum();

Ok(normal.into_owned())
}

fn margin(p: &Leaf<Self>, q: &Leaf<Self>) -> f32 {
Expand Down
4 changes: 2 additions & 2 deletions src/distance/dot_product.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ impl Distance for DotProduct {
fn create_split<'a, R: Rng>(
children: &'a ImmutableSubsetLeafs<Self>,
rng: &mut R,
) -> heed::Result<Cow<'a, UnalignedVector<Self::VectorCodec>>> {
) -> heed::Result<Leaf<'static, Self>> {
let [node_p, node_q] = two_means(rng, children, true)?;
let vector: Vec<f32> =
node_p.vector.iter().zip(node_q.vector.iter()).map(|(p, q)| p - q).collect();
Expand All @@ -101,7 +101,7 @@ impl Distance for DotProduct {
normal.header.extra_dim = node_p.header.extra_dim - node_q.header.extra_dim;
Self::normalize(&mut normal);

Ok(normal.vector)
Ok(normal)
}

fn margin(p: &Leaf<Self>, q: &Leaf<Self>) -> f32 {
Expand Down
6 changes: 2 additions & 4 deletions src/distance/euclidean.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
use std::borrow::Cow;

use bytemuck::{Pod, Zeroable};
use rand::Rng;

Expand Down Expand Up @@ -50,7 +48,7 @@ impl Distance for Euclidean {
fn create_split<'a, R: Rng>(
children: &'a ImmutableSubsetLeafs<Self>,
rng: &mut R,
) -> heed::Result<Cow<'a, UnalignedVector<Self::VectorCodec>>> {
) -> heed::Result<Leaf<'static, Self>> {
let [node_p, node_q] = two_means(rng, children, false)?;
let vector: Vec<_> =
node_p.vector.iter().zip(node_q.vector.iter()).map(|(p, q)| p - q).collect();
Expand All @@ -68,7 +66,7 @@ impl Distance for Euclidean {
.map(|((n, p), q)| -n * (p + q) / 2.0)
.sum();

Ok(normal.vector)
Ok(normal)
}

fn margin(p: &Leaf<Self>, q: &Leaf<Self>) -> f32 {
Expand Down
4 changes: 2 additions & 2 deletions src/distance/manhattan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ impl Distance for Manhattan {
fn create_split<'a, R: Rng>(
children: &'a ImmutableSubsetLeafs<Self>,
rng: &mut R,
) -> heed::Result<Cow<'a, UnalignedVector<Self::VectorCodec>>> {
) -> heed::Result<Leaf<'static, Self>> {
let [node_p, node_q] = two_means(rng, children, false)?;
let vector: Vec<_> =
node_p.vector.iter().zip(node_q.vector.iter()).map(|(p, q)| p - q).collect();
Expand All @@ -71,7 +71,7 @@ impl Distance for Manhattan {
.map(|((n, p), q)| -n * (p + q) / 2.0)
.sum();

Ok(normal.vector)
Ok(normal)
}

fn margin(p: &Leaf<Self>, q: &Leaf<Self>) -> f32 {
Expand Down
12 changes: 4 additions & 8 deletions src/distance/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ mod dot_product;
mod euclidean;
mod manhattan;

fn new_leaf<D: Distance>(vec: Vec<f32>) -> Leaf<'static, D> {
pub fn new_leaf<D: Distance>(vec: Vec<f32>) -> Leaf<'static, D> {
let vector = UnalignedVector::from_vec(vec);
Leaf { header: D::new_header(&vector), vector }
}
Expand Down Expand Up @@ -97,7 +97,7 @@ pub trait Distance: Send + Sync + Sized + Clone + fmt::Debug + 'static {
fn create_split<'a, R: Rng>(
children: &'a ImmutableSubsetLeafs<Self>,
rng: &mut R,
) -> heed::Result<Cow<'a, UnalignedVector<Self::VectorCodec>>>;
) -> heed::Result<Leaf<'static, Self>>;

fn margin(p: &Leaf<Self>, q: &Leaf<Self>) -> f32 {
Self::margin_no_header(&p.vector, &q.vector)
Expand All @@ -108,12 +108,8 @@ pub trait Distance: Send + Sync + Sized + Clone + fmt::Debug + 'static {
q: &UnalignedVector<Self::VectorCodec>,
) -> f32;

fn side<R: Rng>(
normal_plane: &UnalignedVector<Self::VectorCodec>,
node: &Leaf<Self>,
rng: &mut R,
) -> Side {
let dot = Self::margin_no_header(&node.vector, normal_plane);
fn side<R: Rng>(normal_plane: &Leaf<Self>, node: &Leaf<Self>, rng: &mut R) -> Side {
let dot = Self::margin(normal_plane, node);
if dot > 0.0 {
Side::Right
} else if dot < 0.0 {
Expand Down
10 changes: 7 additions & 3 deletions src/node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ impl fmt::Debug for ItemIds<'_> {
pub struct SplitPlaneNormal<'a, D: Distance> {
pub left: NodeId,
pub right: NodeId,
pub normal: Cow<'a, UnalignedVector<D::VectorCodec>>,
pub normal: Leaf<'a, D>,
}

impl<D: Distance> fmt::Debug for SplitPlaneNormal<'_, D> {
Expand Down Expand Up @@ -153,7 +153,8 @@ impl<'a, D: Distance> BytesEncode<'a> for NodeCodec<D> {
bytes.push(SPLIT_PLANE_NORMAL_TAG);
bytes.extend_from_slice(&left.to_bytes());
bytes.extend_from_slice(&right.to_bytes());
bytes.extend_from_slice(normal.as_bytes());
bytes.extend_from_slice(bytes_of(&normal.header));
bytes.extend_from_slice(normal.vector.as_bytes());
}
Node::Descendants(Descendants { descendants }) => {
bytes.push(DESCENDANTS_TAG);
Expand All @@ -179,8 +180,11 @@ impl<'a, D: Distance> BytesDecode<'a> for NodeCodec<D> {
[SPLIT_PLANE_NORMAL_TAG, bytes @ ..] => {
let (left, bytes) = NodeId::from_bytes(bytes);
let (right, bytes) = NodeId::from_bytes(bytes);
let (header_bytes, remaining) = bytes.split_at(size_of::<D::Header>());
let header = pod_read_unaligned(header_bytes);
let vector = UnalignedVector::<D::VectorCodec>::from_bytes(remaining)?;
Ok(Node::SplitPlaneNormal(SplitPlaneNormal {
normal: UnalignedVector::<D::VectorCodec>::from_bytes(bytes)?,
normal: Leaf { header, vector },
left,
right,
}))
Expand Down
4 changes: 2 additions & 2 deletions src/reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ impl<'t, D: Distance> Reader<'t, D> {
Node::SplitPlaneNormal(SplitPlaneNormal { normal, left, right }) => {
let left = recursive_depth(rtxn, database, index, left)?;
let right = recursive_depth(rtxn, database, index, right)?;
let is_zero_normal = normal.is_zero() as usize;
let is_zero_normal = normal.vector.is_zero() as usize;

Ok(TreeStats {
depth: 1 + left.depth.max(right.depth),
Expand Down Expand Up @@ -258,7 +258,7 @@ impl<'t, D: Distance> Reader<'t, D> {
}
}
Node::SplitPlaneNormal(SplitPlaneNormal { normal, left, right }) => {
let margin = D::margin_no_header(&normal, &query_leaf.vector);
let margin = D::margin(&normal, &query_leaf);
queue.push((OrderedFloat(D::pq_distance(dist, margin, Side::Left)), left));
queue.push((OrderedFloat(D::pq_distance(dist, margin, Side::Right)), right));
}
Expand Down
6 changes: 3 additions & 3 deletions src/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use rayon::iter::repeatn;
use rayon::prelude::*;
use roaring::RoaringBitmap;

use crate::distance::Distance;
use crate::distance::{new_leaf, Distance};
use crate::internals::{KeyCodec, Side};
use crate::item_iter::ItemIter;
use crate::node::{Descendants, ItemIds, Leaf, SplitPlaneNormal};
Expand Down Expand Up @@ -563,7 +563,7 @@ impl<D: Distance> Writer<D> {
let mut left_ids = RoaringBitmap::new();
let mut right_ids = RoaringBitmap::new();

if normal.is_zero() {
if normal.vector.is_zero() {
randomly_split_children(rng, to_insert, &mut left_ids, &mut right_ids);
} else {
for leaf in to_insert {
Expand Down Expand Up @@ -730,7 +730,7 @@ impl<D: Distance> Writer<D> {
let mut children_left = RoaringBitmap::new();
let mut children_right = RoaringBitmap::new();
randomly_split_children(rng, item_indices, &mut children_left, &mut children_right);
UnalignedVector::reset(&mut normal);
UnalignedVector::reset(&mut normal.vector);

(children_left, children_right)
} else {
Expand Down