diff --git a/chromadb/utils/distance_functions.py b/chromadb/utils/distance_functions.py index e4c95f87832..c8669798862 100644 --- a/chromadb/utils/distance_functions.py +++ b/chromadb/utils/distance_functions.py @@ -20,12 +20,20 @@ def cosine(x: Vector, y: Vector) -> float: NORM_EPS = 1e-30 if x.dtype == np.float16 or y.dtype == np.float16: NORM_EPS = 1e-7 - return cast( - float, - ( - 1.0 - np.dot(x, y) / ((np.linalg.norm(x) * np.linalg.norm(y)) + NORM_EPS) - ).item(), - ) + + # Avoid redundant norm calculations: compute each norm only once + x_norm = np.linalg.norm(x) + y_norm = np.linalg.norm(y) + denom = (x_norm * y_norm) + NORM_EPS + + # Use np.dot for both 1D and multidimensional (row vector) cases + dot = np.dot(x, y) + + # Use fused multiply-add if available for numerical precision, but stay with simple version for performance + # Directly perform subtraction and division without additional .item() dereference if already scalar + # But in either case, guarantee returning a Python float + result = 1.0 - dot / denom + return cast(float, float(result)) def ip(x: Vector, y: Vector) -> float: