Skip to content

Commit 912c8cb

Browse files
committed
refactor functions in a more functional way
1 parent 597b10b commit 912c8cb

File tree

2 files changed

+12
-21
lines changed

2 files changed

+12
-21
lines changed

src/main/scala/matrix.scala

Lines changed: 11 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -130,15 +130,11 @@ object Matrix:
130130
identity
131131

132132
def colminmax(a: Mat, dirs: Array[Direction]): Vec =
133-
val n = a(0).length
134-
val result = Array.fill(n)(0.0)
135-
for j <- 0 until n do
136-
val col = getcolat(a, j)
137-
val currentdir = dirs(j)
138-
result(j) = currentdir match
139-
case Direction.Minimize => col.min
140-
case Direction.Maximize => col.max
141-
result
133+
Array.tabulate(a(0).length){ (colindex) =>
134+
dirs(colindex) match
135+
case Direction.Minimize => Matrix.getcolat(a, colindex).min
136+
case Direction.Maximize => Matrix.getcolat(a, colindex).max
137+
}
142138

143139
def inversedirections(dirs: Array[Direction]): Array[Direction] =
144140
dirs.map {
@@ -160,11 +156,9 @@ object Matrix:
160156
val n = a.length
161157
val m = b(0).length
162158
val p = b.length
163-
val result = Array.fill(n, m)(0.0)
164-
for i <- 0 until n do
165-
for j <- 0 until m do
166-
for k <- 0 until p do result(i)(j) += a(i)(k) * b(k)(j)
167-
result
159+
Array.tabulate(n, m) { (i, j) =>
160+
(0 until p).map(k => a(i)(k) * b(k)(j)).sum
161+
}
168162

169163
def mul(a: Mat, b: Vec): Vec = mul(a, makeRowMatrix(b)).flatten
170164

@@ -191,11 +185,9 @@ object Matrix:
191185

192186
def replicate(a: Mat): Mat =
193187
val (rows, cols) = Matrix.size(a)
194-
val result = Matrix.zeros(rows, cols)
195-
for i <- 0 until rows do
196-
for j <- 0 until cols do
197-
result(i)(j) = a(i)(j)
198-
result
188+
Array.tabulate(rows, cols) { (i, j) =>
189+
a(i)(j)
190+
}
199191

200192

201193

src/main/scala/statistics.scala

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,7 @@ object Statistics:
2828
)) / (a.length - 1.0)
2929

3030
def correlation(a: Mat): Mat =
31-
val n = a.length
32-
val m = a(0).length
31+
val (n, m) = Matrix.size(a)
3332
Array.tabulate(m, m)((i, j) => correlation(Matrix.getcolat(a, i), Matrix.getcolat(a, j)))
3433

3534
def euclideanDistance(a: Vec, b: Vec): Double =

0 commit comments

Comments
 (0)