Skip to content

Commit aceed2e

Browse files
committed
Remove the use of unsafeRunSync; Fix misuse of IO in core and JVM backend which was causing redundant op invocations
1 parent 5bb57df commit aceed2e

File tree

6 files changed

+206
-271
lines changed

6 files changed

+206
-271
lines changed

backends/.js/src/main/scala/ORTOperatorBackend.scala

Lines changed: 37 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
package org.emergentorder.onnx.backends
22

33
import scala.concurrent.duration._
4-
//import typings.onnxruntimeWeb.tensorMod._
5-
//import typings.onnxruntimeWeb.tensorMod.Tensor.FloatType
4+
//import typings.onnxruntimeWeb.tensorMod
5+
import org.emergentorder.onnx.onnxruntimeWeb.tensorMod
66
//import typings.onnxruntimeWeb.tensorMod.Tensor.DataType
77
//import typings.onnxjs.libTensorMod.Tensor.DataTypeMap.DataTypeMapOps
88
import org.emergentorder.onnx.onnxruntimeWeb.mod.{InferenceSession => OrtSession}
@@ -38,12 +38,17 @@ trait ORTOperatorBackend extends OpToONNXBytesConverter {
3838
val bytesArrayBuffer = bytes.toTypedArray.buffer
3939
val session: IO[
4040
InferenceSession
41-
] = IO.fromFuture(IO { OrtSession.create(bytesArrayBuffer, {
42-
val opts = InferenceSession.SessionOptions()
43-
opts.executionProviders = scala.scalajs.js.Array("wasm")
44-
opts
45-
}
46-
).toFuture })
41+
] = IO.fromFuture(IO {
42+
OrtSession
43+
.create(
44+
bytesArrayBuffer, {
45+
val opts = InferenceSession.SessionOptions()
46+
opts.executionProviders = scala.scalajs.js.Array("cpu")
47+
opts
48+
}
49+
)
50+
.toFuture
51+
})
4752
session
4853
}
4954

