Skip to content

Commit a5fe637

Browse files
committed
Move cache and doCache method to BufferedTensor
1 parent aef2f6c commit a5fe637

File tree

1 file changed

+78
-77
lines changed

1 file changed

+78
-77
lines changed

Tensors/src/main/scala/com/thoughtworks/compute/Tensors.scala

Lines changed: 78 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -993,82 +993,6 @@ trait Tensors extends OpenCL {
993993

994994
private[compute] def doBuffer: Do[PendingBuffer[closure.JvmValue]]
995995

996-
/** Allocates device-side cache that are managed by the [[https://github.com/ThoughtWorksInc/RAII.scala RAII.scala]] library.
997-
*
998-
* @note This method is similar to [[cache]],
999-
* except the life cycle of the cache can be automatically managed.
1000-
*
1001-
* @group slow
1002-
*/
1003-
def doCache: Do[this.type] = doBuffer.map(Function.const(this))
1004-
1005-
/** Allocates device-side cache for this [[Tensor]], and returns a [[java.lang.AutoCloseable]] to release the cache.
1006-
*
1007-
* @note This method can be called multiple times on one [[Tensor]],
1008-
* only one copy of cache will be allocated,
1009-
* which will be finally released until all [[java.lang.AutoCloseable]] returned by [[cache]] method are closed.
1010-
*
1011-
* @group slow
1012-
*/
1013-
def cache: AutoCloseable = {
1014-
sealed trait State
1015-
case object Openning extends State
1016-
case object EarlyClosed extends State
1017-
case object Closed extends State
1018-
final case class Open(release: UnitContinuation[Unit]) extends State
1019-
1020-
val state = new AtomicReference[State](Openning) with AutoCloseable {
1021-
@tailrec
1022-
final def close(): Unit = {
1023-
get match {
1024-
case Openning =>
1025-
if (compareAndSet(Openning, EarlyClosed)) {
1026-
// Success
1027-
} else {
1028-
close()
1029-
}
1030-
case oldState @ Open(release) =>
1031-
if (compareAndSet(oldState, Closed)) {
1032-
release.safeOnComplete { _: Unit =>
1033-
Trampoline.done(())
1034-
}.run
1035-
} else {
1036-
close()
1037-
}
1038-
case EarlyClosed | Closed =>
1039-
throw new IllegalStateException("The resources associated to this tensor has been released.")
1040-
}
1041-
}
1042-
}
1043-
1044-
doBuffer.safeOnComplete { resource =>
1045-
@tailrec
1046-
def retry(): Trampoline[Unit] = {
1047-
state.get() match {
1048-
case EarlyClosed =>
1049-
if (state.compareAndSet(EarlyClosed, Closed)) {
1050-
resource.release.safeOnComplete { _: Unit =>
1051-
Trampoline.done(())
1052-
}
1053-
} else {
1054-
retry()
1055-
}
1056-
case Openning =>
1057-
if (state.compareAndSet(Openning, Open(resource.release))) {
1058-
Trampoline.done(())
1059-
} else {
1060-
retry()
1061-
}
1062-
case _: Open | Closed =>
1063-
throw new IllegalStateException()
1064-
}
1065-
}
1066-
retry()
1067-
}.run
1068-
1069-
state
1070-
}
1071-
1072996
/**
1073997
* @group slow
1074998
*/
@@ -1252,12 +1176,89 @@ trait Tensors extends OpenCL {
12521176

12531177
trait BufferedTensor extends Tensor {
12541178

1255-
def notInline: BufferedTensor = this
1179+
def notInline: this.type = this
12561180

12571181
@transient
12581182
protected lazy val closure = {
12591183
arrayTerm.extract
12601184
}
1185+
1186+
/** Allocates device-side cache that are managed by the [[https://github.com/ThoughtWorksInc/RAII.scala RAII.scala]] library.
1187+
*
1188+
* @note This method is similar to [[cache]],
1189+
* except the life cycle of the cache can be automatically managed.
1190+
*
1191+
* @group slow
1192+
*/
1193+
def doCache: Do[this.type] = doBuffer.map(Function.const(this))
1194+
1195+
/** Allocates device-side cache for this [[Tensor]], and returns a [[java.lang.AutoCloseable]] to release the cache.
1196+
*
1197+
* @note This method can be called multiple times on one [[Tensor]].
1198+
* Only one copy of cache will be allocated,
1199+
* which will be finally released until all [[java.lang.AutoCloseable]] returned by [[cache]] method are closed.
1200+
*
1201+
* @group slow
1202+
*/
1203+
def cache: AutoCloseable = {
1204+
sealed trait State
1205+
case object Openning extends State
1206+
case object EarlyClosed extends State
1207+
case object Closed extends State
1208+
final case class Open(release: UnitContinuation[Unit]) extends State
1209+
1210+
val state = new AtomicReference[State](Openning) with AutoCloseable {
1211+
@tailrec
1212+
final def close(): Unit = {
1213+
get match {
1214+
case Openning =>
1215+
if (compareAndSet(Openning, EarlyClosed)) {
1216+
// Success
1217+
} else {
1218+
close()
1219+
}
1220+
case oldState @ Open(release) =>
1221+
if (compareAndSet(oldState, Closed)) {
1222+
release.safeOnComplete { _: Unit =>
1223+
Trampoline.done(())
1224+
}.run
1225+
} else {
1226+
close()
1227+
}
1228+
case EarlyClosed | Closed =>
1229+
throw new IllegalStateException("The resources associated to this tensor has been released.")
1230+
}
1231+
}
1232+
}
1233+
1234+
doBuffer.safeOnComplete { resource =>
1235+
@tailrec
1236+
def retry(): Trampoline[Unit] = {
1237+
state.get() match {
1238+
case EarlyClosed =>
1239+
if (state.compareAndSet(EarlyClosed, Closed)) {
1240+
resource.release.safeOnComplete { _: Unit =>
1241+
Trampoline.done(())
1242+
}
1243+
} else {
1244+
retry()
1245+
}
1246+
case Openning =>
1247+
if (state.compareAndSet(Openning, Open(resource.release))) {
1248+
Trampoline.done(())
1249+
} else {
1250+
retry()
1251+
}
1252+
case _: Open | Closed =>
1253+
throw new IllegalStateException()
1254+
}
1255+
}
1256+
retry()
1257+
}.run
1258+
1259+
state
1260+
}
1261+
12611262
}
12621263

12631264
}

0 commit comments

Comments
 (0)