Skip to content

Commit 4054be1

Browse files
authored
Cut 10% of query time with this one weird trick. (#96)
Remove `std::sqrt` from `L2` as sum of squares will have the same ordering as square root of the sum of squares. When actual distances are required, we can take the square root then. This seems to save around 10% of query time.
1 parent 92fecad commit 4054be1

File tree

3 files changed

+9
-8
lines changed

3 files changed

+9
-8
lines changed

src/include/defs.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,8 @@ inline auto sum_of_squares(V const& a, U const& b) {
7777
*/
7878
template <class V, class U>
7979
inline auto L2(V const& a, U const& b) {
80-
return std::sqrt(sum_of_squares(a, b));
80+
// return std::sqrt(sum_of_squares(a, b));
81+
return sum_of_squares(a, b);
8182
}
8283

8384
/**

src/include/scoring.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -95,9 +95,9 @@ void gemm_scores(const Matrix1& A, const Matrix2& B, Matrix3& C, unsigned nthrea
9595
CblasColMajor, M, N, 1.0, &beta_ones[0], 1, &beta[0], 1, C.data(), M);
9696

9797
stdx::execution::parallel_policy par{nthreads};
98-
stdx::for_each(std::move(par), begin(raveled_C), end(raveled_C), [](auto& a) {
99-
a = sqrt(a);
100-
});
98+
// stdx::for_each(std::move(par), begin(raveled_C), end(raveled_C), [](auto& a) {
99+
// a = sqrt(a);
100+
// });
101101
}
102102

103103
template <class Matrix1, class Matrix2, class Matrix3>

src/include/test/unit_gemm.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -277,10 +277,10 @@ TEST_CASE("gemm: row major test", "[sgemm]") {
277277
L2_span[j][i] = L2(B_span[j], A_span[i]);
278278
}
279279
}
280-
CHECK(std::abs(L2_span[0][0] - 10.3923) < .0001);
281-
CHECK(std::abs(L2_span[1][0] - 15.5884) < .0001);
282-
CHECK(std::abs(L2_span[0][1] - 5.1961) < .0001);
283-
CHECK(std::abs(L2_span[1][1] - 10.3923) < .0001);
280+
CHECK(std::abs(L2_span[0][0] - 10.3923 * 10.3923) < .001);
281+
CHECK(std::abs(L2_span[1][0] - 15.5884 * 15.5884) < .002);
282+
CHECK(std::abs(L2_span[0][1] - 5.1961 * 5.1961) < .002);
283+
CHECK(std::abs(L2_span[1][1] - 10.3923 * 10.3923) < .001);
284284

285285
/****************************************************************
286286
*

0 commit comments

Comments
 (0)