Skip to content

Commit 2e77368

Browse files
committed
Add tests for matrix multiplication
1 parent 2b6f87b commit 2e77368

File tree

1 file changed

+52
-0
lines changed

1 file changed

+52
-0
lines changed

Tensors/src/test/scala/com/thoughtworks/compute/TensorsSpec.scala

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,29 @@ class TensorsSpec extends AsyncFreeSpec with Matchers {
293293
.run
294294
.toScalaFuture
295295

296+
"matrix multiplication" in doTensors
297+
.map { tensors =>
298+
import tensors._
299+
300+
def matrixMultiply(matrix1: Tensor, matrix2: Tensor): Tensor = {
301+
val Array(i, j) = matrix1.shape
302+
val Array(`j`, k) = matrix2.shape
303+
val product = matrix1.broadcast(Array(i, j, k)) * matrix2.reshape(Array(1, j, k)).broadcast(Array(i, j, k))
304+
305+
product.unzip(1).reduce(_ + _)
306+
307+
}
308+
309+
val matrix1 = Tensor(Array(Array(1.0f, 2.0f, 3.0f), Array(4.0f, 5.0f, 6.0f)))
310+
val matrix2 = Tensor(
311+
Array(Array(7.0f, 8.0f, 9.0f, 10.0f), Array(11.0f, 12.0f, 13.0f, 14.0f), Array(15.0f, 16.0f, 17.0f, 18.0f)))
312+
313+
matrixMultiply(matrix1, matrix2).toString should be("[[74.0,80.0,86.0,92.0],[173.0,188.0,203.0,218.0]]")
314+
315+
}
316+
.run
317+
.toScalaFuture
318+
296319
"broadcast" in doTensors
297320
.map { tensors =>
298321
import tensors._
@@ -303,4 +326,33 @@ class TensorsSpec extends AsyncFreeSpec with Matchers {
303326
}
304327
.run
305328
.toScalaFuture
329+
330+
"unrolled matrix multiplication" in doTensors
331+
.map { tensors =>
332+
import tensors._
333+
334+
def matrixMultiply(matrix1: Tensor, matrix2: Tensor): Tensor = {
335+
336+
val columns1 = matrix1.unzip(1)
337+
338+
Tensor.zip(matrix2.unzip(1).map { column2: Tensor =>
339+
(columns1 zip column2.unzip(0))
340+
.map {
341+
case (l: Tensor, r: Tensor) =>
342+
l * r.broadcast(l.shape)
343+
}
344+
.reduce[Tensor](_ + _)
345+
})
346+
}
347+
348+
matrixMultiply(
349+
Tensor(Array(Array(1.0f, 2.0f, 3.0f), Array(4.0f, 5.0f, 6.0f))),
350+
Tensor(
351+
Array(Array(7.0f, 8.0f, 9.0f, 10.0f), Array(11.0f, 12.0f, 13.0f, 14.0f), Array(15.0f, 16.0f, 17.0f, 18.0f)))
352+
).toString should be("[[74.0,80.0,86.0,92.0],[173.0,188.0,203.0,218.0]]")
353+
354+
}
355+
.run
356+
.toScalaFuture
357+
306358
}

0 commit comments

Comments
 (0)