Skip to content

Commit 3a39309

Browse files
authored
Merge pull request #129 from Atry/nonInline
Move `cache` and `doCache` method to `BufferedTensor` and rename BufferedTensor to NonInlineTensor
2 parents aef2f6c + 3024524 commit 3a39309

File tree

2 files changed

+99
-98
lines changed

2 files changed

+99
-98
lines changed

Tensors/src/main/scala/com/thoughtworks/compute/Tensors.scala

Lines changed: 94 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -422,12 +422,12 @@ trait Tensors extends OpenCL {
422422
}
423423

424424
def apply[A](elements: A, padding: Float = 0.0f)(
425-
implicit tensorBuilder: TensorBuilder.Aux[A, Float]): BufferedTensor = {
425+
implicit tensorBuilder: TensorBuilder.Aux[A, Float]): NonInlineTensor = {
426426
val padding0 = padding
427427
new {
428428
val shape: Array[Int] = tensorBuilder.shape(elements).toArray
429429
val padding: Float = padding0
430-
} with BufferedTensor {
430+
} with NonInlineTensor {
431431
private[compute] val doBuffer = {
432432
Do(TryT(ResourceT(UnitContinuation.delay {
433433
val data = tensorBuilder.flatten(elements).toArray
@@ -454,13 +454,13 @@ trait Tensors extends OpenCL {
454454
} with InlineTensor
455455
}
456456

457-
def random(shape: Array[Int], seed: Int = Random.nextInt(), padding: Float = 0.0f): BufferedTensor = {
457+
def random(shape: Array[Int], seed: Int = Random.nextInt(), padding: Float = 0.0f): NonInlineTensor = {
458458
val shape0 = shape
459459
val padding0 = padding
460460
new {
461461
val padding = padding0
462462
val shape = shape0
463-
} with BufferedTensor {
463+
} with NonInlineTensor {
464464
private[compute] val doBuffer: Do[PendingBuffer[Float]] = {
465465
val size = shape.product
466466
allocateBuffer[Float](size).flatMap { buffer =>
@@ -475,13 +475,13 @@ trait Tensors extends OpenCL {
475475
}
476476

477477
/** Generate random numbers in normal distribution. */
478-
def randomNormal(shape: Array[Int], seed: Int = Random.nextInt(), padding: Float = 0.0f): BufferedTensor = {
478+
def randomNormal(shape: Array[Int], seed: Int = Random.nextInt(), padding: Float = 0.0f): NonInlineTensor = {
479479
val shape0 = shape
480480
val padding0 = padding
481481
new {
482482
val padding = padding0
483483
val shape = shape0
484-
} with BufferedTensor {
484+
} with NonInlineTensor {
485485
private[compute] val doBuffer: Do[PendingBuffer[Float]] = {
486486
val size = shape.product
487487
val paddingSize = if (size % 2 == 1) {
@@ -541,7 +541,7 @@ trait Tensors extends OpenCL {
541541
}
542542
}
543543

544-
def join(tensors0: Seq[Tensor]): BufferedTensor = {
544+
def join(tensors0: Seq[Tensor]): NonInlineTensor = {
545545
def force[A](seq: Seq[A]) = {
546546
seq match {
547547
case seqView: SeqView[A, _] @unchecked =>
@@ -556,7 +556,7 @@ trait Tensors extends OpenCL {
556556
new {
557557
val shape = headTensor.shape :+ tensors.length
558558
val padding: Float = headTensor.padding
559-
} with BufferedTensor {
559+
} with NonInlineTensor {
560560
private[compute] val doBuffer = {
561561
val elements = tensors.map(_.closure)
562562
enqueueClosure(trees.tuple.join(elements: _*), headTensor.shape).asInstanceOf[Do[PendingBuffer[Float]]]
@@ -591,9 +591,9 @@ trait Tensors extends OpenCL {
591591
/**
592592
* @group delayed
593593
*/
594-
def notInline: BufferedTensor
594+
def nonInline: NonInlineTensor
595595

596-
private def reduce(programs: MonoidPrograms): BufferedTensor = {
596+
private def reduce(programs: MonoidPrograms): NonInlineTensor = {
597597
new {
598598
val padding: Float = thisTensor.padding
599599

@@ -683,7 +683,7 @@ trait Tensors extends OpenCL {
683683
}
684684
}
685685
}.shared
686-
} with BufferedTensor {
686+
} with NonInlineTensor {
687687
def shape: Array[Int] = Tensors.ScalarShape
688688
}
689689
}
@@ -771,15 +771,15 @@ trait Tensors extends OpenCL {
771771
/**
772772
* @group delayed
773773
*/
774-
def reshape(newShape: Array[Int]): BufferedTensor = {
774+
def reshape(newShape: Array[Int]): NonInlineTensor = {
775775
if (newShape.product != shape.product) {
776776
throw new IllegalArgumentException
777777
}
778778
new {
779779
val padding: Float = thisTensor.padding
780780
val shape: Array[Int] = newShape
781781
private[compute] val doBuffer: Do[PendingBuffer[Float]] = thisTensor.doBuffer
782-
} with BufferedTensor
782+
} with NonInlineTensor
783783
}
784784

785785
/**
@@ -993,82 +993,6 @@ trait Tensors extends OpenCL {
993993

994994
private[compute] def doBuffer: Do[PendingBuffer[closure.JvmValue]]
995995

996-
/** Allocates device-side cache that are managed by the [[https://github.com/ThoughtWorksInc/RAII.scala RAII.scala]] library.
997-
*
998-
* @note This method is similar to [[cache]],
999-
* except the life cycle of the cache can be automatically managed.
1000-
*
1001-
* @group slow
1002-
*/
1003-
def doCache: Do[this.type] = doBuffer.map(Function.const(this))
1004-
1005-
/** Allocates device-side cache for this [[Tensor]], and returns a [[java.lang.AutoCloseable]] to release the cache.
1006-
*
1007-
* @note This method can be called multiple times on one [[Tensor]],
1008-
* only one copy of cache will be allocated,
1009-
* which will be finally released until all [[java.lang.AutoCloseable]] returned by [[cache]] method are closed.
1010-
*
1011-
* @group slow
1012-
*/
1013-
def cache: AutoCloseable = {
1014-
sealed trait State
1015-
case object Openning extends State
1016-
case object EarlyClosed extends State
1017-
case object Closed extends State
1018-
final case class Open(release: UnitContinuation[Unit]) extends State
1019-
1020-
val state = new AtomicReference[State](Openning) with AutoCloseable {
1021-
@tailrec
1022-
final def close(): Unit = {
1023-
get match {
1024-
case Openning =>
1025-
if (compareAndSet(Openning, EarlyClosed)) {
1026-
// Success
1027-
} else {
1028-
close()
1029-
}
1030-
case oldState @ Open(release) =>
1031-
if (compareAndSet(oldState, Closed)) {
1032-
release.safeOnComplete { _: Unit =>
1033-
Trampoline.done(())
1034-
}.run
1035-
} else {
1036-
close()
1037-
}
1038-
case EarlyClosed | Closed =>
1039-
throw new IllegalStateException("The resources associated to this tensor has been released.")
1040-
}
1041-
}
1042-
}
1043-
1044-
doBuffer.safeOnComplete { resource =>
1045-
@tailrec
1046-
def retry(): Trampoline[Unit] = {
1047-
state.get() match {
1048-
case EarlyClosed =>
1049-
if (state.compareAndSet(EarlyClosed, Closed)) {
1050-
resource.release.safeOnComplete { _: Unit =>
1051-
Trampoline.done(())
1052-
}
1053-
} else {
1054-
retry()
1055-
}
1056-
case Openning =>
1057-
if (state.compareAndSet(Openning, Open(resource.release))) {
1058-
Trampoline.done(())
1059-
} else {
1060-
retry()
1061-
}
1062-
case _: Open | Closed =>
1063-
throw new IllegalStateException()
1064-
}
1065-
}
1066-
retry()
1067-
}.run
1068-
1069-
state
1070-
}
1071-
1072996
/**
1073997
* @group slow
1074998
*/
@@ -1225,12 +1149,12 @@ trait Tensors extends OpenCL {
12251149
enqueueClosure(closure, shape)
12261150
}.shared
12271151

1228-
def notInline: BufferedTensor =
1152+
def nonInline: NonInlineTensor =
12291153
new {
12301154
val padding: Float = thisInlineTensor.padding
12311155
private[compute] val doBuffer: Do[PendingBuffer[Float]] = thisInlineTensor.doBuffer
12321156
val shape: Array[Int] = thisInlineTensor.shape
1233-
} with BufferedTensor
1157+
} with NonInlineTensor
12341158
}
12351159

12361160
trait TransformedTensor extends InlineTensor {
@@ -1250,14 +1174,91 @@ trait Tensors extends OpenCL {
12501174

12511175
}
12521176

1253-
trait BufferedTensor extends Tensor {
1177+
trait NonInlineTensor extends Tensor {
12541178

1255-
def notInline: BufferedTensor = this
1179+
def nonInline: this.type = this
12561180

12571181
@transient
12581182
protected lazy val closure = {
12591183
arrayTerm.extract
12601184
}
1185+
1186+
/** Allocates device-side cache that are managed by the [[https://github.com/ThoughtWorksInc/RAII.scala RAII.scala]] library.
1187+
*
1188+
* @note This method is similar to [[cache]],
1189+
* except the life cycle of the cache can be automatically managed.
1190+
*
1191+
* @group slow
1192+
*/
1193+
def doCache: Do[this.type] = doBuffer.map(Function.const(this))
1194+
1195+
/** Allocates device-side cache for this [[Tensor]], and returns a [[java.lang.AutoCloseable]] to release the cache.
1196+
*
1197+
* @note This method can be called multiple times on one [[Tensor]].
1198+
* Only one copy of cache will be allocated,
1199+
* which will be finally released until all [[java.lang.AutoCloseable]] returned by [[cache]] method are closed.
1200+
*
1201+
* @group slow
1202+
*/
1203+
def cache: AutoCloseable = {
1204+
sealed trait State
1205+
case object Openning extends State
1206+
case object EarlyClosed extends State
1207+
case object Closed extends State
1208+
final case class Open(release: UnitContinuation[Unit]) extends State
1209+
1210+
val state = new AtomicReference[State](Openning) with AutoCloseable {
1211+
@tailrec
1212+
final def close(): Unit = {
1213+
get match {
1214+
case Openning =>
1215+
if (compareAndSet(Openning, EarlyClosed)) {
1216+
// Success
1217+
} else {
1218+
close()
1219+
}
1220+
case oldState @ Open(release) =>
1221+
if (compareAndSet(oldState, Closed)) {
1222+
release.safeOnComplete { _: Unit =>
1223+
Trampoline.done(())
1224+
}.run
1225+
} else {
1226+
close()
1227+
}
1228+
case EarlyClosed | Closed =>
1229+
throw new IllegalStateException("The resources associated to this tensor has been released.")
1230+
}
1231+
}
1232+
}
1233+
1234+
doBuffer.safeOnComplete { resource =>
1235+
@tailrec
1236+
def retry(): Trampoline[Unit] = {
1237+
state.get() match {
1238+
case EarlyClosed =>
1239+
if (state.compareAndSet(EarlyClosed, Closed)) {
1240+
resource.release.safeOnComplete { _: Unit =>
1241+
Trampoline.done(())
1242+
}
1243+
} else {
1244+
retry()
1245+
}
1246+
case Openning =>
1247+
if (state.compareAndSet(Openning, Open(resource.release))) {
1248+
Trampoline.done(())
1249+
} else {
1250+
retry()
1251+
}
1252+
case _: Open | Closed =>
1253+
throw new IllegalStateException()
1254+
}
1255+
}
1256+
retry()
1257+
}.run
1258+
1259+
state
1260+
}
1261+
12611262
}
12621263

12631264
}

benchmarks/src/jmh/scala/com/thoughtworks/compute/benchmarks.scala

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -105,9 +105,9 @@ object benchmarks {
105105
}
106106

107107
def doBenchmark(): Do[() => Array[Float]] = {
108-
val weight: BufferedTensor = Tensor.randomNormal(Array(inputDepth, outputDepth))
108+
val weight: NonInlineTensor = Tensor.randomNormal(Array(inputDepth, outputDepth))
109109

110-
val input: BufferedTensor = Tensor.randomNormal(Array(batchSize, inputDepth))
110+
val input: NonInlineTensor = Tensor.randomNormal(Array(batchSize, inputDepth))
111111

112112
weight.doCache.flatMap { weight =>
113113
input.doCache.map { input =>
@@ -233,7 +233,7 @@ object benchmarks {
233233
trait Benchmarks extends BenchmarkTensors {
234234

235235
def doBenchmark(): Do[() => Float] = {
236-
val input: BufferedTensor = Tensor.randomNormal(Array.fill(numberOfDimensions)(size))
236+
val input: NonInlineTensor = Tensor.randomNormal(Array.fill(numberOfDimensions)(size))
237237

238238
input.doCache.map { input =>
239239
{ () =>
@@ -365,7 +365,7 @@ object benchmarks {
365365

366366
trait Benchmarks extends BenchmarkTensors {
367367

368-
final case class ConvolutionalLayer(weight: BufferedTensor, bias: BufferedTensor) {
368+
final case class ConvolutionalLayer(weight: NonInlineTensor, bias: NonInlineTensor) {
369369
def forward(input: Tensor): Tensor = {
370370
convolute(input, weight, bias)
371371
}
@@ -467,7 +467,7 @@ object benchmarks {
467467
}
468468

469469
def doBenchmark(): Do[() => Array[Float]] = {
470-
val input: BufferedTensor = Tensor.randomNormal(Array(batchSize, imageHeight, imageWidth, depth))
470+
val input: NonInlineTensor = Tensor.randomNormal(Array(batchSize, imageHeight, imageWidth, depth))
471471
val layers = (for (i <- (0 until numberOfLayers).view) yield {
472472
ConvolutionalLayer(weight = Tensor.randomNormal(Array(kernelHeight, kernelWidth, depth, depth)),
473473
bias = Tensor.randomNormal(Array(depth)))

0 commit comments

Comments
 (0)