@@ -139,6 +139,10 @@ butil::Status VectorIndexUtils::CalcDistanceByFaiss(
139139 return CalcCosineDistanceByFaiss (op_left_vectors, op_right_vectors, is_return_normlize, distances,
140140 result_op_left_vectors, result_op_right_vectors);
141141 }
142+ case pb::common::METRIC_TYPE_HAMMING: {
143+ return CalcHammingDistanceByFaiss (op_left_vectors, op_right_vectors, is_return_normlize, distances,
144+ result_op_left_vectors, result_op_right_vectors);
145+ }
142146 case pb::common::METRIC_TYPE_NONE:
143147 case pb::common::MetricType_INT_MIN_SENTINEL_DO_NOT_USE_:
144148 case pb::common::MetricType_INT_MAX_SENTINEL_DO_NOT_USE_: {
@@ -213,6 +217,17 @@ butil::Status VectorIndexUtils::CalcCosineDistanceByFaiss(
213217 result_op_right_vectors, DoCalcCosineDistanceByFaiss);
214218}
215219
220+ butil::Status VectorIndexUtils::CalcHammingDistanceByFaiss (
221+ const google::protobuf::RepeatedPtrField<::dingodb::pb::common::Vector>& op_left_vectors,
222+ const google::protobuf::RepeatedPtrField<::dingodb::pb::common::Vector>& op_right_vectors, bool is_return_normlize,
223+ std::vector<std::vector<float >>& distances, // NOLINT
224+ std::vector<::dingodb::pb::common::Vector>& result_op_left_vectors, // NOLINT
225+ std::vector<::dingodb::pb::common::Vector>& result_op_right_vectors) // NOLINT
226+ { // NOLINT
227+ return CalcDistanceCore (op_left_vectors, op_right_vectors, is_return_normlize, distances, result_op_left_vectors,
228+ result_op_right_vectors, DoCalcHammingDistanceByFaiss);
229+ }
230+
216231butil::Status VectorIndexUtils::CalcL2DistanceByHnswlib (
217232 const google::protobuf::RepeatedPtrField<::dingodb::pb::common::Vector>& op_left_vectors,
218233 const google::protobuf::RepeatedPtrField<::dingodb::pb::common::Vector>& op_right_vectors, bool is_return_normlize,
@@ -307,6 +322,33 @@ butil::Status VectorIndexUtils::DoCalcCosineDistanceByFaiss(
307322 return butil::Status ();
308323}
309324
325+ butil::Status VectorIndexUtils::DoCalcHammingDistanceByFaiss (
326+ const ::dingodb::pb::common::Vector& op_left_vectors, const ::dingodb::pb::common::Vector& op_right_vectors,
327+ bool is_return_normlize,
328+ float & distance, // NOLINT
329+ dingodb::pb::common::Vector& result_op_left_vectors, // NOLINT
330+ dingodb::pb::common::Vector& result_op_right_vectors) // NOLINT
331+ { // NOLINT
332+ faiss::VectorDistance<faiss::MetricType::METRIC_HAMMING> vector_distance;
333+ vector_distance.d = op_left_vectors.binary_values ().size ();
334+
335+ std::vector<uint8_t > left_vectors = std::vector<uint8_t >(op_left_vectors.binary_values ().size ());
336+ for (int j = 0 ; j < op_left_vectors.binary_values ().size (); j++) {
337+ left_vectors[j] = static_cast <uint8_t >(op_left_vectors.binary_values ()[j][0 ]);
338+ }
339+ std::vector<uint8_t > right_vectors = std::vector<uint8_t >(op_right_vectors.binary_values ().size ());
340+ for (int j = 0 ; j < op_right_vectors.binary_values ().size (); j++) {
341+ right_vectors[j] = static_cast <uint8_t >(op_right_vectors.binary_values ()[j][0 ]);
342+ }
343+
344+ distance = vector_distance (left_vectors.data (), right_vectors.data ());
345+
346+ ResultOpBinaryVectorAssignmentWrapper (op_left_vectors, op_right_vectors, is_return_normlize, result_op_left_vectors,
347+ result_op_right_vectors);
348+
349+ return butil::Status ();
350+ }
351+
310352butil::Status VectorIndexUtils::DoCalcL2DistanceByHnswlib (
311353 const ::dingodb::pb::common::Vector& op_left_vectors, const ::dingodb::pb::common::Vector& op_right_vectors,
312354 bool is_return_normlize,
@@ -386,6 +428,13 @@ void VectorIndexUtils::ResultOpVectorAssignment(dingodb::pb::common::Vector& res
386428 result_op_vectors.set_value_type (::dingodb::pb::common::ValueType::FLOAT);
387429}
388430
431+ void VectorIndexUtils::ResultOpBinaryVectorAssignment (dingodb::pb::common::Vector& result_op_vectors,
432+ const ::dingodb::pb::common::Vector& op_vectors) {
433+ result_op_vectors = op_vectors;
434+ result_op_vectors.set_dimension (result_op_vectors.binary_values ().size () * CHAR_BIT);
435+ result_op_vectors.set_value_type (::dingodb::pb::common::ValueType::UINT8);
436+ }
437+
389438void VectorIndexUtils::ResultOpVectorAssignmentWrapper (const ::dingodb::pb::common::Vector& op_left_vectors,
390439 const ::dingodb::pb::common::Vector& op_right_vectors,
391440 bool is_return_normlize,
@@ -403,6 +452,23 @@ void VectorIndexUtils::ResultOpVectorAssignmentWrapper(const ::dingodb::pb::comm
403452 }
404453}
405454
455+ void VectorIndexUtils::ResultOpBinaryVectorAssignmentWrapper (
456+ const ::dingodb::pb::common::Vector& op_left_vectors, const ::dingodb::pb::common::Vector& op_right_vectors,
457+ bool is_return_normlize,
458+ dingodb::pb::common::Vector& result_op_left_vectors, // NOLINT
459+ dingodb::pb::common::Vector& result_op_right_vectors) // NOLINT
460+ { // NOLINT
461+ if (is_return_normlize) {
462+ if (result_op_left_vectors.binary_values ().empty ()) {
463+ ResultOpBinaryVectorAssignment (result_op_left_vectors, op_left_vectors);
464+ }
465+
466+ if (result_op_right_vectors.binary_values ().empty ()) {
467+ ResultOpBinaryVectorAssignment (result_op_right_vectors, op_right_vectors);
468+ }
469+ }
470+ }
471+
406472void VectorIndexUtils::NormalizeVectorForFaiss (float * x, int32_t d) {
407473 static const float kFloatAccuracy = 0.00001 ;
408474
@@ -446,6 +512,10 @@ butil::Status VectorIndexUtils::CheckVectorDimension(const std::vector<pb::commo
446512 DINGO_LOG (ERROR) << s;
447513 return butil::Status (pb::error::Errno::EVECTOR_INVALID, s);
448514 }
515+ if (vector_with_id.vector ().dimension () != dimension) {
516+ std::string s = fmt::format (" vector dimension not match, {} {}" , vector_with_id.vector ().dimension (), dimension);
517+ return butil::Status (pb::error::Errno::EVECTOR_INVALID, s);
518+ }
449519 }
450520
451521 return butil::Status::OK ();
@@ -486,9 +556,9 @@ template <typename T>
486556std::unique_ptr<T[]> VectorIndexUtils::ExtractVectorValue (const std::vector<pb::common::VectorWithId>& vector_with_ids,
487557 faiss::idx_t dimension, bool normalize) {
488558 std::unique_ptr<T[]> vectors = nullptr ;
489- if (std::is_same<T, float >::value) {
559+ if constexpr (std::is_same<T, float >::value) {
490560 vectors = std::make_unique<T[]>(vector_with_ids.size () * dimension);
491- } else if (std::is_same<T, uint8_t >::value) {
561+ } else if constexpr (std::is_same<T, uint8_t >::value) {
492562 vectors = std::make_unique<T[]>(vector_with_ids.size () * dimension / CHAR_BIT);
493563 } else {
494564 std::string s = fmt::format (" invalid value typename type" );
@@ -497,8 +567,8 @@ std::unique_ptr<T[]> VectorIndexUtils::ExtractVectorValue(const std::vector<pb::
497567 }
498568
499569 for (size_t i = 0 ; i < vector_with_ids.size (); ++i) {
500- if (vector_with_ids[i]. vector (). value_type () == pb::common::ValueType::FLOAT ) {
501- if (!std::is_same<T, float >::value ) {
570+ if constexpr (std::is_same<T, float >::value ) {
571+ if (vector_with_ids[i]. vector (). value_type () != pb::common::ValueType::FLOAT ) {
502572 std::string s = fmt::format (" template not match vectors value_type : {}" ,
503573 pb::common::ValueType_Name (vector_with_ids[i].vector ().value_type ()));
504574 DINGO_LOG (ERROR) << s;
@@ -509,15 +579,17 @@ std::unique_ptr<T[]> VectorIndexUtils::ExtractVectorValue(const std::vector<pb::
509579 if (normalize) {
510580 VectorIndexUtils::NormalizeVectorForFaiss (reinterpret_cast <float *>(vectors.get ()) + i * dimension, dimension);
511581 }
512- } else if (vector_with_ids[i]. vector (). value_type () == pb::common::ValueType::UINT8 ) {
513- if (!std::is_same<T, uint8_t >::value ) {
582+ } else if constexpr (std::is_same<T, uint8_t >::value ) {
583+ if (vector_with_ids[i]. vector (). value_type () != pb::common::ValueType::UINT8 ) {
514584 std::string s = fmt::format (" template not match vectors value_type : {}" ,
515585 pb::common::ValueType_Name (vector_with_ids[i].vector ().value_type ()));
516586 DINGO_LOG (ERROR) << s;
517587 return nullptr ;
518588 }
519589 const auto & vector_value = vector_with_ids[i].vector ().binary_values ();
520- memcpy (vectors.get () + i * dimension / CHAR_BIT, vector_value.data (), dimension / CHAR_BIT);
590+ for (int j = 0 ; j < vector_value.size (); j++) {
591+ vectors.get ()[i * dimension / CHAR_BIT + j] = static_cast <uint8_t >(vector_value[j][0 ]);
592+ }
521593 } else {
522594 std::string s =
523595 fmt::format (" invalid value type : {}" , pb::common::ValueType_Name (vector_with_ids[i].vector ().value_type ()));
@@ -855,8 +927,9 @@ butil::Status VectorIndexUtils::ValidateVectorIndexParameter(
855927 !(ivf_flat_parameter.metric_type () == pb::common::METRIC_TYPE_INNER_PRODUCT) &&
856928 !(ivf_flat_parameter.metric_type () == pb::common::METRIC_TYPE_L2)) {
857929 DINGO_LOG (ERROR) << " ivf_flat_parameter.metric_type is illegal " << ivf_flat_parameter.metric_type ();
858- return butil::Status (pb::error::Errno::EILLEGAL_PARAMTETERS,
859- " ivf_flat_parameter.metric_type is illegal " + std::to_string (ivf_flat_parameter.metric_type ()));
930+ return butil::Status (
931+ pb::error::Errno::EILLEGAL_PARAMTETERS,
932+ " ivf_flat_parameter.metric_type is illegal " + std::to_string (ivf_flat_parameter.metric_type ()));
860933 }
861934
862935 // check ivf_flat_parameter.ncentroids
0 commit comments