@@ -54,9 +59,10 @@ trait ORTOperatorBackend extends OpToONNXBytesConverter {
5459
Td <: TensorShapeDenotation,
5560
S <: Shape
5661
](
57-
opModel: Array[Byte],
5862
inputs: Tuple,
59-
input_node_names: IO[List[String]]
63+
input_node_names: IO[List[String]],
64+
opName: String,
65+
attrs: Map[String, Any]
6066
)(using
6167
s: ShapeOf[S],
6268
tt: ValueOf[Tt],
@@ -108,13 +114,20 @@ trait ORTOperatorBackend extends OpToONNXBytesConverter {
108114
.map(_.toArray)
109115
}
110116

117+
val opModel = for {
118+
tens <- inputTensors.memoize
119+
t <- tens
120+
} yield opToModelProto(
121+
opName,
122+
(t.map(_.asInstanceOf[tensorMod.Tensor].`type`.valueOf.asInstanceOf[Float].round)
123+
zip t.map(_.dims.map(_.toInt).toArray)),
124+
attrs
125+
).toByteArray
126+
111127
val res: Tensor[T, Tuple3[Tt, Td, S]] = {
112-
// val resource = cats.effect.Resource.make(IO{getSession(opModel)})(sess => IO{sess.close})
113-
// resource.use( sess =>
114-
inputTensors.flatMap { x =>
115-
// input_node_names.flatMap{y =>
128+
inputTensors.flatMap { x =>
116129
cats.effect.Resource
117-
.make(IO(getSession(opModel)))(sess => IO {})
130+
.make(opModel.map(getSession(_)))(sess => IO {})
118131
.use(sess =>
119132
runModel(
120133
sess,
@@ -127,8 +140,8 @@ trait ORTOperatorBackend extends OpToONNXBytesConverter {
127140
}
128141

129142
}
130-
// res.flatMap(IO.println("Post run").as(_))
131-
res
143+
res.flatMap(IO.println("opNAme = " + opName).as(_))
144+
//res
132145
}
133146

134147
def callOp[T <: Supported, Tt <: TensorTypeDenotation, Td <: TensorShapeDenotation, S <: Shape](
@@ -142,26 +155,15 @@ trait ORTOperatorBackend extends OpToONNXBytesConverter {
142155
td: TensorShapeDenotationOf[Td],
143156
s: ShapeOf[S]
144157
): Tensor[T, Tuple3[Tt, Td, S]] = {
145-
// TODO: prevent passing input to opToONNXBytes
146-
147-
// println("ATTR " + attrs)
148-
val modelProto = opToModelProto(opName, inputs, attrs)
149-
150-
val result: IO[Tensor[T, Tuple3[Tt, Td, S]]] =
151-
for {
152-
mp <- modelProto.flatMap(IO.println("OpName => " + opName).as(_))
153-
} yield {
154-
// println(mp)
155-
callByteArrayOp(
156-
mp.toByteArray,
158+
val inputNodeNames = (0 until inputs.size).toList.map(_.toString)
159+
val result: Tensor[T, Tuple3[Tt, Td, S]] =
160+
callByteArrayOp(
157161
inputs,
158-
IO.pure {
159-
mp.graph.map(_.input.map(_.name.getOrElse(""))).getOrElse(List[String]()).toList
160-
}
162+
IO{inputNodeNames},
163+
opName,
164+
attrs
161165
)
162-
}
163-
164-
result.flatten
166+
result
165167
}
166168

167169
def runModel[

backends/.js/src/test/scala/SqueezeNetTest.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ class ONNXScalaSpec extends AsyncFreeSpec with AsyncIOSpec with Matchers {
2828
] = IO.fromFuture(IO { OrtSession.create("squeezenet1.0-12.onnx",
2929
{
3030
val opts = InferenceSession.SessionOptions()
31-
opts.executionProviders = scala.scalajs.js.Array("wasm")
31+
opts.executionProviders = scala.scalajs.js.Array("cpu")
3232
opts
3333
}
3434
).toFuture })

backends/.jvm/src/main/scala/ORTModelBackend.scala

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -74,17 +74,14 @@ class ORTModelBackend(onnxBytes: Array[Byte])
7474
val output = cats.effect.Resource
7575
.make(inputTensors)(inTens => IO { inTens.map(_.close) })
7676
.use(inTens =>
77-
IO {
78-
runModel[T, Tt, Td, S](
79-
session,
80-
inTens,
81-
allNodeNamesAndDims._1,
82-
allNodeNamesAndDims._3
83-
)
84-
}
77+
runModel[T, Tt, Td, S](
78+
session,
79+
inTens,
80+
allNodeNamesAndDims._1,
81+
allNodeNamesAndDims._3
82+
)
8583
)
86-
87-
output.unsafeRunSync()
84+
output
8885
}
8986

9087
override def close(): Unit = {}

backends/.jvm/src/main/scala/ORTOperatorBackend.scala

Lines changed: 47 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -71,15 +71,14 @@ trait ORTOperatorBackend extends OpToONNXBytesConverter with AutoCloseable {
7171

7272
// TODO: Denotations
7373
val result: Tensor[T, Tuple3[Tt, Td, S]] = tensArr
74-
.map(x =>
74+
.flatMap(x =>
7575
Tensor(
7676
x,
7777
tensorTypeDenotationFromType,
7878
tensorShapeDenotationFromType,
7979
shapeFromType
8080
)
8181
)
82-
.unsafeRunSync()
8382
// result.flatMap(IO.println("Invoking run").as(_))
8483
result
8584
}
@@ -91,52 +90,37 @@ trait ORTOperatorBackend extends OpToONNXBytesConverter with AutoCloseable {
9190
Td <: TensorShapeDenotation,
9291
S <: Shape
9392
](
94-
opModel: Array[Byte],
9593
inputs: Tuple,
96-
input_node_names: IO[List[String]]
94+
input_node_names: List[String],
95+
opName: String,
96+
attrs: Map[String, Any]
9797
)(using
9898
s: ShapeOf[S],
9999
tt: ValueOf[Tt],
100100
td: TensorShapeDenotationOf[Td]
101101
): Tensor[T, Tuple3[Tt, Td, S]] = {
102-
/*
103-
val input_node_names = inputs.toArray.zipWithIndex.map { (e, i) =>
104-
val incr: String = if inputs.toArray.distinct.size == inputs.size then "" else i.toString
105-
val tensE = e.asInstanceOf[Tensor[T, Tuple3[Tt, Td, S]]]
106-
tensE.map{x =>
107-
val t = ((x.toString + incr).hashCode).toString
108-
println("ANESMMMS " + t + " " + i)
109-
t
110-
}
111-
}.toList.sequence
112-
*/
113-
114102
// TODO: more outputs
115-
val output_node_names = List(input_node_names.toString)
103+
val output_node_names = List(inputs.size.toString)
116104

117105
// Spurious warning here, see: https://github.com/lampepfl/dotty/issues/10318
118106
// TODO: don't mix up Options and Tensors here
119107
@annotation.nowarn
120-
def inputTensors: IO[Array[OnnxTensor]] = {
108+
val inputTensors: IO[Array[OnnxTensor]] = {
121109

122110
inputs.toArray
123111
.flatMap { elem =>
124112
elem match {
125113
case opt: Option[Tensor[T, Tuple3[Tt, Td, S]]] =>
126114
opt match {
127115
case Some(x) =>
128-
Some(x.data.flatMap { y =>
129-
x.shape.map { z =>
130-
getOnnxTensor(y, z, env)
131-
}
116+
Some(x.map { y =>
117+
getOnnxTensor(y._1, y._2._3.toSeq.toArray, env)
132118
})
133-
case None => None
134119
}
120+
case None => None
135121
case tens: Tensor[T, Tuple3[Tt, Td, S]] =>
136-
Some(tens.data.flatMap { x =>
137-
tens.shape.map { y =>
138-
getOnnxTensor(x, y, env)
139-
}
122+
Some(tens.map { x =>
123+
getOnnxTensor(x._1, x._2._3.toSeq.toArray, env)
140124
})
141125
}
142126
}
@@ -145,30 +129,40 @@ trait ORTOperatorBackend extends OpToONNXBytesConverter with AutoCloseable {
145129
.map(_.toArray)
146130
}
147131

148-
def res: Tensor[T, Tuple3[Tt, Td, S]] = {
149-
// val resource = cats.effect.Resource.make(IO{getSession(opModel)})(sess => IO{sess.close})
150-
// resource.use( sess =>
132+
def res(
133+
opModelBytes: Array[Byte],
134+
inputTensorss: IO[Array[OnnxTensor]]
135+
): Tensor[T, Tuple3[Tt, Td, S]] = {
151136
cats.effect.Resource
152-
.make(inputTensors)(inTens => IO { inTens.map(_.close) })
137+
.make(inputTensorss)(inTens => IO { inTens.map(_.close) })
153138
.use(inTens =>
154-
input_node_names.flatMap { y =>
155-
cats.effect.Resource
156-
.make(IO.blocking(getSession(opModel)))(sess => IO { sess.close })
157-
.use(sess =>
158-
IO {
159-
runModel(
160-
sess,
161-
inTens,
162-
y,
163-
output_node_names
164-
)
165-
}
139+
cats.effect.Resource
140+
.make(IO.blocking(getSession(opModelBytes)))(sess => IO { sess.close })
141+
.use(sess =>
142+
runModel(
143+
sess,
144+
inTens,
145+
input_node_names,
146+
output_node_names
166147
)
167-
}
148+
)
168149
)
169-
}.unsafeRunSync()
150+
}
151+
152+
val resFinal = for {
153+
tens <- inputTensors.memoize
154+
t <- tens
155+
} yield res(
156+
opToModelProto(
157+
opName,
158+
(t.map(_.getInfo.onnxType.value) zip t.map(_.getInfo.getShape.map(_.toInt))),
159+
attrs
160+
).toByteArray,
161+
tens
162+
)
163+
170164
// res.flatMap(IO.println("Post run").as(_))
171-
res
165+
resFinal.flatten
172166
}
173167

174168
def callOp[T <: Supported, Tt <: TensorTypeDenotation, Td <: TensorShapeDenotation, S <: Shape](
@@ -182,25 +176,16 @@ trait ORTOperatorBackend extends OpToONNXBytesConverter with AutoCloseable {
182176
td: TensorShapeDenotationOf[Td],
183177
s: ShapeOf[S]
184178
): Tensor[T, Tuple3[Tt, Td, S]] = {
185-
// TODO: prevent passing input to opToONNXBytes
186-
187-
val modelProto = opToModelProto(opName, inputs, attrs)
188-
189-
// val mp = opToModelProto(opName, inputs, attrs)
179+
val inputNodeNames = (0 until inputs.size).toList.map(_.toString)
190180

191-
val result: IO[Tensor[T, Tuple3[Tt, Td, S]]] =
192-
for {
193-
mp <- modelProto // modelProto.flatMap(IO.println("OpName => " + opName).as(_))
194-
} yield callByteArrayOp(
195-
mp.toByteArray,
181+
val result: Tensor[T, Tuple3[Tt, Td, S]] =
182+
callByteArrayOp(
196183
inputs,
197-
IO.pure {
198-
mp.graph.map(_.input.map(_.name.getOrElse(""))).getOrElse(List[String]()).toList
199-
}
184+
inputNodeNames,
185+
opName,
186+
attrs
200187
)
201-
val r =
202-
result.unsafeRunSync() // If don't use unsafe here, we get redundant callOp invocations. If we memoize w/ unsafe, we leak memory.
203-
r // This approach makes callOp sync/eager again.
188+
result.flatMap(IO.println("Real call opName => " + opName).as(_))
204189
}
205190

206191
def modelToPersist(mod: ModelProto, outName: String) = {

0 commit comments

Comments
 (0)