Skip to content

Commit 693f77d

Browse files
authored
Merge pull request #132 from nnethercott/add-header-to-normal
Add bias info to SplitPlaneNormal
2 parents c668bb7 + 939ae4a commit 693f77d

18 files changed

+921
-947
lines changed

src/distance/binary_quantized_cosine.rs

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use std::borrow::Cow;
1+
use std::fmt;
22

33
use bytemuck::{Pod, Zeroable};
44
use rand::Rng;
@@ -20,10 +20,17 @@ pub enum BinaryQuantizedCosine {}
2020

2121
/// The header of `BinaryQuantizedCosine` leaf nodes.
2222
#[repr(C)]
23-
#[derive(Pod, Zeroable, Debug, Clone, Copy)]
23+
#[derive(Pod, Zeroable, Clone, Copy)]
2424
pub struct NodeHeaderBinaryQuantizedCosine {
2525
norm: f32,
2626
}
27+
impl fmt::Debug for NodeHeaderBinaryQuantizedCosine {
28+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
29+
f.debug_struct("NodeHeaderBinaryQuantizedCosine")
30+
.field("norm", &format!("{:.4}", self.norm))
31+
.finish()
32+
}
33+
}
2734

2835
impl Distance for BinaryQuantizedCosine {
2936
const DEFAULT_OVERSAMPLING: usize = 3;
@@ -72,7 +79,7 @@ impl Distance for BinaryQuantizedCosine {
7279
fn create_split<'a, R: Rng>(
7380
children: &'a ImmutableSubsetLeafs<Self>,
7481
rng: &mut R,
75-
) -> heed::Result<Cow<'a, UnalignedVector<Self::VectorCodec>>> {
82+
) -> heed::Result<Leaf<'a, Self>> {
7683
let [node_p, node_q] = two_means::<Self, Cosine, R>(rng, children, true)?;
7784
let vector: Vec<f32> =
7885
node_p.vector.iter().zip(node_q.vector.iter()).map(|(p, q)| p - q).collect();
@@ -83,13 +90,10 @@ impl Distance for BinaryQuantizedCosine {
8390
};
8491
Self::normalize(&mut normal);
8592

86-
Ok(normal.vector)
93+
Ok(normal)
8794
}
8895

89-
fn margin_no_header(
90-
p: &UnalignedVector<Self::VectorCodec>,
91-
q: &UnalignedVector<Self::VectorCodec>,
92-
) -> f32 {
93-
dot_product_binary_quantized(p, q)
96+
fn margin(p: &Leaf<Self>, q: &Leaf<Self>) -> f32 {
97+
dot_product_binary_quantized(&p.vector, &q.vector)
9498
}
9599
}

src/distance/binary_quantized_euclidean.rs

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use std::borrow::Cow;
1+
use std::fmt;
22

33
use bytemuck::{Pod, Zeroable};
44
use rand::Rng;
@@ -21,11 +21,18 @@ pub enum BinaryQuantizedEuclidean {}
2121

2222
/// The header of `BinaryQuantizedEuclidean` leaf nodes.
2323
#[repr(C)]
24-
#[derive(Pod, Zeroable, Debug, Clone, Copy)]
24+
#[derive(Pod, Zeroable, Clone, Copy)]
2525
pub struct NodeHeaderBinaryQuantizedEuclidean {
2626
/// An extra constant term to determine the offset of the plane
2727
bias: f32,
2828
}
29+
impl fmt::Debug for NodeHeaderBinaryQuantizedEuclidean {
30+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
31+
f.debug_struct("NodeHeaderBinaryQuantizedEuclidean")
32+
.field("bias", &format!("{:.4}", self.bias))
33+
.finish()
34+
}
35+
}
2936

3037
impl Distance for BinaryQuantizedEuclidean {
3138
const DEFAULT_OVERSAMPLING: usize = 3;
@@ -59,29 +66,30 @@ impl Distance for BinaryQuantizedEuclidean {
5966
fn create_split<'a, R: Rng>(
6067
children: &'a ImmutableSubsetLeafs<Self>,
6168
rng: &mut R,
62-
) -> heed::Result<Cow<'a, UnalignedVector<Self::VectorCodec>>> {
69+
) -> heed::Result<Leaf<'a, Self>> {
6370
let [node_p, node_q] = two_means::<Self, Euclidean, R>(rng, children, false)?;
6471
let vector: Vec<f32> =
6572
node_p.vector.iter().zip(node_q.vector.iter()).map(|(p, q)| p - q).collect();
6673
let mut normal = Leaf {
6774
header: NodeHeaderBinaryQuantizedEuclidean { bias: 0.0 },
68-
vector: UnalignedVector::from_slice(&vector),
75+
vector: UnalignedVector::from_vec(vector),
6976
};
7077
Self::normalize(&mut normal);
7178

72-
Ok(Cow::Owned(normal.vector.into_owned()))
79+
normal.header.bias = normal
80+
.vector
81+
.iter()
82+
.zip(UnalignedVector::<BinaryQuantized>::from_vec(node_p.vector.to_vec()).iter())
83+
.zip(UnalignedVector::<BinaryQuantized>::from_vec(node_q.vector.to_vec()).iter())
84+
.map(|((n, p), q)| -n * (p + q) / 2.0)
85+
.sum();
86+
87+
Ok(normal)
7388
}
7489

7590
fn margin(p: &Leaf<Self>, q: &Leaf<Self>) -> f32 {
7691
p.header.bias + dot_product_binary_quantized(&p.vector, &q.vector)
7792
}
78-
79-
fn margin_no_header(
80-
p: &UnalignedVector<Self::VectorCodec>,
81-
q: &UnalignedVector<Self::VectorCodec>,
82-
) -> f32 {
83-
dot_product_binary_quantized(p, q)
84-
}
8593
}
8694

8795
/// For the binary quantized squared euclidean distance:

src/distance/binary_quantized_manhattan.rs

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use std::borrow::Cow;
1+
use std::fmt;
22

33
use bytemuck::{Pod, Zeroable};
44
use rand::Rng;
@@ -20,11 +20,18 @@ pub enum BinaryQuantizedManhattan {}
2020

2121
/// The header of BinaryQuantizedEuclidean leaf nodes.
2222
#[repr(C)]
23-
#[derive(Pod, Zeroable, Debug, Clone, Copy)]
23+
#[derive(Pod, Zeroable, Clone, Copy)]
2424
pub struct NodeHeaderBinaryQuantizedManhattan {
2525
/// An extra constant term to determine the offset of the plane
2626
bias: f32,
2727
}
28+
impl fmt::Debug for NodeHeaderBinaryQuantizedManhattan {
29+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
30+
f.debug_struct("NodeHeaderBinaryQuantizedManhattan")
31+
.field("bias", &format!("{:.4}", self.bias))
32+
.finish()
33+
}
34+
}
2835

2936
impl Distance for BinaryQuantizedManhattan {
3037
const DEFAULT_OVERSAMPLING: usize = 3;
@@ -63,29 +70,30 @@ impl Distance for BinaryQuantizedManhattan {
6370
fn create_split<'a, R: Rng>(
6471
children: &'a ImmutableSubsetLeafs<Self>,
6572
rng: &mut R,
66-
) -> heed::Result<Cow<'a, UnalignedVector<Self::VectorCodec>>> {
73+
) -> heed::Result<Leaf<'a, Self>> {
6774
let [node_p, node_q] = two_means::<Self, Manhattan, R>(rng, children, false)?;
6875
let vector: Vec<f32> =
6976
node_p.vector.iter().zip(node_q.vector.iter()).map(|(p, q)| p - q).collect();
7077
let mut normal = Leaf {
7178
header: NodeHeaderBinaryQuantizedManhattan { bias: 0.0 },
72-
vector: UnalignedVector::from_slice(&vector),
79+
vector: UnalignedVector::from_vec(vector),
7380
};
7481
Self::normalize(&mut normal);
7582

76-
Ok(Cow::Owned(normal.vector.into_owned()))
83+
normal.header.bias = normal
84+
.vector
85+
.iter()
86+
.zip(UnalignedVector::<BinaryQuantized>::from_vec(node_p.vector.to_vec()).iter())
87+
.zip(UnalignedVector::<BinaryQuantized>::from_vec(node_q.vector.to_vec()).iter())
88+
.map(|((n, p), q)| -n * (p + q) / 2.0)
89+
.sum();
90+
91+
Ok(normal)
7792
}
7893

7994
fn margin(p: &Leaf<Self>, q: &Leaf<Self>) -> f32 {
8095
p.header.bias + dot_product_binary_quantized(&p.vector, &q.vector)
8196
}
82-
83-
fn margin_no_header(
84-
p: &UnalignedVector<Self::VectorCodec>,
85-
q: &UnalignedVector<Self::VectorCodec>,
86-
) -> f32 {
87-
dot_product_binary_quantized(p, q)
88-
}
8997
}
9098

9199
/// For the binary quantized manhattan distance:

src/distance/cosine.rs

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use std::borrow::Cow;
1+
use std::fmt;
22

33
use bytemuck::{Pod, Zeroable};
44
use rand::Rng;
@@ -18,10 +18,15 @@ pub enum Cosine {}
1818

