Skip to content

Commit 635699b

Browse files
committed
add weightise function
1 parent dfd6a1d commit 635699b

File tree

2 files changed

+18
-0
lines changed

2 files changed

+18
-0
lines changed

src/main/scala/matrix.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,3 +132,5 @@ object Matrix:
132132
case Direction.Maximize => Direction.Minimize
133133
}
134134

135+
def weightise(a: Mat, w: Vec): Mat =
136+
a.transpose.zip(w).map((row, weight) => row.map(value => value * weight)).transpose

src/test/scala/testmatrix.scala

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -352,5 +352,21 @@ class TestMatrix extends munit.FunSuite {
352352
Direction.Maximize)
353353
direction.zip(expected).foreach{case (a, b) => A.assertNotEquals(a, b)}
354354
}
355+
test("Weightise"){
356+
val mat = Array(
357+
Array(1.0, 5.0, 6.0, 10.0, 10.0),
358+
Array(-1.0, 10.0, 9.0, 11.0, 11.0),
359+
Array(9.0, 17.0, 12.0, 12.0, 12.0)
360+
)
361+
val weights = Array(0.1, 0.2, 0.3, 0.3, 0.1)
362+
val weighted = Matrix.weightise(mat, weights)
363+
val expected = Array(
364+
Array( 0.1, 1.0, 1.8, 3.0, 1.0),
365+
Array(-0.1, 2.0, 2.7, 3.3, 1.1),
366+
Array( 0.9, 3.4, 3.6, 3.6, 1.2)
367+
)
368+
A.assert(Matrix.elementwise_equal(weighted, expected, 1e-3))
369+
}
355370

356371
}
372+

0 commit comments

Comments
 (0)