Skip to content

Commit 592da13

Browse files
authored
Merge pull request #133 from Atry/split-type-annotation
Split should return TransformedTensors
2 parents ce885bc + 25ad61a commit 592da13

File tree

3 files changed

+4
-4
lines changed

3 files changed

+4
-4
lines changed

Tensors/src/main/scala/com/thoughtworks/compute/Tensors.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -937,10 +937,10 @@ trait Tensors extends OpenCL {
937937
/**
938938
* @group delayed
939939
*/
940-
def split(dimension: Int): IndexedSeq[Tensor] = {
940+
def split(dimension: Int): IndexedSeq[TransformedTensor] = {
941941
// TODO: override map/reduce to produce less OpenCL C code
942942
val newShape = shape.patch(dimension, Nil, 1)
943-
final class TensorSeq extends IndexedSeq[Tensor] {
943+
final class TensorSeq extends IndexedSeq[TransformedTensor] {
944944

945945
override def stringPrefix = "TensorSeq"
946946

Tensors/src/test/scala/com/thoughtworks/compute/TensorsSpec.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,7 @@ class TensorsSpec extends AsyncFreeSpec with Matchers {
307307
val Array(`j`, k) = matrix2.shape
308308
val product = matrix1.broadcast(Array(i, j, k)) * matrix2.reshape(Array(1, j, k)).broadcast(Array(i, j, k))
309309

310-
product.split(1).reduce(_ + _)
310+
product.split(1).reduce[Tensor](_ + _)
311311

312312
}
313313

benchmarks/src/jmh/scala/com/thoughtworks/compute/benchmarks.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ object benchmarks {
100100
// unroll only j
101101
val Array(`j`, k) = matrix2.shape
102102
val product = matrix1.broadcast(Array(i, j, k)) * matrix2.reshape(Array(1, j, k)).broadcast(Array(i, j, k))
103-
product.split(1).reduce(_ + _)
103+
product.split(1).reduce[Tensor](_ + _)
104104
}
105105
}
106106

0 commit comments

Comments
 (0)