1919
/// The header of Cosine leaf nodes.
2020
#[repr(C)]
21-
#[derive(Pod, Zeroable, Debug, Clone, Copy)]
21+
#[derive(Pod, Zeroable, Clone, Copy)]
2222
pub struct NodeHeaderCosine {
2323
norm: f32,
2424
}
25+
impl fmt::Debug for NodeHeaderCosine {
26+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
27+
f.debug_struct("NodeHeaderCosine").field("norm", &format!("{:.4}", self.norm)).finish()
28+
}
29+
}
2530

2631
impl Distance for Cosine {
2732
type Header = NodeHeaderCosine;
@@ -68,21 +73,18 @@ impl Distance for Cosine {
6873
fn create_split<'a, R: Rng>(
6974
children: &'a ImmutableSubsetLeafs<Self>,
7075
rng: &mut R,
71-
) -> heed::Result<Cow<'a, UnalignedVector<Self::VectorCodec>>> {
76+
) -> heed::Result<Leaf<'a, Self>> {
7277
let [node_p, node_q] = two_means(rng, children, true)?;
7378
let vector: Vec<f32> =
7479
node_p.vector.iter().zip(node_q.vector.iter()).map(|(p, q)| p - q).collect();
7580
let unaligned_vector = UnalignedVector::from_vec(vector);
7681
let mut normal = Leaf { header: NodeHeaderCosine { norm: 0.0 }, vector: unaligned_vector };
7782
Self::normalize(&mut normal);
7883

79-
Ok(normal.vector)
84+
Ok(normal)
8085
}
8186

82-
fn margin_no_header(
83-
p: &UnalignedVector<Self::VectorCodec>,
84-
q: &UnalignedVector<Self::VectorCodec>,
85-
) -> f32 {
86-
dot_product(p, q)
87+
fn margin(p: &Leaf<Self>, q: &Leaf<Self>) -> f32 {
88+
dot_product(&p.vector, &q.vector)
8789
}
8890
}

src/distance/dot_product.rs

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use std::borrow::Cow;
1+
use std::fmt;
22

33
use bytemuck::{Pod, Zeroable};
44
use heed::{RwPrefix, RwTxn};
@@ -21,12 +21,20 @@ pub enum DotProduct {}
2121

2222
/// The header of DotProduct leaf nodes.
2323
#[repr(C)]
24-
#[derive(Pod, Zeroable, Debug, Clone, Copy)]
24+
#[derive(Pod, Zeroable, Clone, Copy)]
2525
pub struct NodeHeaderDotProduct {
2626
extra_dim: f32,
2727
/// An extra constant term to determine the offset of the plane
2828
norm: f32,
2929
}
30+
impl fmt::Debug for NodeHeaderDotProduct {
31+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
32+
f.debug_struct("NodeHeaderDotProduct")
33+
.field("extra_dim", &format!("{:.4}", self.extra_dim))
34+
.field("norm", &format!("{:.4}", self.norm))
35+
.finish()
36+
}
37+
}
3038

3139
impl Distance for DotProduct {
3240
type Header = NodeHeaderDotProduct;
@@ -90,7 +98,7 @@ impl Distance for DotProduct {
9098
fn create_split<'a, R: Rng>(
9199
children: &'a ImmutableSubsetLeafs<Self>,
92100
rng: &mut R,
93-
) -> heed::Result<Cow<'a, UnalignedVector<Self::VectorCodec>>> {
101+
) -> heed::Result<Leaf<'a, Self>> {
94102
let [node_p, node_q] = two_means(rng, children, true)?;
95103
let vector: Vec<f32> =
96104
node_p.vector.iter().zip(node_q.vector.iter()).map(|(p, q)| p - q).collect();
@@ -101,20 +109,13 @@ impl Distance for DotProduct {
101109
normal.header.extra_dim = node_p.header.extra_dim - node_q.header.extra_dim;
102110
Self::normalize(&mut normal);
103111

104-
Ok(normal.vector)
112+
Ok(normal)
105113
}
106114

107115
fn margin(p: &Leaf<Self>, q: &Leaf<Self>) -> f32 {
108116
dot_product(&p.vector, &q.vector) + p.header.extra_dim * q.header.extra_dim
109117
}
110118

