Skip to content

Commit 923d014

Browse files
authored
Optimizations for Python bindings and sum_of_squares (#102)
This PR has two small optimizations: * Passes `indices` and `ids` by reference rather than by value in `module.cc`. This seems to give a rather substantial speedup for Python ivf query times. * Removes explicit casts to `float` in `sum_of_squares`. This gives a 30% speedup (tested on EC2 r6id.2xlarge).
1 parent a6bfc30 commit 923d014

File tree

2 files changed

+31
-7
lines changed

2 files changed

+31
-7
lines changed

apis/python/src/tiledb/vector_search/module.cc

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -109,8 +109,8 @@ static void declare_qv_query_heap_infinite_ram(py::module& m, const std::string&
109109
[](const ColMajorMatrix<T>& parts,
110110
const ColMajorMatrix<float>& centroids,
111111
const ColMajorMatrix<float>& query_vectors,
112-
std::vector<Id_Type> indices,
113-
std::vector<Id_Type> ids,
112+
std::vector<Id_Type>& indices,
113+
std::vector<Id_Type>& ids,
114114
size_t nprobe,
115115
size_t k_nn,
116116
bool nth,
@@ -137,7 +137,7 @@ static void declare_qv_query_heap_finite_ram(py::module& m, const std::string& s
137137
const std::string& parts_uri,
138138
const ColMajorMatrix<float>& centroids,
139139
const ColMajorMatrix<float>& query_vectors,
140-
std::vector<Id_Type> indices,
140+
std::vector<Id_Type>& indices,
141141
const std::string& ids_uri,
142142
size_t nprobe,
143143
size_t k_nn,
@@ -167,8 +167,8 @@ static void declare_nuv_query_heap_infinite_ram(py::module& m, const std::string
167167
[](const ColMajorMatrix<T>& parts,
168168
const ColMajorMatrix<float>& centroids,
169169
const ColMajorMatrix<float>& query_vectors,
170-
std::vector<Id_Type> indices,
171-
std::vector<Id_Type> ids,
170+
std::vector<Id_Type>& indices,
171+
std::vector<Id_Type>& ids,
172172
size_t nprobe,
173173
size_t k_nn,
174174
bool nth,
@@ -195,7 +195,7 @@ static void declare_nuv_query_heap_finite_ram(py::module& m, const std::string&
195195
const std::string& parts_uri,
196196
const ColMajorMatrix<float>& centroids,
197197
const ColMajorMatrix<float>& query_vectors,
198-
std::vector<Id_Type> indices,
198+
std::vector<Id_Type>& indices,
199199
const std::string& ids_uri,
200200
size_t nprobe,
201201
size_t k_nn,

src/include/defs.h

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,17 +57,41 @@
5757
* @param b
5858
* @return
5959
*/
60+
#if 0
6061
template <class V, class U>
6162
inline auto sum_of_squares(V const& a, U const& b) {
6263
float sum{0.0};
6364
size_t size_a = size(a);
65+
66+
if constexpr (std::is_same_v<decltype(a[0]),decltype(b[0])>) {
67+
for (size_t i = 0; i < size_a; ++i) {
68+
float diff = a[i]- b[i];
69+
sum += diff * diff;
70+
}
71+
} else {
72+
for (size_t i = 0; i < size_a; ++i) {
73+
float diff = ((float)a[i]) - ((float)b[i]);
74+
sum += diff * diff;
75+
}
76+
}
77+
return sum;
78+
}
79+
#else
80+
template <class V, class U>
81+
inline auto sum_of_squares(V const& a, U const& b) {
82+
float sum{0.0};
83+
size_t size_a = size(a);
84+
6485
for (size_t i = 0; i < size_a; ++i) {
65-
float diff = ((float)a[i]) - ((float)b[i]);
86+
// float diff = ((float)a[i]) - ((float)b[i]);
87+
float diff = a[i] - b[i];
6688
sum += diff * diff;
6789
}
6890
return sum;
6991
}
7092

93+
#endif
94+
7195
/**
7296
* @brief Compute L2 distance between two vectors.
7397
* @tparam V

0 commit comments

Comments
 (0)