Skip to content

Commit 119719d

Browse files
authored
Add mat mul (#75)
* Added matmul and pairwise distance * Added matmul and pairwise distance * Added matmul and pairwise distance * Added matmul and pairwise distance * Added matmul and pairwise distance * Added matmul and pairwise distance
1 parent a6624dd commit 119719d

File tree

4 files changed

+417
-2
lines changed

4 files changed

+417
-2
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ repos:
5050
additional_dependencies: ["tomli"]
5151
- repo: https://github.com/astral-sh/ruff-pre-commit
5252
# Ruff version.
53-
rev: v0.14.13
53+
rev: v0.15.0
5454
hooks:
5555
# Run the linter.
5656
- id: ruff-check

albucore/functions.py

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1081,3 +1081,125 @@ def uint8_wrapper(img: ImageType, *args: Any, **kwargs: Any) -> ImageType:
10811081
return to_float(result) if input_dtype != np.uint8 else result
10821082

10831083
return uint8_wrapper
1084+
1085+
1086+
def matmul(a: np.ndarray, b: np.ndarray) -> np.ndarray:
1087+
"""Optimized matrix multiplication for coordinate transformations.
1088+
1089+
Replaces cv2.gemm which has similar performance but doesn't support all dtypes.
1090+
Uses NumPy's @ operator which leverages optimized BLAS libraries:
1091+
- ARM: Apple Accelerate framework
1092+
- x86: MKL or OpenBLAS
1093+
1094+
Benchmark results (macOS ARM):
1095+
- Small matrices (2x2, 10x10): Similar to cv2.gemm (~1.0x)
1096+
- Large matrices (1024x1024, 2048x2048): Similar to cv2.gemm (~1.0x)
1097+
- Tall/skinny TPS matrices: Similar to cv2.gemm (0.93-1.02x)
1098+
- uint8: Supported (cv2.gemm doesn't support uint8)
1099+
1100+
Args:
1101+
a: First matrix, shape (M, K), dtype float32, float64, or uint8
1102+
b: Second matrix, shape (K, N), dtype float32, float64, or uint8
1103+
1104+
Returns:
1105+
Result matrix, shape (M, N). Output dtype follows NumPy's @ promotion rules:
1106+
- float32 @ float32 -> float32
1107+
- float64 @ float64 -> float64
1108+
- uint8 @ uint8 -> int32 (promoted by NumPy)
1109+
1110+
Examples:
1111+
>>> import numpy as np
1112+
>>> from albucore import matmul
1113+
>>>
1114+
>>> # ThinPlateSpline pairwise distance computation
1115+
>>> points1 = np.random.randn(10000, 2).astype(np.float32) # Target points
1116+
>>> points2 = np.random.randn(10, 2).astype(np.float32) # Control points
1117+
>>> dot_matrix = matmul(points1, points2.T) # (10000, 10)
1118+
>>>
1119+
>>> # TPS coordinate transformation
1120+
>>> kernel = np.random.randn(10000, 10).astype(np.float32)
1121+
>>> weights = np.random.randn(10, 2).astype(np.float32)
1122+
>>> transformed = matmul(kernel, weights) # (10000, 2)
1123+
1124+
Note:
1125+
This function is a simple wrapper around NumPy's @ operator,
1126+
provided for API consistency and to make it explicit that
1127+
this is the recommended replacement for cv2.gemm in geometric
1128+
transformation contexts.
1129+
1130+
Use Cases:
1131+
- ThinPlateSpline geometric transformation (3 uses in AlbumentationsX)
1132+
- Macenko stain normalization for medical imaging (1 use in AlbumentationsX)
1133+
"""
1134+
return a @ b
1135+
1136+
1137+
def pairwise_distances_squared(
1138+
points1: np.ndarray,
1139+
points2: np.ndarray,
1140+
) -> np.ndarray:
1141+
"""Compute squared pairwise Euclidean distances between two point sets.
1142+
1143+
Uses adaptive backend selection based on point set size:
1144+
- Small point sets (n1*n2 < 1000): simsimd.cdist (5.93x faster than cv2)
1145+
- Large point sets (n1*n2 >= 1000): NumPy vectorized (similar to cv2, more maintainable)
1146+
1147+
Algorithm (NumPy backend): ||a - b||² = ||a||² + ||b||² - 2(a·b)
1148+
1149+
Benchmark results (macOS ARM):
1150+
- Small (10x10): simsimd 5.93x faster than cv2
1151+
- Medium (100x100): NumPy 1.05x faster than cv2
1152+
- Large (1000x100): NumPy similar to cv2 (~1.0x)
1153+
1154+
Args:
1155+
points1: First set of points, shape (N, D), dtype float32
1156+
points2: Second set of points, shape (M, D), dtype float32
1157+
1158+
Returns:
1159+
Matrix of squared distances, shape (N, M), dtype float32
1160+
Element [i, j] contains ||points1[i] - points2[j]||²
1161+
1162+
Examples:
1163+
>>> import numpy as np
1164+
>>> from albucore import pairwise_distances_squared
1165+
>>> # Control points for thin plate spline
1166+
>>> src_points = np.array([[0, 0], [1, 0], [0, 1]], dtype=np.float32)
1167+
>>> dst_points = np.array([[0.1, 0.1], [0.9, 0.1]], dtype=np.float32)
1168+
>>> distances_sq = pairwise_distances_squared(src_points, dst_points)
1169+
>>> distances_sq.shape
1170+
(3, 2)
1171+
1172+
Note:
1173+
Returns SQUARED distances (not Euclidean distances).
1174+
This is often what's needed (e.g., for RBF kernels in TPS),
1175+
and avoids the expensive sqrt operation.
1176+
1177+
For actual Euclidean distances: np.sqrt(result)
1178+
1179+
The computation can produce very small negative values (e.g., -1e-6)
1180+
due to floating-point rounding with float32 inputs. The result is
1181+
automatically clamped to enforce non-negativity (distances >= 0).
1182+
"""
1183+
points1 = np.ascontiguousarray(points1, dtype=np.float32)
1184+
points2 = np.ascontiguousarray(points2, dtype=np.float32)
1185+
1186+
n1, n2 = points1.shape[0], points2.shape[0]
1187+
1188+
# Use simsimd for small point sets (benchmarked: 5.93x faster)
1189+
# For larger point sets, NumPy is faster or similar
1190+
if n1 * n2 < 1000:
1191+
result = np.asarray(ss.cdist(points1, points2, metric="sqeuclidean"), dtype=np.float32)
1192+
# Clamp to zero to handle numerical errors that can produce small negative values
1193+
np.maximum(result, 0.0, out=result)
1194+
return result
1195+
1196+
# NumPy vectorized implementation for larger point sets
1197+
# Vectorized computation: ||a-b||² = ||a||² + ||b||² - 2(a·b)
1198+
p1_squared = (points1**2).sum(axis=1, keepdims=True) # (N, 1)
1199+
p2_squared = (points2**2).sum(axis=1)[None, :] # (1, M)
1200+
dot_product = points1 @ points2.T # (N, M)
1201+
1202+
result = p1_squared + p2_squared - 2 * dot_product
1203+
# Clamp to zero to handle numerical errors that can produce small negative values
1204+
np.maximum(result, 0.0, out=result)
1205+
return result

pyproject.toml

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ requires = [ "setuptools>=45", "wheel" ]
55

66
[project]
77
name = "albucore"
8-
version = "0.0.36"
8+
version = "0.0.37"
99

1010
description = "High-performance image processing functions for deep learning and computer vision."
1111
readme = "README.md"
@@ -154,3 +154,8 @@ disallow_untyped_defs = true
154154
warn_return_any = true
155155
strict_equality = true
156156
warn_unreachable = true
157+
158+
[tool.pytest.ini_options]
159+
markers = [
160+
"slow: marks tests as slow (deselect with '-m \"not slow\"')",
161+
]

0 commit comments

Comments
 (0)