Skip to content

Commit 2c06460

Browse files
authored
Normalize to speed up distance calc (#1)
1 parent 3586dc9 commit 2c06460

File tree

14 files changed

+331
-19
lines changed

14 files changed

+331
-19
lines changed

bruteforce.go

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
package search
55

66
import (
7+
"math"
78
"sort"
89

910
"github.com/kelindar/search/internal/cosine/simd"
@@ -36,6 +37,8 @@ func NewIndex[T any]() *Index[T] {
3637

3738
// Add adds a new vector to the search index.
3839
func (b *Index[T]) Add(vx Vector, item T) {
40+
normalize(vx)
41+
3942
b.arr = append(b.arr, entry[T]{
4043
Vector: vx,
4144
Value: item,
@@ -48,10 +51,13 @@ func (b *Index[T]) Search(query Vector, k int) []Result[T] {
4851
return nil
4952
}
5053

54+
// Normalize and quantize the query vector
55+
normalize(query)
56+
5157
var relevance float64
5258
dst := make(minheap[T], 0, k)
5359
for _, v := range b.arr {
54-
simd.Cosine(&relevance, v.Vector, query)
60+
simd.DotProduct(&relevance, query, v.Vector)
5561
result := Result[T]{
5662
entry: v,
5763
Relevance: relevance,
@@ -73,6 +79,21 @@ func (b *Index[T]) Search(query Vector, k int) []Result[T] {
7379
return dst
7480
}
7581

82+
// Normalize normalizes the vector, resulting in a unit vector. This allows us
83+
// to do a simple dot product to calculate the cosine similarity instead of
84+
// the full cosine distance.
85+
func normalize(v []float32) {
86+
norm := float32(0)
87+
for _, x := range v {
88+
norm += x * x
89+
}
90+
91+
norm = float32(math.Sqrt(float64(norm)))
92+
for i := range v {
93+
v[i] /= norm
94+
}
95+
}
96+
7697
// --------------------------------- Heap ---------------------------------
7798

7899
// minheap is a min-heap of top values, ordered by relevance.

bruteforce_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ import (
1313

1414
/*
1515
cpu: 13th Gen Intel(R) Core(TM) i7-13700K
16-
BenchmarkIndex/search-24 4029 298055 ns/op 272 B/op 3 allocs/op
16+
BenchmarkIndex/search-24 5366 217116 ns/op 272 B/op 3 allocs/op
1717
*/
1818
func BenchmarkIndex(b *testing.B) {
1919
data, err := loadDataset()

internal/cosine/cosine_apple.c

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,4 +26,15 @@ void f32_cosine_distance(const float *x, const float *y, double *result, const u
2626

2727
double cosine_similarity = (double)sum_xy / (double)denominator;
2828
*result = cosine_similarity;
29-
}
29+
}
30+
31+
void f32_dot_product(const float *x, const float *y, double *result, const uint64_t size) {
32+
float sum = 0.0f;
33+
34+
#pragma clang loop vectorize(enable) interleave(enable)
35+
for (uint64_t i = 0; i < size; i++) {
36+
sum += x[i] * y[i];
37+
}
38+
39+
*result = (double)sum;
40+
}

internal/cosine/cosine_avx.c

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@ void f32_cosine_distance(const float *x, const float *y, double *result, const u
1313
#pragma clang loop vectorize(enable) interleave_count(2)
1414
for (uint64_t i = 0; i < size; i++) {
1515
sum_xy += x[i] * y[i]; // Sum of x * y
16-
sum_xx += x[i] * x[i]; // Sum of x * x
17-
sum_yy += y[i] * y[i]; // Sum of y * y
16+
sum_xx += x[i] * x[i]; // Sum of x * x
17+
sum_yy += y[i] * y[i]; // Sum of y * y
1818
}
1919

2020
// Calculate the final result
@@ -26,4 +26,15 @@ void f32_cosine_distance(const float *x, const float *y, double *result, const u
2626

2727
double cosine_similarity = (double)sum_xy / (double)denominator;
2828
*result = cosine_similarity;
29-
}
29+
}
30+
31+
void f32_dot_product(const float *x, const float *y, double *result, const uint64_t size) {
32+
float sum = 0.0f;
33+
34+
#pragma clang loop vectorize(enable) interleave(enable)
35+
for (uint64_t i = 0; i < size; i++) {
36+
sum += x[i] * y[i];
37+
}
38+
39+
*result = (double)sum;
40+
}

internal/cosine/cosine_neon.c

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,4 +26,15 @@ void f32_cosine_distance(const float *x, const float *y, double *result, const u
2626

2727
double cosine_similarity = (double)sum_xy / (double)denominator;
2828
*result = cosine_similarity;
29-
}
29+
}
30+
31+
void f32_dot_product(const float *x, const float *y, double *result, const uint64_t size) {
32+
float sum = 0.0f;
33+
34+
#pragma clang loop vectorize(enable) interleave(enable)
35+
for (uint64_t i = 0; i < size; i++) {
36+
sum += x[i] * y[i];
37+
}
38+
39+
*result = (double)sum;
40+
}

internal/cosine/simd/cosine_apple.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,6 @@ import "unsafe"
77

88
//go:noescape,nosplit
99
func f32_cosine_distance(x unsafe.Pointer, y unsafe.Pointer, result unsafe.Pointer, size uint64)
10+
11+
//go:noescape,nosplit
12+
func f32_dot_product(x unsafe.Pointer, y unsafe.Pointer, result unsafe.Pointer, size uint64)

internal/cosine/simd/cosine_apple.s

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,3 +114,65 @@ BB0_11:
114114
WORD $0xfd000040 // str d0, [x2]
115115
WORD $0xa8c17bfd // ldp x29, x30, [sp], #16 ; 16-byte Folded Reload
116116
WORD $0xd65f03c0 // ret
117+
118+
TEXT ·f32_dot_product(SB), $0-32
119+
MOVD x+0(FP), R0
120+
MOVD y+8(FP), R1
121+
MOVD result+16(FP), R2
122+
MOVD size+24(FP), R3
123+
WORD $0xa9bf7bfd // stp x29, x30, [sp, #-16]! ; 16-byte Folded Spill
124+
WORD $0x910003fd // mov x29, sp
125+
WORD $0xb40000c3 // cbz x3, LBB1_3
126+
WORD $0xf100207f // cmp x3, #8
127+
WORD $0x54000102 // b.hs LBB1_4
128+
WORD $0xd2800008 // mov x8, #0
129+
WORD $0x2f00e400 // movi d0, #0000000000000000
130+
WORD $0x14000018 // b LBB1_7
131+
132+
BB1_3:
133+
WORD $0x2f00e400 // movi d0, #0000000000000000
134+
WORD $0xfd000040 // str d0, [x2]
135+
WORD $0xa8c17bfd // ldp x29, x30, [sp], #16 ; 16-byte Folded Reload
136+
WORD $0xd65f03c0 // ret
137+
138+
BB1_4:
139+
WORD $0x927df068 // and x8, x3, #0xfffffffffffffff8
140+
WORD $0x91004009 // add x9, x0, #16
141+
WORD $0x9100402a // add x10, x1, #16
142+
WORD $0x6f00e400 // movi.2d v0, #0000000000000000
143+
WORD $0xaa0803eb // mov x11, x8
144+
WORD $0x6f00e401 // movi.2d v1, #0000000000000000
145+
146+
BB1_5:
147+
WORD $0xad7f8d22 // ldp q2, q3, [x9, #-16]
148+
WORD $0xad7f9544 // ldp q4, q5, [x10, #-16]
149+
WORD $0x4e22cc80 // fmla.4s v0, v4, v2
150+
WORD $0x4e23cca1 // fmla.4s v1, v5, v3
151+
WORD $0x91008129 // add x9, x9, #32
152+
WORD $0x9100814a // add x10, x10, #32
153+
WORD $0xf100216b // subs x11, x11, #8
154+
WORD $0x54ffff21 // b.ne LBB1_5
155+
WORD $0x4e20d420 // fadd.4s v0, v1, v0
156+
WORD $0x6e20d400 // faddp.4s v0, v0, v0
157+
WORD $0x7e30d800 // faddp.2s s0, v0
158+
WORD $0xeb03011f // cmp x8, x3
159+
WORD $0x54000140 // b.eq LBB1_9
160+
161+
BB1_7:
162+
WORD $0xcb080069 // sub x9, x3, x8
163+
WORD $0xd37ef50a // lsl x10, x8, #2
164+
WORD $0x8b0a0028 // add x8, x1, x10
165+
WORD $0x8b0a000a // add x10, x0, x10
166+
167+
BB1_8:
168+
WORD $0xbc404541 // ldr s1, [x10], #4
169+
WORD $0xbc404502 // ldr s2, [x8], #4
170+
WORD $0x1f010040 // fmadd s0, s2, s1, s0
171+
WORD $0xf1000529 // subs x9, x9, #1
172+
WORD $0x54ffff81 // b.ne LBB1_8
173+
174+
BB1_9:
175+
WORD $0x1e22c000 // fcvt d0, s0
176+
WORD $0xfd000040 // str d0, [x2]
177+
WORD $0xa8c17bfd // ldp x29, x30, [sp], #16 ; 16-byte Folded Reload
178+
WORD $0xd65f03c0 // ret

internal/cosine/simd/cosine_avx.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,6 @@ import "unsafe"
77

88
//go:noescape,nosplit
99
func f32_cosine_distance(x unsafe.Pointer, y unsafe.Pointer, result unsafe.Pointer, size uint64)
10+
11+
//go:noescape,nosplit
12+
func f32_dot_product(x unsafe.Pointer, y unsafe.Pointer, result unsafe.Pointer, size uint64)

internal/cosine/simd/cosine_avx.s

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,3 +104,74 @@ LBB0_9:
104104
BYTE $0x5d // pop rbp
105105
WORD $0xf8c5; BYTE $0x77 // vzeroupper
106106
BYTE $0xc3 // ret
107+
108+
TEXT ·f32_dot_product(SB), $0-32
109+
MOVQ x+0(FP), DI
110+
MOVQ y+8(FP), SI
111+
MOVQ result+16(FP), DX
112+
MOVQ size+24(FP), CX
113+
BYTE $0x55 // push rbp
114+
WORD $0x8948; BYTE $0xe5 // mov rbp, rsp
115+
LONG $0xf8e48348 // and rsp, -8
116+
WORD $0x8548; BYTE $0xc9 // test rcx, rcx
117+
JE LBB1_1
118+
LONG $0x20f98348 // cmp rcx, 32
119+
JAE LBB1_5
120+
LONG $0xc057f8c5 // vxorps xmm0, xmm0, xmm0
121+
WORD $0x3145; BYTE $0xc0 // xor r8d, r8d
122+
JMP LBB1_4
123+
124+
LBB1_1:
125+
LONG $0xc057f8c5 // vxorps xmm0, xmm0, xmm0
126+
LONG $0x0211fbc5 // vmovsd qword ptr [rdx], xmm0
127+
WORD $0x8948; BYTE $0xec // mov rsp, rbp
128+
BYTE $0x5d // pop rbp
129+
BYTE $0xc3 // ret
130+
131+
LBB1_5:
132+
WORD $0x8949; BYTE $0xc8 // mov r8, rcx
133+
LONG $0xe0e08349 // and r8, -32
134+
LONG $0xc057f8c5 // vxorps xmm0, xmm0, xmm0
135+
WORD $0xc031 // xor eax, eax
136+
LONG $0xc957f0c5 // vxorps xmm1, xmm1, xmm1
137+
LONG $0xd257e8c5 // vxorps xmm2, xmm2, xmm2
138+
LONG $0xdb57e0c5 // vxorps xmm3, xmm3, xmm3
139+
140+
LBB1_6:
141+
LONG $0x2410fcc5; BYTE $0x86 // vmovups ymm4, ymmword ptr [rsi + 4*rax]
142+
LONG $0x6c10fcc5; WORD $0x2086 // vmovups ymm5, ymmword ptr [rsi + 4*rax + 32]
143+
LONG $0x7410fcc5; WORD $0x4086 // vmovups ymm6, ymmword ptr [rsi + 4*rax + 64]
144+
LONG $0x7c10fcc5; WORD $0x6086 // vmovups ymm7, ymmword ptr [rsi + 4*rax + 96]
145+
LONG $0xb85de2c4; WORD $0x8704 // vfmadd231ps ymm0, ymm4, ymmword ptr [rdi + 4*rax]
146+
LONG $0xb855e2c4; WORD $0x874c; BYTE $0x20 // vfmadd231ps ymm1, ymm5, ymmword ptr [rdi + 4*rax + 32]
147+
LONG $0xb84de2c4; WORD $0x8754; BYTE $0x40 // vfmadd231ps ymm2, ymm6, ymmword ptr [rdi + 4*rax + 64]
148+
LONG $0xb845e2c4; WORD $0x875c; BYTE $0x60 // vfmadd231ps ymm3, ymm7, ymmword ptr [rdi + 4*rax + 96]
149+
LONG $0x20c08348 // add rax, 32
150+
WORD $0x3949; BYTE $0xc0 // cmp r8, rax
151+
JNE LBB1_6
152+
LONG $0xc058f4c5 // vaddps ymm0, ymm1, ymm0
153+
LONG $0xc058ecc5 // vaddps ymm0, ymm2, ymm0
154+
LONG $0xc058e4c5 // vaddps ymm0, ymm3, ymm0
155+
LONG $0x197de3c4; WORD $0x01c1 // vextractf128 xmm1, ymm0, 1
156+
LONG $0xc158f8c5 // vaddps xmm0, xmm0, xmm1
157+
LONG $0x0579e3c4; WORD $0x01c8 // vpermilpd xmm1, xmm0, 1
158+
LONG $0xc158f8c5 // vaddps xmm0, xmm0, xmm1
159+
LONG $0xc816fac5 // vmovshdup xmm1, xmm0
160+
LONG $0xc158fac5 // vaddss xmm0, xmm0, xmm1
161+
WORD $0x3949; BYTE $0xc8 // cmp r8, rcx
162+
JE LBB1_8
163+
164+
LBB1_4:
165+
LONG $0x107aa1c4; WORD $0x860c // vmovss xmm1, dword ptr [rsi + 4*r8]
166+
LONG $0xb971a2c4; WORD $0x8704 // vfmadd231ss xmm0, xmm1, dword ptr [rdi + 4*r8]
167+
WORD $0xff49; BYTE $0xc0 // inc r8
168+
WORD $0x394c; BYTE $0xc1 // cmp rcx, r8
169+
JNE LBB1_4
170+
171+
LBB1_8:
172+
LONG $0xc05afac5 // vcvtss2sd xmm0, xmm0, xmm0
173+
LONG $0x0211fbc5 // vmovsd qword ptr [rdx], xmm0
174+
WORD $0x8948; BYTE $0xec // mov rsp, rbp
175+
BYTE $0x5d // pop rbp
176+
WORD $0xf8c5; BYTE $0x77 // vzeroupper
177+
BYTE $0xc3 // ret

internal/cosine/simd/cosine_neon.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,6 @@ import "unsafe"
77

88
//go:noescape,nosplit
99
func f32_cosine_distance(x unsafe.Pointer, y unsafe.Pointer, result unsafe.Pointer, size uint64)
10+
11+
//go:noescape,nosplit
12+
func f32_dot_product(x unsafe.Pointer, y unsafe.Pointer, result unsafe.Pointer, size uint64)

0 commit comments

Comments
 (0)