Skip to content

Commit a4acf04

Browse files
committed
Refactoring and API polish
1 parent 9b6be99 commit a4acf04

File tree

4 files changed

+63
-78
lines changed

4 files changed

+63
-78
lines changed

source/ports/scala_port/src/main/scala/Caller.scala

Lines changed: 58 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ import scala.concurrent.Future
55
import scala.concurrent.ExecutionContext
66

77
import com.sun.jna._, ptr.PointerByReference
8-
import java.util.concurrent.{LinkedBlockingQueue, ConcurrentHashMap}
8+
import java.util.concurrent.{ConcurrentLinkedQueue, ConcurrentHashMap}
99
import java.util.concurrent.atomic.{AtomicBoolean, AtomicInteger}
1010

1111
/** `Caller` creates a new thread on which:
@@ -29,17 +29,23 @@ object Caller {
2929

3030
private case class UniqueCall(call: Call, id: Int)
3131

32+
private case class LoadCommand(
33+
namespace: Option[String],
34+
runtime: Runtime,
35+
filePaths: Vector[String]
36+
)
37+
38+
private val runningInMetacall = System.getProperty("java.polyglot.name") == "metacall"
39+
3240
private def callLoop() = {
33-
if (System.getProperty("java.polyglot.name") != "metacall")
41+
if (!runningInMetacall)
3442
Bindings.instance.metacall_initialize()
3543

3644
while (!closed.get) try {
3745
if (!scriptsQueue.isEmpty()) {
38-
val Script(filePath, runtime, namespace) = scriptsQueue.take()
46+
val LoadCommand(namespace, runtime, paths) = scriptsQueue.poll()
3947
val handleRef = namespace.map(_ => new PointerByReference())
40-
41-
Loader.loadFileUnsafe(runtime, filePath, handleRef)
42-
48+
Loader.loadFilesUnsafe(runtime, paths, handleRef)
4349
handleRef.zip(namespace) match {
4450
case Some((handleRef, namespace)) =>
4551
namespaceHandles.put(
@@ -48,48 +54,59 @@ object Caller {
4854
)
4955
case None => ()
5056
}
51-
}
52-
53-
if (!callQueue.isEmpty() && scriptsQueue.isEmpty()) {
54-
val UniqueCall(Call(namespace, fnName, args), id) = callQueue.take()
57+
} else if (!callQueue.isEmpty()) {
58+
val UniqueCall(Call(namespace, fnName, args), id) = callQueue.poll()
5559
val result = callUnsafe(namespace, fnName, args)
5660
callResultMap.put(id, result)
5761
}
5862
} catch {
59-
case e: Throwable => {
60-
Console.err.println(e)
61-
// TODO: Add a `setOnError` method and call it here
62-
}
63+
case e: Throwable => Console.err.println(e)
6364
}
6465

65-
if (System.getProperty("java.polyglot.name") != "metacall")
66+
if (!runningInMetacall)
6667
Bindings.instance.metacall_destroy()
6768
}
6869

6970
private val closed = new AtomicBoolean(false)
70-
private val callQueue = new LinkedBlockingQueue[UniqueCall]()
71+
private val callQueue = new ConcurrentLinkedQueue[UniqueCall]()
7172
private val callResultMap = new ConcurrentHashMap[Int, Value]()
7273
private val callCounter = new AtomicInteger(0)
73-
private val scriptsQueue = new LinkedBlockingQueue[Script]()
74+
private val scriptsQueue = new ConcurrentLinkedQueue[LoadCommand]()
7475
private val namespaceHandles =
7576
new ConcurrentHashMap[String, PointerByReference]()
7677

77-
def loadFile(runtime: Runtime, filePath: String, namespace: Option[String]): Unit = {
78-
if (closed.get())
79-
throw new Exception(s"Trying to load script $filePath while the caller is closed")
78+
def loadFiles(
79+
runtime: Runtime,
80+
filePaths: Vector[String],
81+
namespace: Option[String] = None
82+
): Unit = {
83+
if (closed.get()) {
84+
val scriptsStr =
85+
if (filePaths.length == 1) "script " + filePaths.head
86+
else "scripts " + filePaths.mkString(", ")
87+
throw new Exception(
88+
s"Trying to load scripts $scriptsStr while the caller is closed"
89+
)
90+
}
8091

81-
scriptsQueue.put(Script(filePath, runtime, namespace))
92+
scriptsQueue.add(LoadCommand(namespace, runtime, filePaths))
8293
while (!scriptsQueue.isEmpty()) ()
8394
}
8495

96+
def loadFile(
97+
runtime: Runtime,
98+
filePath: String,
99+
namespace: Option[String] = None
100+
): Unit = loadFiles(runtime, Vector(filePath), namespace)
101+
85102
def loadFile(runtime: Runtime, filePath: String, namespace: String): Unit =
86103
loadFile(runtime, filePath, Some(namespace))
87104

88105
def loadFile(runtime: Runtime, filePath: String): Unit =
89106
loadFile(runtime, filePath, None)
90107

91108
def start(): Unit = {
92-
if (System.getProperty("java.polyglot.name") != "metacall")
109+
if (!runningInMetacall)
93110
new Thread(() => callLoop()).start()
94111
else
95112
callLoop()
@@ -99,9 +116,9 @@ object Caller {
99116

100117
/** Calls a loaded function.
101118
* WARNING: Should only be used from within the caller thread.
102-
* @param namespace The script/module file where the function is defined
103119
* @param fnName The name of the function to call
104120
* @param args A list of `Value`s to be passed as arguments to the function
121+
* @param namespace The script/module file where the function is defined
105122
* @return The function's return value, or `InvalidValue` in case of an error
106123
*/
107124
private def callUnsafe(
@@ -146,57 +163,51 @@ object Caller {
146163
}
147164

148165
/** Calls a loaded function.
149-
* @param namespace The script/module file where the function is defined
150166
* @param fnName The name of the function to call
151167
* @param args A list of `Value`s to be passed as arguments to the function
168+
* @param namespace The script/module file where the function is defined
152169
* @return The function's return value, or `InvalidValue` in case of an error
153170
*/
154-
def callV(namespace: Option[String], fnName: String, args: List[Value])(implicit
171+
def callV(fnName: String, args: List[Value], namespace: Option[String] = None)(implicit
155172
ec: ExecutionContext
156173
): Future[Value] =
157-
Future(blocking.callV(namespace, fnName, args))
174+
Future(blocking.callV(fnName, args, namespace))
158175

159-
def call[A](namespace: Option[String], fnName: String, args: A)(implicit
176+
def call[A](fnName: String, args: A, namespace: Option[String] = None)(implicit
160177
AA: Args[A],
161178
ec: ExecutionContext
162179
): Future[Value] =
163-
Future(blocking.call[A](namespace, fnName, args))
164-
165-
def call[A](fnName: String, args: A)(implicit
166-
AA: Args[A],
167-
ec: ExecutionContext
168-
): Future[Value] =
169-
call[A](None, fnName, args)
180+
callV(fnName, AA.from(args), namespace)
170181

171182
def call[A](namespace: String, fnName: String, args: A)(implicit
172183
AA: Args[A],
173184
ec: ExecutionContext
174-
): Future[Value] =
175-
call[A](Some(namespace), fnName, args)
185+
): Future[Value] = call[A](fnName, args, Some(namespace))
176186

177187
/** Blocking versions of the methods on [[Caller]]. Do not use them if you don't *need* to. */
178188
object blocking {
179189

180190
/** Calls a loaded function.
181-
* @param namespace The script/module file where the function is defined
182191
* @param fnName The name of the function to call
183192
* @param args A list of `Value`s to be passed as arguments to the function
193+
* @param namespace The script/module file where the function is defined
184194
* @return The function's return value, or `InvalidValue` in case of an error
185195
*/
186-
def callV(namespace: Option[String], fnName: String, args: List[Value]): Value = {
196+
def callV(
197+
fnName: String,
198+
args: List[Value],
199+
namespace: Option[String] = None
200+
): Value = {
187201
val call = Call(namespace, fnName, args)
188202

189-
// TODO: This trick works but it may overflow, we should do a test
190-
// executing calls to a function without parameters nor return, with a size
191-
// greater of sizeof(int), for example Integer.MAX_VALUE + 15, in order to see what happens
192203
val callId = callCounter.getAndIncrement()
193204

194-
if (callId == Int.MaxValue)
205+
if (callId == Int.MaxValue - 1)
195206
callCounter.set(0)
196207

197208
val uniqueCall = UniqueCall(call, callId)
198209

199-
callQueue.put(uniqueCall)
210+
callQueue.add(uniqueCall)
200211

201212
var result: Value = null
202213

@@ -209,34 +220,17 @@ object Caller {
209220
}
210221

211222
/** Calls a loaded function
212-
* @param namespace The script/module file where the function is defined
213223
* @param fnName The name of the function to call
214224
* @param args A product (tuple, case class, single value) to be passed as arguments to the function
225+
* @param namespace The script/module file where the function is defined
215226
* @return The function's return value, or `InvalidValue` in case of an error
216227
*/
217228
def call[A](
218-
namespace: Option[String],
219229
fnName: String,
220-
args: A
230+
args: A,
231+
namespace: Option[String] = None
221232
)(implicit AA: Args[A]): Value =
222-
callV(namespace, fnName, AA.from(args))
223-
224-
/** Calls a loaded function.
225-
* @param fnName The name of the function to call
226-
* @param args A product (tuple, case class, single value) to be passed as arguments to the function
227-
* @return The function's return value, or `InvalidValue` in case of an error
228-
*/
229-
def call[A](fnName: String, args: A)(implicit AA: Args[A]): Value =
230-
call[A](None, fnName, args)
231-
232-
/** Calls a loaded function.
233-
* @param namespace The script/module file where the function is defined
234-
* @param fnName The name of the function to call
235-
* @param args A product (tuple, case class, single value) to be passed as arguments to the function
236-
* @return The function's return value, or `InvalidValue` in case of an error
237-
*/
238-
def call[A](namespace: String, fnName: String, args: A)(implicit AA: Args[A]): Value =
239-
call[A](Some(namespace), fnName, args)
233+
blocking.callV(fnName, AA.from(args), namespace)
240234

241235
}
242236

source/ports/scala_port/src/main/scala/Loader.scala

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -32,14 +32,6 @@ private[metacall] object Loader {
3232
if (code != 0)
3333
throw new Exception("Failed to load scripts: " + filePaths.mkString(" "))
3434
}
35-
36-
def loadFileUnsafe(
37-
runtime: Runtime,
38-
filePath: String,
39-
handleRef: Option[PointerByReference]
40-
) =
41-
loadFilesUnsafe(runtime, Vector(filePath), handleRef)
42-
4335
}
4436

4537
sealed trait Runtime

source/ports/scala_port/src/main/scala/util.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ package metacall
33
import com.sun.jna._
44

55
object util {
6-
case class Script(filePath: String, runtime: Runtime, namespace: Option[String])
6+
case class Script(runtime: Runtime, filePath: String)
77

88
private[metacall] class SizeT(value: Long)
99
extends IntegerType(Native.SIZE_T_SIZE, value) {

source/ports/scala_port/src/test/scala/MetaCallSpec.scala

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -481,13 +481,12 @@ class MetaCallSpec extends AnyFlatSpec {
481481
Caller.loadFile(Runtime.Python, "./src/test/scala/scripts/s2.py", Some("s2"))
482482

483483
assert(
484-
Caller.blocking.call(Some("s1"), "fn_in_s1", ()) == StringValue("Hello from s1")
484+
Caller.blocking.call("fn_in_s1", (), Some("s1")) == StringValue("Hello from s1")
485485
)
486486
}
487487

488488
"Caller" should "call functions and clean up arguments and returned pointers" in {
489489
val ret = Caller.blocking.callV(
490-
None,
491490
"hello_scala_from_python",
492491
List(StringValue("Hello "), StringValue("Scala!"))
493492
)
@@ -501,7 +500,7 @@ class MetaCallSpec extends AnyFlatSpec {
501500
case _ => NullValue
502501
}
503502

504-
val ret = Caller.blocking.callV(None, "apply_fn_to_one", fnVal :: Nil)
503+
val ret = Caller.blocking.callV("apply_fn_to_one", fnVal :: Nil)
505504

506505
assert(ret == LongValue(2L))
507506
}
@@ -531,7 +530,7 @@ class MetaCallSpec extends AnyFlatSpec {
531530

532531
val resSum = rangeValues
533532
.traverse { range =>
534-
Future(Caller.blocking.callV(None, "sumList", range :: Nil)) map {
533+
Future(Caller.blocking.callV("sumList", range :: Nil)) map {
535534
case n: NumericValue[_] => n.long.value
536535
case other => fail("Returned value should be a number, but got " + other)
537536
}
@@ -552,7 +551,7 @@ class MetaCallSpec extends AnyFlatSpec {
552551

553552
val resSum = rangeValues
554553
.traverse { range =>
555-
Caller.callV(None, "sumList", range :: Nil) map {
554+
Caller.callV("sumList", range :: Nil) map {
556555
case n: NumericValue[_] => n.long.value
557556
case other => fail("Returned value should be a number, but got " + other)
558557
}

0 commit comments

Comments
 (0)