diff --git a/src/numpy_pandas/np_opts.py b/src/numpy_pandas/np_opts.py index 7cf690b..e1454be 100644 --- a/src/numpy_pandas/np_opts.py +++ b/src/numpy_pandas/np_opts.py @@ -110,33 +110,45 @@ def linear_equation_solver(A: List[List[float]], b: List[float]) -> List[float]: """Solve system of linear equations Ax = b using Gaussian elimination.""" n = len(A) - # Create augmented matrix [A|b] - augmented = [row[:] + [b[i]] for i, row in enumerate(A)] + # Create augmented matrix [A|b] in-place for performance + augmented = [A[i] + [b[i]] for i in range(n)] # Forward elimination for i in range(n): - # Find pivot + # Find pivot (maximum in this column) max_idx = i + max_val = abs(augmented[i][i]) for j in range(i + 1, n): - if abs(augmented[j][i]) > abs(augmented[max_idx][i]): + v = abs(augmented[j][i]) + if v > max_val: + max_val = v max_idx = j - # Swap rows - augmented[i], augmented[max_idx] = augmented[max_idx], augmented[i] + if max_idx != i: + augmented[i], augmented[max_idx] = augmented[max_idx], augmented[i] - # Eliminate below + piv_row = augmented[i] + piv_val = piv_row[i] + + # Eliminate rows below for j in range(i + 1, n): - factor = augmented[j][i] / augmented[i][i] + row = augmented[j] + a = row[i] + if a == 0: + continue + factor = a / piv_val + # Unroll inner loop for fused multiply-subtract for k in range(i, n + 1): - augmented[j][k] -= factor * augmented[i][k] + row[k] -= factor * piv_row[k] # Back substitution - x = [0] * n + x = [0.0] * n for i in range(n - 1, -1, -1): - x[i] = augmented[i][n] + val = augmented[i][n] + ai = augmented[i] for j in range(i + 1, n): - x[i] -= augmented[i][j] * x[j] - x[i] /= augmented[i][i] + val -= ai[j] * x[j] + x[i] = val / ai[i] return x