@@ -422,12 +422,12 @@ trait Tensors extends OpenCL {
422
422
}
423
423
424
424
def apply [A ](elements : A , padding : Float = 0.0f )(
425
- implicit tensorBuilder : TensorBuilder .Aux [A , Float ]): BufferedTensor = {
425
+ implicit tensorBuilder : TensorBuilder .Aux [A , Float ]): NonInlineTensor = {
426
426
val padding0 = padding
427
427
new {
428
428
val shape : Array [Int ] = tensorBuilder.shape(elements).toArray
429
429
val padding : Float = padding0
430
- } with BufferedTensor {
430
+ } with NonInlineTensor {
431
431
private [compute] val doBuffer = {
432
432
Do (TryT (ResourceT (UnitContinuation .delay {
433
433
val data = tensorBuilder.flatten(elements).toArray
@@ -454,13 +454,13 @@ trait Tensors extends OpenCL {
454
454
} with InlineTensor
455
455
}
456
456
457
- def random (shape : Array [Int ], seed : Int = Random .nextInt(), padding : Float = 0.0f ): BufferedTensor = {
457
+ def random (shape : Array [Int ], seed : Int = Random .nextInt(), padding : Float = 0.0f ): NonInlineTensor = {
458
458
val shape0 = shape
459
459
val padding0 = padding
460
460
new {
461
461
val padding = padding0
462
462
val shape = shape0
463
- } with BufferedTensor {
463
+ } with NonInlineTensor {
464
464
private [compute] val doBuffer : Do [PendingBuffer [Float ]] = {
465
465
val size = shape.product
466
466
allocateBuffer[Float ](size).flatMap { buffer =>
@@ -475,13 +475,13 @@ trait Tensors extends OpenCL {
475
475
}
476
476
477
477
/** Generate random numbers in normal distribution. */
478
- def randomNormal (shape : Array [Int ], seed : Int = Random .nextInt(), padding : Float = 0.0f ): BufferedTensor = {
478
+ def randomNormal (shape : Array [Int ], seed : Int = Random .nextInt(), padding : Float = 0.0f ): NonInlineTensor = {
479
479
val shape0 = shape
480
480
val padding0 = padding
481
481
new {
482
482
val padding = padding0
483
483
val shape = shape0
484
- } with BufferedTensor {
484
+ } with NonInlineTensor {
485
485
private [compute] val doBuffer : Do [PendingBuffer [Float ]] = {
486
486
val size = shape.product
487
487
val paddingSize = if (size % 2 == 1 ) {
@@ -541,7 +541,7 @@ trait Tensors extends OpenCL {
541
541
}
542
542
}
543
543
544
- def join (tensors0 : Seq [Tensor ]): BufferedTensor = {
544
+ def join (tensors0 : Seq [Tensor ]): NonInlineTensor = {
545
545
def force [A ](seq : Seq [A ]) = {
546
546
seq match {
547
547
case seqView : SeqView [A , _] @ unchecked =>
@@ -556,7 +556,7 @@ trait Tensors extends OpenCL {
556
556
new {
557
557
val shape = headTensor.shape :+ tensors.length
558
558
val padding : Float = headTensor.padding
559
- } with BufferedTensor {
559
+ } with NonInlineTensor {
560
560
private [compute] val doBuffer = {
561
561
val elements = tensors.map(_.closure)
562
562
enqueueClosure(trees.tuple.join(elements : _* ), headTensor.shape).asInstanceOf [Do [PendingBuffer [Float ]]]
@@ -591,9 +591,9 @@ trait Tensors extends OpenCL {
591
591
/**
592
592
* @group delayed
593
593
*/
594
- def notInline : BufferedTensor
594
+ def nonInline : NonInlineTensor
595
595
596
- private def reduce (programs : MonoidPrograms ): BufferedTensor = {
596
+ private def reduce (programs : MonoidPrograms ): NonInlineTensor = {
597
597
new {
598
598
val padding : Float = thisTensor.padding
599
599
@@ -683,7 +683,7 @@ trait Tensors extends OpenCL {
683
683
}
684
684
}
685
685
}.shared
686
- } with BufferedTensor {
686
+ } with NonInlineTensor {
687
687
def shape : Array [Int ] = Tensors .ScalarShape
688
688
}
689
689
}
@@ -771,15 +771,15 @@ trait Tensors extends OpenCL {
771
771
/**
772
772
* @group delayed
773
773
*/
774
- def reshape (newShape : Array [Int ]): BufferedTensor = {
774
+ def reshape (newShape : Array [Int ]): NonInlineTensor = {
775
775
if (newShape.product != shape.product) {
776
776
throw new IllegalArgumentException
777
777
}
778
778
new {
779
779
val padding : Float = thisTensor.padding
780
780
val shape : Array [Int ] = newShape
781
781
private [compute] val doBuffer : Do [PendingBuffer [Float ]] = thisTensor.doBuffer
782
- } with BufferedTensor
782
+ } with NonInlineTensor
783
783
}
784
784
785
785
/**
@@ -993,82 +993,6 @@ trait Tensors extends OpenCL {
993
993
994
994
private [compute] def doBuffer : Do [PendingBuffer [closure.JvmValue ]]
995
995
996
- /** Allocates device-side cache that are managed by the [[https://github.com/ThoughtWorksInc/RAII.scala RAII.scala ]] library.
997
- *
998
- * @note This method is similar to [[cache ]],
999
- * except the life cycle of the cache can be automatically managed.
1000
- *
1001
- * @group slow
1002
- */
1003
- def doCache : Do [this .type ] = doBuffer.map(Function .const(this ))
1004
-
1005
- /** Allocates device-side cache for this [[Tensor ]], and returns a [[java.lang.AutoCloseable ]] to release the cache.
1006
- *
1007
- * @note This method can be called multiple times on one [[Tensor ]],
1008
- * only one copy of cache will be allocated,
1009
- * which will be finally released until all [[java.lang.AutoCloseable ]] returned by [[cache ]] method are closed.
1010
- *
1011
- * @group slow
1012
- */
1013
- def cache : AutoCloseable = {
1014
- sealed trait State
1015
- case object Openning extends State
1016
- case object EarlyClosed extends State
1017
- case object Closed extends State
1018
- final case class Open (release : UnitContinuation [Unit ]) extends State
1019
-
1020
- val state = new AtomicReference [State ](Openning ) with AutoCloseable {
1021
- @ tailrec
1022
- final def close (): Unit = {
1023
- get match {
1024
- case Openning =>
1025
- if (compareAndSet(Openning , EarlyClosed )) {
1026
- // Success
1027
- } else {
1028
- close()
1029
- }
1030
- case oldState @ Open (release) =>
1031
- if (compareAndSet(oldState, Closed )) {
1032
- release.safeOnComplete { _ : Unit =>
1033
- Trampoline .done(())
1034
- }.run
1035
- } else {
1036
- close()
1037
- }
1038
- case EarlyClosed | Closed =>
1039
- throw new IllegalStateException (" The resources associated to this tensor has been released." )
1040
- }
1041
- }
1042
- }
1043
-
1044
- doBuffer.safeOnComplete { resource =>
1045
- @ tailrec
1046
- def retry (): Trampoline [Unit ] = {
1047
- state.get() match {
1048
- case EarlyClosed =>
1049
- if (state.compareAndSet(EarlyClosed , Closed )) {
1050
- resource.release.safeOnComplete { _ : Unit =>
1051
- Trampoline .done(())
1052
- }
1053
- } else {
1054
- retry()
1055
- }
1056
- case Openning =>
1057
- if (state.compareAndSet(Openning , Open (resource.release))) {
1058
- Trampoline .done(())
1059
- } else {
1060
- retry()
1061
- }
1062
- case _ : Open | Closed =>
1063
- throw new IllegalStateException ()
1064
- }
1065
- }
1066
- retry()
1067
- }.run
1068
-
1069
- state
1070
- }
1071
-
1072
996
/**
1073
997
* @group slow
1074
998
*/
@@ -1225,12 +1149,12 @@ trait Tensors extends OpenCL {
1225
1149
enqueueClosure(closure, shape)
1226
1150
}.shared
1227
1151
1228
- def notInline : BufferedTensor =
1152
+ def nonInline : NonInlineTensor =
1229
1153
new {
1230
1154
val padding : Float = thisInlineTensor.padding
1231
1155
private [compute] val doBuffer : Do [PendingBuffer [Float ]] = thisInlineTensor.doBuffer
1232
1156
val shape : Array [Int ] = thisInlineTensor.shape
1233
- } with BufferedTensor
1157
+ } with NonInlineTensor
1234
1158
}
1235
1159
1236
1160
trait TransformedTensor extends InlineTensor {
@@ -1250,14 +1174,91 @@ trait Tensors extends OpenCL {
1250
1174
1251
1175
}
1252
1176
1253
- trait BufferedTensor extends Tensor {
1177
+ trait NonInlineTensor extends Tensor {
1254
1178
1255
- def notInline : BufferedTensor = this
1179
+ def nonInline : this . type = this
1256
1180
1257
1181
@ transient
1258
1182
protected lazy val closure = {
1259
1183
arrayTerm.extract
1260
1184
}
1185
+
1186
+ /** Allocates device-side cache that are managed by the [[https://github.com/ThoughtWorksInc/RAII.scala RAII.scala ]] library.
1187
+ *
1188
+ * @note This method is similar to [[cache ]],
1189
+ * except the life cycle of the cache can be automatically managed.
1190
+ *
1191
+ * @group slow
1192
+ */
1193
+ def doCache : Do [this .type ] = doBuffer.map(Function .const(this ))
1194
+
1195
+ /** Allocates device-side cache for this [[Tensor ]], and returns a [[java.lang.AutoCloseable ]] to release the cache.
1196
+ *
1197
+ * @note This method can be called multiple times on one [[Tensor ]].
1198
+ * Only one copy of cache will be allocated,
1199
+ * which will be finally released until all [[java.lang.AutoCloseable ]] returned by [[cache ]] method are closed.
1200
+ *
1201
+ * @group slow
1202
+ */
1203
+ def cache : AutoCloseable = {
1204
+ sealed trait State
1205
+ case object Openning extends State
1206
+ case object EarlyClosed extends State
1207
+ case object Closed extends State
1208
+ final case class Open (release : UnitContinuation [Unit ]) extends State
1209
+
1210
+ val state = new AtomicReference [State ](Openning ) with AutoCloseable {
1211
+ @ tailrec
1212
+ final def close (): Unit = {
1213
+ get match {
1214
+ case Openning =>
1215
+ if (compareAndSet(Openning , EarlyClosed )) {
1216
+ // Success
1217
+ } else {
1218
+ close()
1219
+ }
1220
+ case oldState @ Open (release) =>
1221
+ if (compareAndSet(oldState, Closed )) {
1222
+ release.safeOnComplete { _ : Unit =>
1223
+ Trampoline .done(())
1224
+ }.run
1225
+ } else {
1226
+ close()
1227
+ }
1228
+ case EarlyClosed | Closed =>
1229
+ throw new IllegalStateException (" The resources associated to this tensor has been released." )
1230
+ }
1231
+ }
1232
+ }
1233
+
1234
+ doBuffer.safeOnComplete { resource =>
1235
+ @ tailrec
1236
+ def retry (): Trampoline [Unit ] = {
1237
+ state.get() match {
1238
+ case EarlyClosed =>
1239
+ if (state.compareAndSet(EarlyClosed , Closed )) {
1240
+ resource.release.safeOnComplete { _ : Unit =>
1241
+ Trampoline .done(())
1242
+ }
1243
+ } else {
1244
+ retry()
1245
+ }
1246
+ case Openning =>
1247
+ if (state.compareAndSet(Openning , Open (resource.release))) {
1248
+ Trampoline .done(())
1249
+ } else {
1250
+ retry()
1251
+ }
1252
+ case _ : Open | Closed =>
1253
+ throw new IllegalStateException ()
1254
+ }
1255
+ }
1256
+ retry()
1257
+ }.run
1258
+
1259
+ state
1260
+ }
1261
+
1261
1262
}
1262
1263
1263
1264
}
0 commit comments