@@ -17,9 +17,6 @@ import ORTTensorUtils._
17
17
18
18
trait ORTOperatorBackend extends OpToONNXBytesConverter with AutoCloseable {
19
19
20
- // Java map performs better
21
- val sessionCache = new java.util.LinkedHashMap [String , ModelProto ]
22
-
23
20
val env = OrtEnvironment .getEnvironment()
24
21
25
22
val coreCount = java.lang.Runtime .getRuntime().availableProcessors()
@@ -130,10 +127,6 @@ trait ORTOperatorBackend extends OpToONNXBytesConverter with AutoCloseable {
130
127
val modelProto = opToModelProto(opName, inputs, attrs)
131
128
132
129
val result : Tensor [T , Tuple3 [Tt , Td , S ]] = callByteArrayOp(modelProto.toByteArray, inputs)
133
- sessionCache.computeIfAbsent(
134
- opName + inputs.toArray.toList.toString + attrs.toString,
135
- _ => modelToPersist(modelProto, result.toString.hashCode.toString)
136
- )
137
130
result
138
131
}
139
132
@@ -145,33 +138,7 @@ trait ORTOperatorBackend extends OpToONNXBytesConverter with AutoCloseable {
145
138
mod.clearGraph.withGraph(graphToPersist)
146
139
}
147
140
148
- // WARNING: not referentially transparent
149
- // Limitation: same reference cannot appear multiple times in a single op internally to the fused graph
150
- def fuseOps : ModelProto = {
151
- val cacheValues = sessionCache.values.asScala.toList
152
- if cacheValues.size == 0 then return ModelProto ()
153
- val nodes = cacheValues.map(_.getGraph.node).fold(Seq [NodeProto ]())((x, y) => x ++ y)
154
- val nodeOutputs = cacheValues
155
- .map(_.getGraph.output)
156
- .fold(Seq [ValueInfoProto ]())((x, y) => x ++ y)
157
- .map(_.getName)
158
- val inputs = cacheValues
159
- .map(_.getGraph)
160
- .filter(! _.node(0 ).opType.equals(Some (" Constant" )))
161
- .map(_.input)
162
- .fold(Seq [ValueInfoProto ]())((x, y) => x ++ y)
163
- .filter(z => ! nodeOutputs.contains(z.getName))
164
- .distinct
165
- val outputs = cacheValues(cacheValues.size - 1 ).getGraph.output
166
- val modelProto = (cacheValues.head.clearGraph).withGraph(
167
- (new GraphProto ).withNode(nodes).withInput(inputs).withOutput(outputs)
168
- )
169
- sessionCache.clear
170
- modelProto
171
- }
172
-
173
141
override def close (): Unit = {
174
142
env.close
175
- // super.close
176
143
}
177
144
}
0 commit comments