diff --git a/src/numerical/linear_algebra.py b/src/numerical/linear_algebra.py index 0c56ca3..d881301 100644 --- a/src/numerical/linear_algebra.py +++ b/src/numerical/linear_algebra.py @@ -30,14 +30,15 @@ def matrix_inverse(matrix: np.ndarray) -> np.ndarray: raise ValueError("Matrix must be square") n = matrix.shape[0] identity = np.eye(n) - augmented = np.hstack((matrix, identity)) + augmented = np.hstack((matrix.astype(float, copy=False), identity)) for i in range(n): pivot = augmented[i, i] - augmented[i] = augmented[i] / pivot - for j in range(n): - if i != j: - factor = augmented[j, i] - augmented[j] = augmented[j] - factor * augmented[i] + augmented[i] /= pivot + # Vectorized row operation for j != i + mask = np.arange(n) != i + factors = augmented[mask, i] + # Subtract factors[:, None] * augmented[i] from augmented[mask] + augmented[mask] -= factors[:, np.newaxis] * augmented[i] return augmented[:, n:]