Skip to content

Commit 0cf4c9b

Browse files
committed
Add SqueezeNet test
1 parent f98f790 commit 0cf4c9b

File tree

3 files changed

+51
-6
lines changed

3 files changed

+51
-6
lines changed

.travis.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
language: scala
22
script:
33
- bash get_models.sh
4-
- travis_wait 30 sbt +test -J-Xmx5G
4+
- travis_wait 30 sbt test -J-Xmx5G
55
# - sbt docsJVM/mdoc
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
package org.emergentorder.onnx.backends
2+
3+
import java.nio.file.{Files, Paths}
4+
import org.emergentorder.onnx.Tensors._
5+
import org.emergentorder.onnx.backends._
6+
import org.emergentorder.compiletime._
7+
import io.kjaer.compiletime._
8+
9+
10+
import org.scalatest.flatspec.AnyFlatSpec
11+
import org.scalatest.matchers.should._
12+
13+
14+
class ONNXScalaSpec extends AnyFlatSpec with Matchers{
15+
16+
"SqueezeNet ONNX-Scala model" should "predict dummy image class" in {
17+
val squeezenetBytes = Files.readAllBytes(Paths.get("squeezenet1.1.onnx"))
18+
val squeezenet = new ORTModelBackend(squeezenetBytes)
19+
val data = Array.fill(1*3*224*224){42f}
20+
//In NCHW tensor image format
21+
val shape = 1 #: 3 #: 224 #: 224 #: SNil
22+
val tensorShapeDenotation = "Batch" ##: "Channel" ##: "Height" ##: "Width" ##: TSNil
23+
24+
val tensorDenotation: String & Singleton = "Image"
25+
26+
val imageTens = Tensor(data,tensorDenotation,tensorShapeDenotation,shape)
27+
28+
//or as a shorthand if you aren't concerned with enforcing denotations
29+
val imageTensDefaultDenotations = Tensor(data,shape)
30+
val out = squeezenet.fullModel[Float,
31+
"ImageNetClassification",
32+
"Batch" ##: "Class" ##: TSNil,
33+
1 #: 1000 #: SNil](Tuple(imageTens))
34+
35+
//The output shape
36+
assert(out.shape(0) == 1)
37+
assert(out.shape(1) == 1000)
38+
39+
//The highest probability (predicted) class
40+
assert(out.data.indices.maxBy(out.data) == 418)
41+
}
42+
}

build.sbt

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ import sbtcrossproject.CrossPlugin.autoImport.{crossProject, CrossType}
33
val dottyVersion = "3.0.0-RC1"
44
val scala213Version = "2.13.5"
55
val spireVersion = "0.17.0"
6+
val scalaTestVersion = "3.2.5"
67

78
scalaVersion := dottyVersion
89

@@ -50,11 +51,11 @@ lazy val proto = (crossProject(JSPlatform, JVMPlatform)
5051
PB.protoSources in Compile := Seq(file("proto/src/main/protobuf")),
5152
)
5253

53-
lazy val backends = (crossProject(JSPlatform, JVMPlatform)
54+
lazy val backends = (crossProject(JVMPlatform, JSPlatform)
5455
.crossType(CrossType.Pure) in file("backends"))
5556
.dependsOn(core)
5657
//conditionally enabling/disable based on version, still not working
57-
.enablePlugins(ScalaJSBundlerPlugin)//, ScalablyTypedConverterPlugin)
58+
// .enablePlugins(ScalaJSBundlerPlugin)//, ScalablyTypedConverterPlugin)
5859
.settings(
5960
commonSettings,
6061
name := "onnx-scala-backends",
@@ -77,10 +78,12 @@ lazy val backends = (crossProject(JSPlatform, JVMPlatform)
7778
),
7879
crossScalaVersions := Seq(dottyVersion, scala213Version)
7980
)
80-
.jvmSettings().jsSettings(
81-
scalaJSUseMainModuleInitializer := true, //, //Testing
81+
.jvmSettings(
82+
libraryDependencies += ("org.scalatest" %% "scalatest" % scalaTestVersion) % Test,
83+
).jsSettings(
84+
scalaJSUseMainModuleInitializer := true) //, //Testing
8285
//Seems to be a bundling issue, copying things manually seems to work
83-
npmDependencies in Compile += "onnxjs" -> "0.1.8")
86+
// npmDependencies in Compile += "onnxjs" -> "0.1.8")
8487

8588
lazy val core = (crossProject(JSPlatform, JVMPlatform)
8689
.crossType(CrossType.Pure) in file("core"))

0 commit comments

Comments
 (0)