Skip to content

Commit 7e6c25e

Browse files
authored
Merge pull request #112 from Atry/fix-broadcast
Fix broadcast method
2 parents 83901f4 + 2e77368 commit 7e6c25e

File tree

3 files changed

+64
-2
lines changed

3 files changed

+64
-2
lines changed

OpenCLKernelBuilder/src/main/scala/com/thoughtworks/compute/OpenCLKernelBuilder.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -298,7 +298,7 @@ trait OpenCLKernelBuilder extends AllExpressions {
298298
def extract: Element = {
299299
val numberOfRows = originalShape.length
300300
val numberOfColumns = matrix.length / numberOfRows
301-
if (matrix.length % numberOfRows != 0) {
301+
if (matrix.length != numberOfRows * numberOfColumns) {
302302
throw new IllegalStateException()
303303
}
304304

Tensors/src/main/scala/com/thoughtworks/compute/Tensors.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -711,7 +711,7 @@ trait Tensors extends OpenCL {
711711
if (i < length) {
712712
shape(i) match {
713713
case di if di == newShape(i) =>
714-
matrix1(i * (length + 1) + i) = 1.0
714+
matrix1(i * (newLength + 1) + i) = 1.0
715715
case 1 =>
716716
case _ =>
717717
throw new IllegalArgumentException(

Tensors/src/test/scala/com/thoughtworks/compute/TensorsSpec.scala

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,4 +293,66 @@ class TensorsSpec extends AsyncFreeSpec with Matchers {
293293
.run
294294
.toScalaFuture
295295

296+
"matrix multiplication" in doTensors
297+
.map { tensors =>
298+
import tensors._
299+
300+
def matrixMultiply(matrix1: Tensor, matrix2: Tensor): Tensor = {
301+
val Array(i, j) = matrix1.shape
302+
val Array(`j`, k) = matrix2.shape
303+
val product = matrix1.broadcast(Array(i, j, k)) * matrix2.reshape(Array(1, j, k)).broadcast(Array(i, j, k))
304+
305+
product.unzip(1).reduce(_ + _)
306+
307+
}
308+
309+
val matrix1 = Tensor(Array(Array(1.0f, 2.0f, 3.0f), Array(4.0f, 5.0f, 6.0f)))
310+
val matrix2 = Tensor(
311+
Array(Array(7.0f, 8.0f, 9.0f, 10.0f), Array(11.0f, 12.0f, 13.0f, 14.0f), Array(15.0f, 16.0f, 17.0f, 18.0f)))
312+
313+
matrixMultiply(matrix1, matrix2).toString should be("[[74.0,80.0,86.0,92.0],[173.0,188.0,203.0,218.0]]")
314+
315+
}
316+
.run
317+
.toScalaFuture
318+
319+
"broadcast" in doTensors
320+
.map { tensors =>
321+
import tensors._
322+
323+
val matrix1 = Tensor(Array(Array(1.0f, 2.0f, 3.0f), Array(4.0f, 5.0f, 6.0f)))
324+
matrix1.broadcast(Array(2, 3, 4)).toString should be(
325+
"[[[1.0,1.0,1.0,1.0],[2.0,2.0,2.0,2.0],[3.0,3.0,3.0,3.0]],[[4.0,4.0,4.0,4.0],[5.0,5.0,5.0,5.0],[6.0,6.0,6.0,6.0]]]")
326+
}
327+
.run
328+
.toScalaFuture
329+
330+
"unrolled matrix multiplication" in doTensors
331+
.map { tensors =>
332+
import tensors._
333+
334+
def matrixMultiply(matrix1: Tensor, matrix2: Tensor): Tensor = {
335+
336+
val columns1 = matrix1.unzip(1)
337+
338+
Tensor.zip(matrix2.unzip(1).map { column2: Tensor =>
339+
(columns1 zip column2.unzip(0))
340+
.map {
341+
case (l: Tensor, r: Tensor) =>
342+
l * r.broadcast(l.shape)
343+
}
344+
.reduce[Tensor](_ + _)
345+
})
346+
}
347+
348+
matrixMultiply(
349+
Tensor(Array(Array(1.0f, 2.0f, 3.0f), Array(4.0f, 5.0f, 6.0f))),
350+
Tensor(
351+
Array(Array(7.0f, 8.0f, 9.0f, 10.0f), Array(11.0f, 12.0f, 13.0f, 14.0f), Array(15.0f, 16.0f, 17.0f, 18.0f)))
352+
).toString should be("[[74.0,80.0,86.0,92.0],[173.0,188.0,203.0,218.0]]")
353+
354+
}
355+
.run
356+
.toScalaFuture
357+
296358
}

0 commit comments

Comments
 (0)