@@ -814,6 +814,41 @@ TEST_P(KnnTest, Cosine) {
814
814
}
815
815
}
816
816
817
+ TEST_P (KnnTest, IP) {
818
+ // Test with normalized unit vectors for IP distance
819
+ // Using unit vectors pointing in different directions
820
+ const pair<float , float > kTestCoords [] = {
821
+ {1 .0f , 0 .0f }, {0 .0f , 1 .0f }, {-1 .0f , 0 .0f }, {0 .0f , -1 .0f }};
822
+
823
+ auto schema = MakeSimpleSchema ({{" pos" , SchemaField::VECTOR}});
824
+ schema.fields [" pos" ].special_params =
825
+ SchemaField::VectorParams{GetParam (), 2 , VectorSimilarity::IP};
826
+ FieldIndices indices{schema, kEmptyOptions , PMR_NS::get_default_resource (), nullptr };
827
+
828
+ for (size_t i = 0 ; i < ABSL_ARRAYSIZE (kTestCoords ); i++) {
829
+ string coords = ToBytes ({kTestCoords [i].first , kTestCoords [i].second });
830
+ MockedDocument doc{Map{{" pos" , coords}}};
831
+ indices.Add (i, doc);
832
+ }
833
+
834
+ SearchAlgorithm algo{};
835
+ QueryParams params;
836
+
837
+ // Query with vector pointing right - should find exact match (highest dot product)
838
+ {
839
+ params[" vec" ] = ToBytes ({1 .0f , 0 .0f });
840
+ algo.Init (" * =>[KNN 1 @pos $vec]" , ¶ms);
841
+ EXPECT_THAT (algo.Search (&indices).ids , testing::UnorderedElementsAre (0 ));
842
+ }
843
+
844
+ // Query with vector pointing up - should find exact match (highest dot product)
845
+ {
846
+ params[" vec" ] = ToBytes ({0 .0f , 1 .0f });
847
+ algo.Init (" * =>[KNN 1 @pos $vec]" , ¶ms);
848
+ EXPECT_THAT (algo.Search (&indices).ids , testing::UnorderedElementsAre (1 ));
849
+ }
850
+ }
851
+
817
852
TEST_P (KnnTest, AddRemove) {
818
853
auto schema = MakeSimpleSchema ({{" pos" , SchemaField::VECTOR}});
819
854
schema.fields [" pos" ].special_params =
@@ -894,12 +929,22 @@ TEST_F(SearchTest, VectorDistanceBasic) {
894
929
EXPECT_GE (cos_dist, 0 .0f );
895
930
EXPECT_LE (cos_dist, 2 .0f ); // Cosine distance range
896
931
932
+ // Test IP distance
933
+ float ip_dist = VectorDistance (vec1.data (), vec2.data (), 3 , VectorSimilarity::IP);
934
+ // IP distance can be negative for non-normalized vectors
935
+ EXPECT_NE (ip_dist, 0 .0f ); // Should be non-zero for different vectors
936
+
897
937
// Test identical vectors
898
938
float l2_same = VectorDistance (vec1.data (), vec1.data (), 3 , VectorSimilarity::L2);
899
939
EXPECT_NEAR (l2_same, 0 .0f , 1e-6 );
900
940
901
941
float cos_same = VectorDistance (vec1.data (), vec1.data (), 3 , VectorSimilarity::COSINE);
902
942
EXPECT_NEAR (cos_same, 0 .0f , 1e-6 );
943
+
944
+ float ip_same = VectorDistance (vec1.data (), vec1.data (), 3 , VectorSimilarity::IP);
945
+ // For identical vectors: IP = 1 - dot_product(v, v) = 1 - ||v||^2
946
+ // For vec1 = {1, 2, 3}: ||v||^2 = 1 + 4 + 9 = 14, so IP = 1 - 14 = -13
947
+ EXPECT_LT (ip_same, 0 .0f ); // Should be negative for non-normalized vectors
903
948
}
904
949
905
950
TEST_F (SearchTest, VectorDistanceConsistency) {
@@ -914,6 +959,10 @@ TEST_F(SearchTest, VectorDistanceConsistency) {
914
959
float cos_dist1 = VectorDistance (vec1.data (), vec2.data (), 5 , VectorSimilarity::COSINE);
915
960
float cos_dist2 = VectorDistance (vec1.data (), vec2.data (), 5 , VectorSimilarity::COSINE);
916
961
EXPECT_EQ (cos_dist1, cos_dist2);
962
+
963
+ float ip_dist1 = VectorDistance (vec1.data (), vec2.data (), 5 , VectorSimilarity::IP);
964
+ float ip_dist2 = VectorDistance (vec1.data (), vec2.data (), 5 , VectorSimilarity::IP);
965
+ EXPECT_EQ (ip_dist1, ip_dist2);
917
966
}
918
967
919
968
static void BM_VectorSearch (benchmark::State& state) {
0 commit comments