Skip to content

Commit f2cf05f

Browse files
authored
feat(search): Add IP (Inner Product) distance metric support for vector similarity (#5559)
Fixed: #5556
1 parent 6d74f4a commit f2cf05f

File tree

5 files changed

+67
-5
lines changed

5 files changed

+67
-5
lines changed

src/core/search/base.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ namespace dfly::search {
2121

2222
using DocId = uint32_t;
2323

24-
enum class VectorSimilarity { L2, COSINE };
24+
enum class VectorSimilarity { L2, IP, COSINE };
2525

2626
using OwnedFtVector = std::pair<std::unique_ptr<float[]>, size_t /* dimension (size) */>;
2727

src/core/search/search_test.cc

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -814,6 +814,41 @@ TEST_P(KnnTest, Cosine) {
814814
}
815815
}
816816

817+
TEST_P(KnnTest, IP) {
818+
// Test with normalized unit vectors for IP distance
819+
// Using unit vectors pointing in different directions
820+
const pair<float, float> kTestCoords[] = {
821+
{1.0f, 0.0f}, {0.0f, 1.0f}, {-1.0f, 0.0f}, {0.0f, -1.0f}};
822+
823+
auto schema = MakeSimpleSchema({{"pos", SchemaField::VECTOR}});
824+
schema.fields["pos"].special_params =
825+
SchemaField::VectorParams{GetParam(), 2, VectorSimilarity::IP};
826+
FieldIndices indices{schema, kEmptyOptions, PMR_NS::get_default_resource(), nullptr};
827+
828+
for (size_t i = 0; i < ABSL_ARRAYSIZE(kTestCoords); i++) {
829+
string coords = ToBytes({kTestCoords[i].first, kTestCoords[i].second});
830+
MockedDocument doc{Map{{"pos", coords}}};
831+
indices.Add(i, doc);
832+
}
833+
834+
SearchAlgorithm algo{};
835+
QueryParams params;
836+
837+
// Query with vector pointing right - should find exact match (highest dot product)
838+
{
839+
params["vec"] = ToBytes({1.0f, 0.0f});
840+
algo.Init("* =>[KNN 1 @pos $vec]", &params);
841+
EXPECT_THAT(algo.Search(&indices).ids, testing::UnorderedElementsAre(0));
842+
}
843+
844+
// Query with vector pointing up - should find exact match (highest dot product)
845+
{
846+
params["vec"] = ToBytes({0.0f, 1.0f});
847+
algo.Init("* =>[KNN 1 @pos $vec]", &params);
848+
EXPECT_THAT(algo.Search(&indices).ids, testing::UnorderedElementsAre(1));
849+
}
850+
}
851+
817852
TEST_P(KnnTest, AddRemove) {
818853
auto schema = MakeSimpleSchema({{"pos", SchemaField::VECTOR}});
819854
schema.fields["pos"].special_params =
@@ -894,12 +929,22 @@ TEST_F(SearchTest, VectorDistanceBasic) {
894929
EXPECT_GE(cos_dist, 0.0f);
895930
EXPECT_LE(cos_dist, 2.0f); // Cosine distance range
896931

932+
// Test IP distance
933+
float ip_dist = VectorDistance(vec1.data(), vec2.data(), 3, VectorSimilarity::IP);
934+
// IP distance can be negative for non-normalized vectors
935+
EXPECT_NE(ip_dist, 0.0f); // Should be non-zero for different vectors
936+
897937
// Test identical vectors
898938
float l2_same = VectorDistance(vec1.data(), vec1.data(), 3, VectorSimilarity::L2);
899939
EXPECT_NEAR(l2_same, 0.0f, 1e-6);
900940

901941
float cos_same = VectorDistance(vec1.data(), vec1.data(), 3, VectorSimilarity::COSINE);
902942
EXPECT_NEAR(cos_same, 0.0f, 1e-6);
943+
944+
float ip_same = VectorDistance(vec1.data(), vec1.data(), 3, VectorSimilarity::IP);
945+
// For identical vectors: IP = 1 - dot_product(v, v) = 1 - ||v||^2
946+
// For vec1 = {1, 2, 3}: ||v||^2 = 1 + 4 + 9 = 14, so IP = 1 - 14 = -13
947+
EXPECT_LT(ip_same, 0.0f); // Should be negative for non-normalized vectors
903948
}
904949

905950
TEST_F(SearchTest, VectorDistanceConsistency) {
@@ -914,6 +959,10 @@ TEST_F(SearchTest, VectorDistanceConsistency) {
914959
float cos_dist1 = VectorDistance(vec1.data(), vec2.data(), 5, VectorSimilarity::COSINE);
915960
float cos_dist2 = VectorDistance(vec1.data(), vec2.data(), 5, VectorSimilarity::COSINE);
916961
EXPECT_EQ(cos_dist1, cos_dist2);
962+
963+
float ip_dist1 = VectorDistance(vec1.data(), vec2.data(), 5, VectorSimilarity::IP);
964+
float ip_dist2 = VectorDistance(vec1.data(), vec2.data(), 5, VectorSimilarity::IP);
965+
EXPECT_EQ(ip_dist1, ip_dist2);
917966
}
918967

919968
static void BM_VectorSearch(benchmark::State& state) {

src/core/search/vector_utils.cc

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,15 @@ FAST_MATH float L2Distance(const float* u, const float* v, size_t dims) {
2929
return sqrt(sum);
3030
}
3131

32-
// TODO: Normalize vectors ahead if cosine distance is used
32+
// Inner product distance: 1 - dot_product(u, v)
33+
// For normalized vectors, this is equivalent to cosine distance
34+
FAST_MATH float IPDistance(const float* u, const float* v, size_t dims) {
35+
float sum_uv = 0;
36+
for (size_t i = 0; i < dims; i++)
37+
sum_uv += u[i] * v[i];
38+
return 1.0f - sum_uv;
39+
}
40+
3341
FAST_MATH float CosineDistance(const float* u, const float* v, size_t dims) {
3442
float sum_uv = 0, sum_uu = 0, sum_vv = 0;
3543
for (size_t i = 0; i < dims; i++) {
@@ -71,6 +79,8 @@ float VectorDistance(const float* u, const float* v, size_t dims, VectorSimilari
7179
switch (sim) {
7280
case VectorSimilarity::L2:
7381
return L2Distance(u, v, dims);
82+
case VectorSimilarity::IP:
83+
return IPDistance(u, v, dims);
7484
case VectorSimilarity::COSINE:
7585
return CosineDistance(u, v, dims);
7686
};

src/server/search/doc_index.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,9 @@ string DocIndexInfo::BuildRestoreCommand() const {
148148
Overloaded info{
149149
[](monostate) {},
150150
[out = &out](const search::SchemaField::VectorParams& params) {
151-
auto sim = params.sim == search::VectorSimilarity::L2 ? "L2" : "COSINE";
151+
auto sim = params.sim == search::VectorSimilarity::L2 ? "L2"
152+
: params.sim == search::VectorSimilarity::IP ? "IP"
153+
: "COSINE";
152154
absl::StrAppend(out, " ", params.use_hnsw ? "HNSW" : "FLAT", " 6 ", "DIM ", params.dim,
153155
" DISTANCE_METRIC ", sim, " INITIAL_CAP ", params.capacity);
154156
},

src/server/search/search_family.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,9 @@ search::SchemaField::VectorParams ParseVectorParams(CmdArgParser* parser) {
8080
for (size_t i = 0; i * 2 < num_args; i++) {
8181
if (parser->Check("DIM", &params.dim)) {
8282
} else if (parser->Check("DISTANCE_METRIC")) {
83-
params.sim = parser->MapNext("L2", search::VectorSimilarity::L2, "COSINE",
84-
search::VectorSimilarity::COSINE);
83+
params.sim =
84+
parser->MapNext("L2", search::VectorSimilarity::L2, "IP", search::VectorSimilarity::IP,
85+
"COSINE", search::VectorSimilarity::COSINE);
8586
} else if (parser->Check("INITIAL_CAP", &params.capacity)) {
8687
} else if (parser->Check("M", &params.hnsw_m)) {
8788
} else if (parser->Check("EF_CONSTRUCTION", &params.hnsw_ef_construction)) {

0 commit comments

Comments
 (0)