111-
fn margin_no_header(
112-
p: &UnalignedVector<Self::VectorCodec>,
113-
q: &UnalignedVector<Self::VectorCodec>,
114-
) -> f32 {
115-
dot_product(p, q)
116-
}
117-
118119
fn preprocess(
119120
wtxn: &mut RwTxn,
120121
new_iter: impl for<'a> Fn(

src/distance/euclidean.rs

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use std::borrow::Cow;
1+
use std::fmt;
22

33
use bytemuck::{Pod, Zeroable};
44
use rand::Rng;
@@ -19,11 +19,16 @@ pub enum Euclidean {}
1919

2020
/// The header of Euclidean leaf nodes.
2121
#[repr(C)]
22-
#[derive(Pod, Zeroable, Debug, Clone, Copy)]
22+
#[derive(Pod, Zeroable, Clone, Copy)]
2323
pub struct NodeHeaderEuclidean {
2424
/// An extra constant term to determine the offset of the plane
2525
bias: f32,
2626
}
27+
impl fmt::Debug for NodeHeaderEuclidean {
28+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
29+
f.debug_struct("NodeHeaderEuclidean").field("bias", &format!("{:.4}", self.bias)).finish()
30+
}
31+
}
2732

2833
impl Distance for Euclidean {
2934
type Header = NodeHeaderEuclidean;
@@ -50,7 +55,7 @@ impl Distance for Euclidean {
5055
fn create_split<'a, R: Rng>(
5156
children: &'a ImmutableSubsetLeafs<Self>,
5257
rng: &mut R,
53-
) -> heed::Result<Cow<'a, UnalignedVector<Self::VectorCodec>>> {
58+
) -> heed::Result<Leaf<'a, Self>> {
5459
let [node_p, node_q] = two_means(rng, children, false)?;
5560
let vector: Vec<_> =
5661
node_p.vector.iter().zip(node_q.vector.iter()).map(|(p, q)| p - q).collect();
@@ -68,17 +73,10 @@ impl Distance for Euclidean {
6873
.map(|((n, p), q)| -n * (p + q) / 2.0)
6974
.sum();
7075

71-
Ok(normal.vector)
72-
}
73-
74-
fn margin(p: &Leaf<Self>, q: &Leaf<Self>) -> f32 {
75-
p.header.bias + dot_product(&p.vector, &q.vector)
76+
Ok(normal)
7677
}
7778

78-
fn margin_no_header(
79-
p: &UnalignedVector<Self::VectorCodec>,
80-
q: &UnalignedVector<Self::VectorCodec>,
81-
) -> f32 {
82-
dot_product(p, q)
79+
fn margin(n: &Leaf<Self>, q: &Leaf<Self>) -> f32 {
80+
n.header.bias + dot_product(&n.vector, &q.vector)
8381
}
8482
}

src/distance/manhattan.rs

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use std::borrow::Cow;
1+
use std::fmt;
22

33
use bytemuck::{Pod, Zeroable};
44
use rand::Rng;
@@ -18,11 +18,16 @@ pub enum Manhattan {}
1818

1919
/// The header of Manhattan leaf nodes.
2020
#[repr(C)]
21-
#[derive(Pod, Zeroable, Debug, Clone, Copy)]
21+
#[derive(Pod, Zeroable, Clone, Copy)]
2222
pub struct NodeHeaderManhattan {
2323
/// An extra constant term to determine the offset of the plane
2424
bias: f32,
2525
}
26+
impl fmt::Debug for NodeHeaderManhattan {
27+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
28+
f.debug_struct("NodeHeaderManhattan").field("bias", &format!("{:.4}", self.bias)).finish()
29+
}
30+
}
2631

2732
impl Distance for Manhattan {
2833
type Header = NodeHeaderManhattan;
@@ -53,7 +58,7 @@ impl Distance for Manhattan {
5358
fn create_split<'a, R: Rng>(
5459
children: &'a ImmutableSubsetLeafs<Self>,
5560
rng: &mut R,
56-
) -> heed::Result<Cow<'a, UnalignedVector<Self::VectorCodec>>> {
61+
) -> heed::Result<Leaf<'a, Self>> {
5762
let [node_p, node_q] = two_means(rng, children, false)?;
5863
let vector: Vec<_> =
5964
node_p.vector.iter().zip(node_q.vector.iter()).map(|(p, q)| p - q).collect();
@@ -71,17 +76,10 @@ impl Distance for Manhattan {
7176
.map(|((n, p), q)| -n * (p + q) / 2.0)
7277
.sum();
7378

74-
Ok(normal.vector)
79+
Ok(normal)
7580
}
7681

7782
fn margin(p: &Leaf<Self>, q: &Leaf<Self>) -> f32 {
7883
p.header.bias + dot_product(&p.vector, &q.vector)
7984
}
80-
81-
fn margin_no_header(
82-
p: &UnalignedVector<Self::VectorCodec>,
83-
q: &UnalignedVector<Self::VectorCodec>,
84-
) -> f32 {
85-
dot_product(p, q)
86-
}
8785
}

0 commit comments

Comments
 (0)