Skip to content

Commit 0ff0f49

Browse files
committed
add linear equation solver
1 parent 24ebba6 commit 0ff0f49

File tree

2 files changed

+48
-0
lines changed

2 files changed

+48
-0
lines changed

src/main/scala/matrix.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,3 +146,9 @@ object Matrix:
146146
for j <- 0 until m do
147147
for k <- 0 until p do result(i)(j) += a(i)(k) * b(k)(j)
148148
result
149+
150+
def mul(a: Mat, b: Vec): Vec = mul(a, makeRowMatrix(b)).flatten
151+
152+
def solve(A: Mat, b: Vec): Vec = mul(inverse(A), b)
153+
154+

src/test/scala/testmatrix.scala

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -412,4 +412,46 @@ class TestMatrix extends munit.FunSuite {
412412

413413
A.assert(Matrix.elementwise_equal(result, expected, 1e-5))
414414
}
415+
test("Equation solve 1"){
416+
val Amat = Array(Array(2.0, 4.0), Array(6.0, 5.0))
417+
val bvec = Array(12.0, 60.0)
418+
val expected = Array(12.857, -3.428)
419+
val result = Matrix.solve(Amat, bvec)
420+
A.assert(Matrix.elementwise_equal(result, expected, 1e-3))
421+
}
422+
test("Equation solve 2"){
423+
val Amat = Array(Array(2.0, 1.0), Array(6.0, 0.0))
424+
val bvec = Array(12.0, 60.0)
425+
val expected = Array(10.0, -8.0)
426+
val result = Matrix.solve(Amat, bvec)
427+
A.assert(Matrix.elementwise_equal(result, expected, 1e-3))
428+
}
429+
test("Equation solve 3"){
430+
val Amat = Array(Array(2.0, 0.0), Array(6.0, 1.0))
431+
val bvec = Array(12.0, 60.0)
432+
val expected = Array(6.0, 24.0)
433+
val result = Matrix.solve(Amat, bvec)
434+
A.assert(Matrix.elementwise_equal(result, expected, 1e-3))
435+
}
436+
test("Equation solve 4"){
437+
val Amat = Array(Array(4.0, 1.0), Array(5.0, 0.0))
438+
val bvec = Array(12.0, 60.0)
439+
val expected = Array(12.0, -36.0)
440+
val result = Matrix.solve(Amat, bvec)
441+
A.assert(Matrix.elementwise_equal(result, expected, 1e-3))
442+
}
443+
test("Equation solve 3x3"){
444+
val Amat = Array(
445+
Array(2.0, 1.0, 3.0),
446+
Array(1.0, 2.0, 4.0),
447+
Array(2.0, 2.0, 4.0))
448+
449+
val b = Array(13.0, 17.0, 18.0)
450+
451+
val expected = Array(1.0, 2.0, 3.0)
452+
453+
val result = Matrix.solve(Amat, b)
454+
455+
A.assert(Matrix.elementwise_equal(result, expected, 1e-3))
456+
}
415457
}

0 commit comments

Comments
 (0)