Skip to content

Commit 2b6f87b

Browse files
committed
Fix broadcast method
1 parent 83901f4 commit 2b6f87b

File tree

3 files changed

+12
-2
lines changed

3 files changed

+12
-2
lines changed

OpenCLKernelBuilder/src/main/scala/com/thoughtworks/compute/OpenCLKernelBuilder.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -298,7 +298,7 @@ trait OpenCLKernelBuilder extends AllExpressions {
298298
def extract: Element = {
299299
val numberOfRows = originalShape.length
300300
val numberOfColumns = matrix.length / numberOfRows
301-
if (matrix.length % numberOfRows != 0) {
301+
if (matrix.length != numberOfRows * numberOfColumns) {
302302
throw new IllegalStateException()
303303
}
304304

Tensors/src/main/scala/com/thoughtworks/compute/Tensors.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -711,7 +711,7 @@ trait Tensors extends OpenCL {
711711
if (i < length) {
712712
shape(i) match {
713713
case di if di == newShape(i) =>
714-
matrix1(i * (length + 1) + i) = 1.0
714+
matrix1(i * (newLength + 1) + i) = 1.0
715715
case 1 =>
716716
case _ =>
717717
throw new IllegalArgumentException(

Tensors/src/test/scala/com/thoughtworks/compute/TensorsSpec.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,4 +293,14 @@ class TensorsSpec extends AsyncFreeSpec with Matchers {
293293
.run
294294
.toScalaFuture
295295

296+
"broadcast" in doTensors
297+
.map { tensors =>
298+
import tensors._
299+
300+
val matrix1 = Tensor(Array(Array(1.0f, 2.0f, 3.0f), Array(4.0f, 5.0f, 6.0f)))
301+
matrix1.broadcast(Array(2, 3, 4)).toString should be(
302+
"[[[1.0,1.0,1.0,1.0],[2.0,2.0,2.0,2.0],[3.0,3.0,3.0,3.0]],[[4.0,4.0,4.0,4.0],[5.0,5.0,5.0,5.0],[6.0,6.0,6.0,6.0]]]")
303+
}
304+
.run
305+
.toScalaFuture
296306
}

0 commit comments

Comments
 (0)