11package cask .internal
22
33import java .io .{InputStream , PrintWriter , StringWriter }
4-
54import scala .collection .generic .CanBuildFrom
65import scala .collection .mutable
76import java .io .OutputStream
8-
7+ import java .lang .invoke .{MethodHandles , MethodType }
8+ import java .util .concurrent .{Executor , ExecutorService , ForkJoinPool , ThreadFactory }
99import scala .annotation .switch
1010import scala .concurrent .{ExecutionContext , Future , Promise }
11+ import scala .util .Try
12+ import scala .util .control .NonFatal
1113
1214object Util {
15+ private val lookup = MethodHandles .lookup()
16+
17+ import cask .util .Logger .Console .globalLogger
18+
19+ /**
20+ * Create a virtual thread executor with the given executor as the scheduler.
21+ * */
22+ def createVirtualThreadExecutor (executor : Executor ): Option [ExecutorService ] = {
23+ (for {
24+ factory <- Try (createVirtualThreadFactory(" cask-handler-executor" , executor))
25+ executor <- Try (createNewThreadPerTaskExecutor(factory))
26+ } yield executor).toOption
27+ }
28+
29+ /**
30+ * Create a default cask virtual thread executor if possible.
31+ * */
32+ def createDefaultCaskVirtualThreadExecutor : Option [ExecutorService ] = {
33+ for {
34+ scheduler <- getDefaultVirtualThreadScheduler
35+ executor <- createVirtualThreadExecutor(scheduler)
36+ } yield executor
37+ }
38+
39+ /**
40+ * Try to get the default virtual thread scheduler, or null if not supported.
41+ * */
42+ def getDefaultVirtualThreadScheduler : Option [ForkJoinPool ] = {
43+ try {
44+ val virtualThreadClass = Class .forName(" java.lang.VirtualThread" )
45+ val privateLookup = MethodHandles .privateLookupIn(virtualThreadClass, lookup)
46+ val defaultSchedulerField = privateLookup.findStaticVarHandle(virtualThreadClass, " DEFAULT_SCHEDULER" , classOf [ForkJoinPool ])
47+ Option (defaultSchedulerField.get().asInstanceOf [ForkJoinPool ])
48+ } catch {
49+ case NonFatal (e) =>
50+ // --add-opens java.base/java.lang=ALL-UNNAMED
51+ globalLogger.exception(e)
52+ None
53+ }
54+ }
55+
56+ def createNewThreadPerTaskExecutor (threadFactory : ThreadFactory ): ExecutorService = {
57+ try {
58+ val executorsClazz = ClassLoader .getSystemClassLoader.loadClass(" java.util.concurrent.Executors" )
59+ val newThreadPerTaskExecutorMethod = lookup.findStatic(
60+ executorsClazz,
61+ " newThreadPerTaskExecutor" ,
62+ MethodType .methodType(classOf [ExecutorService ], classOf [ThreadFactory ]))
63+ newThreadPerTaskExecutorMethod.invoke(threadFactory)
64+ .asInstanceOf [ExecutorService ]
65+ } catch {
66+ case NonFatal (e) =>
67+ globalLogger.exception(e)
68+ throw new UnsupportedOperationException (" Failed to create newThreadPerTaskExecutor." , e)
69+ }
70+ }
71+
72+ /**
73+ * Create a virtual thread factory with a executor, the executor will be used as the scheduler of
74+ * virtual thread.
75+ *
76+ * The executor should run task on platform threads.
77+ *
78+ * returns null if not supported.
79+ */
80+ def createVirtualThreadFactory (prefix : String ,
81+ executor : Executor ): ThreadFactory =
82+ try {
83+ val builderClass = ClassLoader .getSystemClassLoader.loadClass(" java.lang.Thread$Builder" )
84+ val ofVirtualClass = ClassLoader .getSystemClassLoader.loadClass(" java.lang.Thread$Builder$OfVirtual" )
85+ val ofVirtualMethod = lookup.findStatic(classOf [Thread ], " ofVirtual" , MethodType .methodType(ofVirtualClass))
86+ var builder = ofVirtualMethod.invoke()
87+ if (executor != null ) {
88+ val clazz = builder.getClass
89+ val privateLookup = MethodHandles .privateLookupIn(
90+ clazz,
91+ lookup
92+ )
93+ val schedulerFieldSetter = privateLookup
94+ .findSetter(clazz, " scheduler" , classOf [Executor ])
95+ schedulerFieldSetter.invoke(builder, executor)
96+ }
97+ val nameMethod = lookup.findVirtual(ofVirtualClass, " name" ,
98+ MethodType .methodType(ofVirtualClass, classOf [String ], classOf [Long ]))
99+ val factoryMethod = lookup.findVirtual(builderClass, " factory" , MethodType .methodType(classOf [ThreadFactory ]))
100+ builder = nameMethod.invoke(builder, prefix + " -virtual-thread-" , 0L )
101+ factoryMethod.invoke(builder).asInstanceOf [ThreadFactory ]
102+ } catch {
103+ case NonFatal (e) =>
104+ globalLogger.exception(e)
105+ // --add-opens java.base/java.lang=ALL-UNNAMED
106+ throw new UnsupportedOperationException (" Failed to create virtual thread factory." , e)
107+ }
108+
13109 def firstFutureOf [T ](futures : Seq [Future [T ]])(implicit ec : ExecutionContext ) = {
14110 val p = Promise [T ]
15111 futures.foreach(_.foreach(p.trySuccess))
16112 p.future
17113 }
114+
18115 /**
19- * Convert a string to a C&P-able literal. Basically
20- * copied verbatim from the uPickle source code.
21- */
116+ * Convert a string to a C&P-able literal. Basically
117+ * copied verbatim from the uPickle source code.
118+ */
22119 def literalize (s : IndexedSeq [Char ], unicode : Boolean = true ) = {
23120 val sb = new StringBuilder
24121 sb.append('"' )
@@ -47,29 +144,30 @@ object Util {
47144 def transferTo (in : InputStream , out : OutputStream ) = {
48145 val buffer = new Array [Byte ](8192 )
49146
50- while ({
51- in.read(buffer) match {
147+ while ( {
148+ in.read(buffer) match {
52149 case - 1 => false
53150 case n =>
54151 out.write(buffer, 0 , n)
55152 true
56153 }
57154 }) ()
58155 }
156+
59157 def pluralize (s : String , n : Int ) = {
60158 if (n == 1 ) s else s + " s"
61159 }
62160
63161 /**
64- * Splits a string into path segments; automatically removes all
65- * leading/trailing slashes, and ignores empty path segments.
66- *
67- * Written imperatively for performance since it's used all over the place.
68- */
162+ * Splits a string into path segments; automatically removes all
163+ * leading/trailing slashes, and ignores empty path segments.
164+ *
165+ * Written imperatively for performance since it's used all over the place.
166+ */
69167 def splitPath (p : String ): collection.IndexedSeq [String ] = {
70168 val pLength = p.length
71169 var i = 0
72- while (i < pLength && p(i) == '/' ) i += 1
170+ while (i < pLength && p(i) == '/' ) i += 1
73171 var segmentStart = i
74172 val out = mutable.ArrayBuffer .empty[String ]
75173
@@ -81,7 +179,7 @@ object Util {
81179 segmentStart = i + 1
82180 }
83181
84- while (i < pLength){
182+ while (i < pLength) {
85183 if (p(i) == '/' ) complete()
86184 i += 1
87185 }
@@ -96,33 +194,35 @@ object Util {
96194 pw.flush()
97195 trace.toString
98196 }
197+
99198 def softWrap (s : String , leftOffset : Int , maxWidth : Int ) = {
100199 val oneLine = s.linesIterator.mkString(" " ).split(' ' )
101200
102201 lazy val indent = " " * leftOffset
103202
104203 val output = new StringBuilder (oneLine.head)
105204 var currentLineWidth = oneLine.head.length
106- for (chunk <- oneLine.tail){
205+ for (chunk <- oneLine.tail) {
107206 val addedWidth = currentLineWidth + chunk.length + 1
108- if (addedWidth > maxWidth){
207+ if (addedWidth > maxWidth) {
109208 output.append(" \n " + indent)
110209 output.append(chunk)
111210 currentLineWidth = chunk.length
112- } else {
211+ } else {
113212 currentLineWidth = addedWidth
114213 output.append(' ' )
115214 output.append(chunk)
116215 }
117216 }
118217 output.mkString
119218 }
219+
120220 def sequenceEither [A , B , M [X ] <: TraversableOnce [X ]](in : M [Either [A , B ]])(
121221 implicit cbf : CanBuildFrom [M [Either [A , B ]], B , M [B ]]): Either [A , M [B ]] = {
122222 in.foldLeft[Either [A , mutable.Builder [B , M [B ]]]](Right (cbf(in))) {
123- case (acc, el) =>
124- for (a <- acc; e <- el) yield a += e
125- }
223+ case (acc, el) =>
224+ for (a <- acc; e <- el) yield a += e
225+ }
126226 .map(_.result())
127227 }
128228}
0 commit comments