Skip to content

Commit 7611e8f

Browse files
committed
Bump ort-web version, dotty to 3.1.1, remove fuseOps
1 parent 448cb4f commit 7611e8f

File tree

2 files changed

+2
-35
lines changed

2 files changed

+2
-35
lines changed

backends/.jvm/src/main/scala/ORTOperatorBackend.scala

Lines changed: 0 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,6 @@ import ORTTensorUtils._
1717

1818
trait ORTOperatorBackend extends OpToONNXBytesConverter with AutoCloseable {
1919

20-
// Java map performs better
21-
val sessionCache = new java.util.LinkedHashMap[String, ModelProto]
22-
2320
val env = OrtEnvironment.getEnvironment()
2421

2522
val coreCount = java.lang.Runtime.getRuntime().availableProcessors()
@@ -130,10 +127,6 @@ trait ORTOperatorBackend extends OpToONNXBytesConverter with AutoCloseable {
130127
val modelProto = opToModelProto(opName, inputs, attrs)
131128

132129
val result: Tensor[T, Tuple3[Tt, Td, S]] = callByteArrayOp(modelProto.toByteArray, inputs)
133-
sessionCache.computeIfAbsent(
134-
opName + inputs.toArray.toList.toString + attrs.toString,
135-
_ => modelToPersist(modelProto, result.toString.hashCode.toString)
136-
)
137130
result
138131
}
139132

@@ -145,33 +138,7 @@ trait ORTOperatorBackend extends OpToONNXBytesConverter with AutoCloseable {
145138
mod.clearGraph.withGraph(graphToPersist)
146139
}
147140

148-
// WARNING: not referentially transparent
149-
// Limitation: same reference cannot appear multiple times in a single op internally to the fused graph
150-
def fuseOps: ModelProto = {
151-
val cacheValues = sessionCache.values.asScala.toList
152-
if cacheValues.size == 0 then return ModelProto()
153-
val nodes = cacheValues.map(_.getGraph.node).fold(Seq[NodeProto]())((x, y) => x ++ y)
154-
val nodeOutputs = cacheValues
155-
.map(_.getGraph.output)
156-
.fold(Seq[ValueInfoProto]())((x, y) => x ++ y)
157-
.map(_.getName)
158-
val inputs = cacheValues
159-
.map(_.getGraph)
160-
.filter(!_.node(0).opType.equals(Some("Constant")))
161-
.map(_.input)
162-
.fold(Seq[ValueInfoProto]())((x, y) => x ++ y)
163-
.filter(z => !nodeOutputs.contains(z.getName))
164-
.distinct
165-
val outputs = cacheValues(cacheValues.size - 1).getGraph.output
166-
val modelProto = (cacheValues.head.clearGraph).withGraph(
167-
(new GraphProto).withNode(nodes).withInput(inputs).withOutput(outputs)
168-
)
169-
sessionCache.clear
170-
modelProto
171-
}
172-
173141
override def close(): Unit = {
174142
env.close
175-
// super.close
176143
}
177144
}

build.sbt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import sbtcrossproject.CrossPlugin.autoImport.{crossProject, CrossType}
22

33
//val dottyVersion = dottyLatestNightlyBuild.get
4-
val dottyVersion = "3.1.2-RC1"
4+
val dottyVersion = "3.1.1"
55
val spireVersion = "0.18.0-M3"
66
val scalaTestVersion = "3.2.11"
77

@@ -69,7 +69,7 @@ lazy val backends = (crossProject(JVMPlatform, JSPlatform)
6969
)
7070
.jsSettings(
7171
scalaJSUseMainModuleInitializer := true, // , //Testing
72-
Compile / npmDependencies += "onnxruntime-web" -> "1.10.0"
72+
Compile / npmDependencies += "onnxruntime-web" -> "1.11.0"
7373
)
7474
.jsConfigure { project => project.enablePlugins(ScalablyTypedConverterPlugin) }
7575

0 commit comments

Comments
 (0)