Skip to content

Commit 3504e57

Browse files
committed
Use new onnxruntime-web npm package for JS build
1 parent 068705f commit 3504e57

File tree

3 files changed

+29
-37
lines changed

3 files changed

+29
-37
lines changed

backends/.js/src/main/scala/Main.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ object Main {
77

88
val t = new ONNXJSOperatorBackend {}
99

10-
t.test
10+
t.test()
1111
println(s"Using Scala.js version ${System.getProperty("java.vm.version")}")
1212
}
1313
}
Lines changed: 20 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
package org.emergentorder.onnx.backends
22

33
import scala.concurrent.duration._
4-
import typings.onnxjs.onnxImplMod.Tensor.{^ => Tensor}
4+
import typings.onnxruntimeWeb.tensorMod.Tensor
5+
import typings.onnxruntimeWeb.tensorMod.Tensor.FloatType
6+
import typings.onnxruntimeWeb.tensorMod.Tensor.DataType
57
//import typings.onnxjs.libTensorMod.Tensor.DataTypeMap.DataTypeMapOps
6-
import typings.onnxjs.onnxImplMod.InferenceSession.{^ => InferenceSession}
8+
import typings.onnxruntimeWeb.mod.InferenceSession
9+
//import typings.onnxruntimeWeb.ort.InferenceSession.{^ => InferenceSession}
710
//import typings.onnxjs.onnxMod.Onnx
811

9-
import typings.onnxjs.onnxImplMod._
12+
//import typings.onnxruntimeWeb.onnxImplMod._
1013

1114
import scala.scalajs.js.|
1215

@@ -16,34 +19,23 @@ trait ONNXJSOperatorBackend {
1619
implicit val ec: scala.concurrent.ExecutionContext = scala.concurrent.ExecutionContext.global
1720
def test() = {
1821

19-
val session = new InferenceSession()
20-
val url = "relu.onnx"
21-
val modelFuture = session.loadModel(url).toFuture
22+
val session = InferenceSession.create("relu.onnx")
23+
val dataTypes = new FloatType {}
2224

23-
val dataTypes = new typings.onnxjs.libTensorMod.Tensor.FloatType {}
24-
25-
val outputFuture = modelFuture.map { x =>
26-
val inputs = Array(
25+
/*
26+
val inputs = Array(
2727
new Tensor(
28-
scala.scalajs.js.Array[Boolean | Double](
29-
(1 until 61).map(_.toDouble: Boolean | Double).toArray: _*
28+
"float32",
29+
scala.scalajs.js.Array[Double](
30+
(1 until 61).map(_.toDouble: Double).toArray: _*
3031
),
31-
typings.onnxjs.onnxjsStrings.float32,
3232
scala.scalajs.js.Array(3.0, 4.0, 5.0)
33-
): typings.onnxjs.tensorMod.Tensor
34-
);
35-
println("before run")
36-
val res = session.run(scala.scalajs.js.Array(inputs: _*)).toFuture
37-
println("after run")
38-
res
39-
}.flatten
40-
41-
import scala.util.{Success, Failure}
42-
43-
outputFuture onComplete {
44-
case Success(t) => println(t.get("y").get.dims)
45-
case Failure(fail) => println(fail)
46-
}
47-
33+
)
34+
)
35+
*/
36+
//println("before run")
37+
//val res = session.run(scala.scalajs.js.Array(inputs: _*))
38+
//println("after run")
39+
//res
4840
}
4941
}

build.sbt

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -52,14 +52,14 @@ lazy val backends = (crossProject(JVMPlatform, JSPlatform)
5252
.crossType(CrossType.Pure) in file("backends"))
5353
.dependsOn(core)
5454
//conditionally enabling/disable based on version, still not working
55-
// .enablePlugins(ScalaJSBundlerPlugin)//, ScalablyTypedConverterPlugin)
55+
// .enablePlugins(ScalablyTypedConverterPlugin)
5656
.settings(
5757
commonSettings,
5858
name := "onnx-scala-backends",
59-
excludeFilter in unmanagedSources := (CrossVersion
60-
.partialVersion(scalaVersion.value) match {
61-
case _ => "Main.scala" | "ONNXJSOperatorBackend.scala"
62-
}),
59+
// excludeFilter in unmanagedSources := (CrossVersion
60+
// .partialVersion(scalaVersion.value) match {
61+
// case _ => "Main.scala" | "ONNXJSOperatorBackend.scala"
62+
// }),
6363
// scalacOptions ++= { if (isDotty.value) Seq("-source:3.0-migration") else Nil },
6464
libraryDependencies ++= Seq(
6565
"com.microsoft.onnxruntime" % "onnxruntime" % "1.9.0"
@@ -70,12 +70,12 @@ lazy val backends = (crossProject(JVMPlatform, JSPlatform)
7070
//TODO: move to utest
7171
libraryDependencies += ("org.scalatest" %% "scalatest" % scalaTestVersion) % Test
7272
)
73-
.jsSettings(scalaJSUseMainModuleInitializer := true) //, //Testing
74-
//npmDependencies in Compile += "onnxjs" -> "0.1.8")
73+
.jsSettings(scalaJSUseMainModuleInitializer := true, //, //Testing
74+
npmDependencies in Compile += "onnxruntime-web" -> "1.9.0")
7575
//Seems to be a bundling issue, copying things manually seems to work
7676
//TODO NEW: try JS, bundler and converter beta/RC are out
7777
// npmDependencies in Compile += "onnxjs" -> "0.1.8")
78-
//.jsConfigure { project => project.enablePlugins(ScalaJSBundlerPlugin)} //ScalablyTypedConverterPlugin)}
78+
.jsConfigure { project => project.enablePlugins(ScalablyTypedConverterPlugin)}
7979
//ScalaJSBundlerPlugin)} //,ScalablyTypedConverterPlugin) }
8080

8181
lazy val core = (crossProject(JSPlatform, JVMPlatform)

0 commit comments

Comments
 (0)