Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 25 additions & 13 deletions src/numpy_pandas/np_opts.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,33 +110,45 @@ def linear_equation_solver(A: List[List[float]], b: List[float]) -> List[float]:
"""Solve system of linear equations Ax = b using Gaussian elimination."""
n = len(A)

# Create augmented matrix [A|b]
augmented = [row[:] + [b[i]] for i, row in enumerate(A)]
# Create augmented matrix [A|b] in-place for performance
augmented = [A[i] + [b[i]] for i in range(n)]

# Forward elimination
for i in range(n):
# Find pivot
# Find pivot (maximum in this column)
max_idx = i
max_val = abs(augmented[i][i])
for j in range(i + 1, n):
if abs(augmented[j][i]) > abs(augmented[max_idx][i]):
v = abs(augmented[j][i])
if v > max_val:
max_val = v
max_idx = j

# Swap rows
augmented[i], augmented[max_idx] = augmented[max_idx], augmented[i]
if max_idx != i:
augmented[i], augmented[max_idx] = augmented[max_idx], augmented[i]

# Eliminate below
piv_row = augmented[i]
piv_val = piv_row[i]

# Eliminate rows below
for j in range(i + 1, n):
factor = augmented[j][i] / augmented[i][i]
row = augmented[j]
a = row[i]
if a == 0:
continue
factor = a / piv_val
# Unroll inner loop for fused multiply-subtract
for k in range(i, n + 1):
augmented[j][k] -= factor * augmented[i][k]
row[k] -= factor * piv_row[k]

# Back substitution
x = [0] * n
x = [0.0] * n
for i in range(n - 1, -1, -1):
x[i] = augmented[i][n]
val = augmented[i][n]
ai = augmented[i]
for j in range(i + 1, n):
x[i] -= augmented[i][j] * x[j]
x[i] /= augmented[i][i]
val -= ai[j] * x[j]
x[i] = val / ai[i]

return x

Expand Down