Skip to content

Commit 6e70d28

Browse files
authored
Merge pull request #127 from Atry/split-join
Rename zip and unzip to join and split
2 parents 0c8b03e + 20e7e00 commit 6e70d28

File tree

7 files changed

+40
-40
lines changed

7 files changed

+40
-40
lines changed

Expressions/src/main/scala/com/thoughtworks/compute/Expressions.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ object Expressions {
194194
}
195195

196196
protected trait TupleTermApi extends ValueTermApi with TupleExpressionApi { this: TupleTerm =>
197-
def unzip: Seq[Element]
197+
def split: Seq[Element]
198198
}
199199

200200
/** @template */
@@ -213,7 +213,7 @@ object Expressions {
213213

214214
def parameter(id: Any, element: ValueType, length: Int): TupleTerm { type Element = element.ThisTerm }
215215

216-
def zip[Element0 <: ValueTerm](elements: Element0*): TupleTerm { type Element = Element0 }
216+
def join[Element0 <: ValueTerm](elements: Element0*): TupleTerm { type Element = Element0 }
217217

218218
}
219219

OpenCLKernelBuilder/src/main/scala/com/thoughtworks/compute/OpenCLKernelBuilder.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -509,7 +509,7 @@ trait OpenCLKernelBuilder extends AllExpressions {
509509
type FloatTerm <: (ValueTerm with Any) with ClFloatTerm
510510

511511
trait ClTupleTerm extends TupleTermApi with ClValueTerm { thisTupleTerm: TupleTerm =>
512-
def unzip: Seq[Element] = new IndexedSeq[Element] {
512+
def split: Seq[Element] = new IndexedSeq[Element] {
513513

514514
def length: Int = thisTupleTerm.length
515515

@@ -566,7 +566,7 @@ trait OpenCLKernelBuilder extends AllExpressions {
566566
tupleTermFactory[element.ThisTerm].newInstance(element, length, termCode)
567567
}
568568

569-
def zip[Element0 <: ValueTerm](elements: Element0*): TupleTerm {
569+
def join[Element0 <: ValueTerm](elements: Element0*): TupleTerm {
570570
type Element = Element0
571571
} = {
572572
val elementType = elements.head.valueType

Tensors/src/main/scala/com/thoughtworks/compute/Tensors.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -541,7 +541,7 @@ trait Tensors extends OpenCL {
541541
}
542542
}
543543

544-
def zip(tensors0: Seq[Tensor]): BufferedTensor = {
544+
def join(tensors0: Seq[Tensor]): BufferedTensor = {
545545
def force[A](seq: Seq[A]) = {
546546
seq match {
547547
case seqView: SeqView[A, _] @unchecked =>
@@ -559,7 +559,7 @@ trait Tensors extends OpenCL {
559559
} with BufferedTensor {
560560
private[compute] val doBuffer = {
561561
val elements = tensors.map(_.closure)
562-
enqueueClosure(trees.tuple.zip(elements: _*), headTensor.shape).asInstanceOf[Do[PendingBuffer[Float]]]
562+
enqueueClosure(trees.tuple.join(elements: _*), headTensor.shape).asInstanceOf[Do[PendingBuffer[Float]]]
563563
}.shared
564564
}
565565
}
@@ -937,7 +937,7 @@ trait Tensors extends OpenCL {
937937
/**
938938
* @group delayed
939939
*/
940-
def unzip(dimension: Int): IndexedSeq[Tensor] = {
940+
def split(dimension: Int): IndexedSeq[Tensor] = {
941941
// TODO: override map/reduce to produce less OpenCL C code
942942
val newShape = shape.patch(dimension, Nil, 1)
943943
new IndexedSeq[Tensor] {

Tensors/src/test/scala/com/thoughtworks/compute/TensorsSpec.scala

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ class TensorsSpec extends AsyncFreeSpec with Matchers {
116116
doTensors.map { tensors =>
117117
import tensors._
118118
val tensor = Tensor(Seq(Seq(Seq(Seq(1.0f, 5.0f)))))
119-
tensor.unzip(dimension = 3).map(_.toString) should be(Seq("[[[1.0]]]", "[[[5.0]]]"))
119+
tensor.split(dimension = 3).map(_.toString) should be(Seq("[[[1.0]]]", "[[[5.0]]]"))
120120
}
121121
}.run.toScalaFuture
122122

@@ -140,7 +140,7 @@ class TensorsSpec extends AsyncFreeSpec with Matchers {
140140
"convolution" in {
141141
doTensors.flatMap { tensors =>
142142
import tensors.Tensor
143-
import tensors.Tensor.zip
143+
import tensors.Tensor.join
144144
def convolute(input: Tensor /* batchSize × height × width × depth */,
145145
weight: Tensor /* kernelHeight × kernelWidth × depth × filterSize */,
146146
bias: Tensor /* filterSize */ ): Tensor = {
@@ -150,20 +150,20 @@ class TensorsSpec extends AsyncFreeSpec with Matchers {
150150
case Array(kernelHeight, kernelWidth, `depth`, filterSize) =>
151151
bias.shape match {
152152
case Array(`filterSize`) =>
153-
val inputSeq: Seq[Tensor /* batchSize × height × width */ ] = input.unzip(dimension = 3)
153+
val inputSeq: Seq[Tensor /* batchSize × height × width */ ] = input.split(dimension = 3)
154154

155155
inputSeq.size should be(depth)
156156
inputSeq.head.shape should be(Array(batchSize, height, width))
157157

158158
val weightSeq: Seq[Seq[Seq[Seq[Tensor]]]] /* filterSize × kernelHeight × kernelWidth × depth */ =
159-
weight.unzip(dimension = 3).map { khKwD =>
159+
weight.split(dimension = 3).map { khKwD =>
160160
khKwD.shape should be(Array(kernelHeight, kernelWidth, depth))
161161

162-
khKwD.unzip(dimension = 0).map { kwD =>
162+
khKwD.split(dimension = 0).map { kwD =>
163163
kwD.shape should be(Array(kernelWidth, depth))
164-
kwD.unzip(dimension = 0).map { d =>
164+
kwD.split(dimension = 0).map { d =>
165165
d.shape should be(Array(depth))
166-
d.unzip(dimension = 0)
166+
d.split(dimension = 0)
167167
}
168168
}
169169
}
@@ -173,7 +173,7 @@ class TensorsSpec extends AsyncFreeSpec with Matchers {
173173
weightSeq.head.head.length should be(kernelWidth)
174174
weightSeq.head.head.head.length should be(depth)
175175

176-
val biasSeq: Seq[Tensor] /* filterSize */ = bias.unzip(dimension = 0)
176+
val biasSeq: Seq[Tensor] /* filterSize */ = bias.split(dimension = 0)
177177

178178
val outputChannels: Seq[Tensor] = weightSeq.view
179179
.zip(biasSeq)
@@ -197,7 +197,7 @@ class TensorsSpec extends AsyncFreeSpec with Matchers {
197197
biasPerFilter.broadcast(Array(batchSize, height, width)) + summands.reduce(_ + _)
198198
}
199199

200-
zip(outputChannels)
200+
join(outputChannels)
201201
case _ =>
202202
throw new IllegalArgumentException
203203
}
@@ -307,7 +307,7 @@ class TensorsSpec extends AsyncFreeSpec with Matchers {
307307
val Array(`j`, k) = matrix2.shape
308308
val product = matrix1.broadcast(Array(i, j, k)) * matrix2.reshape(Array(1, j, k)).broadcast(Array(i, j, k))
309309

310-
product.unzip(1).reduce(_ + _)
310+
product.split(1).reduce(_ + _)
311311

312312
}
313313

@@ -338,10 +338,10 @@ class TensorsSpec extends AsyncFreeSpec with Matchers {
338338

339339
def matrixMultiply(matrix1: Tensor, matrix2: Tensor): Tensor = {
340340

341-
val columns1 = matrix1.unzip(1)
341+
val columns1 = matrix1.split(1)
342342

343-
Tensor.zip(matrix2.unzip(1).map { column2: Tensor =>
344-
(columns1 zip column2.unzip(0))
343+
Tensor.join(matrix2.split(1).map { column2: Tensor =>
344+
(columns1 zip column2.split(0))
345345
.map {
346346
case (l: Tensor, r: Tensor) =>
347347
l * r.broadcast(l.shape)

Trees/src/main/scala/com/thoughtworks/compute/Trees.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -958,7 +958,7 @@ object Trees {
958958
protected def erasedExport(foreignCategory: Category, context: ExportContext): Category#Term = {
959959
def foreignTerm = {
960960
val foreignElements = elementTrees.map(_.export(foreignCategory, context))
961-
foreignCategory.tuple.zip(foreignElements: _*)
961+
foreignCategory.tuple.join(foreignElements: _*)
962962
}
963963
context.asScala.getOrElseUpdate(this, foreignTerm)
964964
}
@@ -984,7 +984,7 @@ object Trees {
984984

985985
protected def erasedExport(foreignCategory: Category, context: ExportContext) = {
986986
context.asScala
987-
.getOrElseUpdate(this, tuple.export(foreignCategory, context).unzip.apply(index))
987+
.getOrElseUpdate(this, tuple.export(foreignCategory, context).split.apply(index))
988988

989989
}
990990

@@ -1004,7 +1004,7 @@ object Trees {
10041004

10051005
val length: Int
10061006

1007-
def unzip: Seq[Element] = {
1007+
def split: Seq[Element] = {
10081008
new IndexedSeq[Element] {
10091009
def length = thisTuple.length
10101010
def apply(index: Int): Element = {
@@ -1037,7 +1037,7 @@ object Trees {
10371037
tupleTermFactory[element.ThisTerm].newInstance(element, length, parameterTree)
10381038
}
10391039

1040-
def zip[Element0 <: ValueTerm](elements: Element0*): TupleTerm { type Element = Element0 } = {
1040+
def join[Element0 <: ValueTerm](elements: Element0*): TupleTerm { type Element = Element0 } = {
10411041
val elementTrees = elements.map(_.tree.asInstanceOf[Tree { type TermIn[C <: Category] = Element0#TermIn[C] }])
10421042
val zipTree = Concatenate[Element0](elementTrees)
10431043
tupleTermFactory[Element0].newInstance(elements.head.valueType, elements.length, zipTree)

Trees/src/test/scala/com/thoughtworks/compute/TreesSpec.scala

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ final class TreesSpec extends FreeSpec with Matchers {
5555
"tuple.zip" - {
5656
"reflexive" in {
5757
reflexive(
58-
trees.tuple.zip(
58+
trees.tuple.join(
5959
trees.float.parameter("my_id"),
6060
trees.float.literal(2.0f),
6161
trees.float.literal(3.0f)
@@ -65,12 +65,12 @@ final class TreesSpec extends FreeSpec with Matchers {
6565

6666
"sameStructuralDifferentParameterName" in {
6767
sameStructuralDifferentParameterName(
68-
trees.tuple.zip(
68+
trees.tuple.join(
6969
trees.float.parameter("my_id1"),
7070
trees.float.parameter("my_id2"),
7171
trees.float.literal(0.0f)
7272
),
73-
trees.tuple.zip(
73+
trees.tuple.join(
7474
trees.float.parameter("my_id2"),
7575
trees.float.parameter("my_id3"),
7676
trees.float.literal(0.0f)
@@ -80,11 +80,11 @@ final class TreesSpec extends FreeSpec with Matchers {
8080

8181
"differentStructural" in {
8282
differentStructural(
83-
trees.tuple.zip(
83+
trees.tuple.join(
8484
trees.float.literal(1.0f),
8585
trees.float.literal(0.0f)
8686
),
87-
trees.tuple.zip(
87+
trees.tuple.join(
8888
trees.float.literal(0.0f),
8989
trees.float.literal(1.0f)
9090
)

benchmarks/src/jmh/scala/com/thoughtworks/compute/benchmarks.scala

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -87,9 +87,9 @@ object benchmarks {
8787
val Array(i, j) = matrix1.shape
8888
if (i >= unrollThreshold) {
8989
// unroll j and k
90-
val columns1 = matrix1.unzip(1)
91-
Tensor.zip(matrix2.unzip(1).map { column2: Tensor =>
92-
(columns1 zip column2.unzip(0))
90+
val columns1 = matrix1.split(1)
91+
Tensor.join(matrix2.split(1).map { column2: Tensor =>
92+
(columns1 zip column2.split(0))
9393
.map {
9494
case (l: Tensor, r: Tensor) =>
9595
l * r.broadcast(l.shape)
@@ -100,7 +100,7 @@ object benchmarks {
100100
// unroll only j
101101
val Array(`j`, k) = matrix2.shape
102102
val product = matrix1.broadcast(Array(i, j, k)) * matrix2.reshape(Array(1, j, k)).broadcast(Array(i, j, k))
103-
product.unzip(1).reduce(_ + _)
103+
product.split(1).reduce(_ + _)
104104
}
105105
}
106106

@@ -380,7 +380,7 @@ object benchmarks {
380380
case Array(kernelHeight, kernelWidth, `depth`, filterSize) =>
381381
bias.shape match {
382382
case Array(`filterSize`) =>
383-
val inputSeq: Seq[Tensor /* batchSize × height × width */ ] = input.unzip(dimension = 3)
383+
val inputSeq: Seq[Tensor /* batchSize × height × width */ ] = input.split(dimension = 3)
384384

385385
if (inputSeq.size != depth) {
386386
throw new IllegalArgumentException
@@ -393,27 +393,27 @@ object benchmarks {
393393
}
394394

395395
val weightSeq: Seq[Seq[Seq[Seq[Tensor]]]] /* filterSize × kernelHeight × kernelWidth × depth */ =
396-
weight.unzip(dimension = 3).map { khKwD =>
396+
weight.split(dimension = 3).map { khKwD =>
397397
khKwD.shape match {
398398
case Array(kernelHeight, kernelWidth, depth) =>
399399
case _ =>
400400
throw new IllegalArgumentException
401401
}
402402

403-
khKwD.unzip(dimension = 0).map { kwD =>
403+
khKwD.split(dimension = 0).map { kwD =>
404404
kwD.shape match {
405405
case Array(kernelWidth, depth) =>
406406
case _ =>
407407
throw new IllegalArgumentException
408408
}
409409

410-
kwD.unzip(dimension = 0).map { d =>
410+
kwD.split(dimension = 0).map { d =>
411411
d.shape match {
412412
case Array(depth) =>
413413
case _ =>
414414
throw new IllegalArgumentException
415415
}
416-
d.unzip(dimension = 0)
416+
d.split(dimension = 0)
417417
}
418418
}
419419
}
@@ -428,7 +428,7 @@ object benchmarks {
428428
throw new IllegalArgumentException
429429
}
430430

431-
val biasSeq: Seq[Tensor] /* filterSize */ = bias.unzip(dimension = 0)
431+
val biasSeq: Seq[Tensor] /* filterSize */ = bias.split(dimension = 0)
432432

433433
val outputChannels: Seq[Tensor] = weightSeq.view
434434
.zip(biasSeq)
@@ -454,7 +454,7 @@ object benchmarks {
454454
biasPerFilter.broadcast(Array(batchSize, height, width)) + summands.reduce(_ + _)
455455
}
456456

457-
Tensor.zip(outputChannels)
457+
Tensor.join(outputChannels)
458458
case _ =>
459459
throw new IllegalArgumentException
460460
}

0 commit comments

Comments
 (0)