@@ -293,6 +293,29 @@ class TensorsSpec extends AsyncFreeSpec with Matchers {
293
293
.run
294
294
.toScalaFuture
295
295
296
+ " matrix multiplication" in doTensors
297
+ .map { tensors =>
298
+ import tensors ._
299
+
300
+ def matrixMultiply (matrix1 : Tensor , matrix2 : Tensor ): Tensor = {
301
+ val Array (i, j) = matrix1.shape
302
+ val Array (`j`, k) = matrix2.shape
303
+ val product = matrix1.broadcast(Array (i, j, k)) * matrix2.reshape(Array (1 , j, k)).broadcast(Array (i, j, k))
304
+
305
+ product.unzip(1 ).reduce(_ + _)
306
+
307
+ }
308
+
309
+ val matrix1 = Tensor (Array (Array (1.0f , 2.0f , 3.0f ), Array (4.0f , 5.0f , 6.0f )))
310
+ val matrix2 = Tensor (
311
+ Array (Array (7.0f , 8.0f , 9.0f , 10.0f ), Array (11.0f , 12.0f , 13.0f , 14.0f ), Array (15.0f , 16.0f , 17.0f , 18.0f )))
312
+
313
+ matrixMultiply(matrix1, matrix2).toString should be(" [[74.0,80.0,86.0,92.0],[173.0,188.0,203.0,218.0]]" )
314
+
315
+ }
316
+ .run
317
+ .toScalaFuture
318
+
296
319
" broadcast" in doTensors
297
320
.map { tensors =>
298
321
import tensors ._
@@ -303,4 +326,33 @@ class TensorsSpec extends AsyncFreeSpec with Matchers {
303
326
}
304
327
.run
305
328
.toScalaFuture
329
+
330
+ " unrolled matrix multiplication" in doTensors
331
+ .map { tensors =>
332
+ import tensors ._
333
+
334
+ def matrixMultiply (matrix1 : Tensor , matrix2 : Tensor ): Tensor = {
335
+
336
+ val columns1 = matrix1.unzip(1 )
337
+
338
+ Tensor .zip(matrix2.unzip(1 ).map { column2 : Tensor =>
339
+ (columns1 zip column2.unzip(0 ))
340
+ .map {
341
+ case (l : Tensor , r : Tensor ) =>
342
+ l * r.broadcast(l.shape)
343
+ }
344
+ .reduce[Tensor ](_ + _)
345
+ })
346
+ }
347
+
348
+ matrixMultiply(
349
+ Tensor (Array (Array (1.0f , 2.0f , 3.0f ), Array (4.0f , 5.0f , 6.0f ))),
350
+ Tensor (
351
+ Array (Array (7.0f , 8.0f , 9.0f , 10.0f ), Array (11.0f , 12.0f , 13.0f , 14.0f ), Array (15.0f , 16.0f , 17.0f , 18.0f )))
352
+ ).toString should be(" [[74.0,80.0,86.0,92.0],[173.0,188.0,203.0,218.0]]" )
353
+
354
+ }
355
+ .run
356
+ .toScalaFuture
357
+
306
358
}
0 commit comments