Skip to content

Commit 74e91f2

Browse files
committed
Fix misuse of IO in JS backend which was causing redundant op invocations
1 parent dda57f2 commit 74e91f2

File tree

2 files changed

+52
-48
lines changed

2 files changed

+52
-48
lines changed

backends/.js/src/main/scala/ORTOperatorBackend.scala

Lines changed: 46 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ trait ORTOperatorBackend extends OpToONNXBytesConverter {
6161
S <: Shape
6262
](
6363
inputs: Tuple,
64-
input_node_names: IO[List[String]],
64+
input_node_names: List[String],
6565
opName: String,
6666
attrs: Map[String, Any]
6767
)(using
@@ -82,7 +82,7 @@ trait ORTOperatorBackend extends OpToONNXBytesConverter {
8282
*/
8383

8484
// TODO: more outputs
85-
val output_node_names = input_node_names.map(x => { List(x.toString) })
85+
val output_node_names = List( input_node_names.toString)
8686

8787
// Spurious warning here, see: https://github.com/lampepfl/dotty/issues/10318
8888
// TODO: don't mix up Options and Tensors here
@@ -95,30 +95,46 @@ trait ORTOperatorBackend extends OpToONNXBytesConverter {
9595
case opt: Option[Tensor[T, Tuple3[Tt, Td, S]]] =>
9696
opt match {
9797
case Some(x) =>
98-
Some(x.data.flatMap { y =>
99-
x.shape.map { z =>
100-
getOnnxTensor(y, z)
98+
Some(x.map { y =>
99+
getOnnxTensor(y._1, y._2._3.toSeq.toArray)
101100
}
102-
})
101+
)
103102
case None => None
104103
}
105104
case tens: Tensor[T, Tuple3[Tt, Td, S]] =>
106-
Some(tens.data.flatMap { x =>
107-
tens.shape.map { y =>
108-
getOnnxTensor(x, y)
105+
Some(tens.map { x =>
106+
getOnnxTensor(x._1, x._2._3.toSeq.toArray)
109107
}
110-
})
108+
)
111109
}
112110
}
113111
.toList
114112
.sequence
115113
.map(_.toArray)
116114
}
117115

118-
val opModel = for {
116+
def res(opModelBytes: Array[Byte], inputTensorss: IO[Array[OnnxTensor[T]]]) : Tensor[T, Tuple3[Tt, Td, S]] = {
117+
cats.effect.Resource.make(inputTensorss)(inTens => IO {}).
118+
use(x =>
119+
cats.effect.Resource
120+
.make(IO.blocking(getSession(opModelBytes)))(sess => IO {})
121+
.use(sess =>
122+
runModel(
123+
sess,
124+
x,
125+
input_node_names,
126+
output_node_names
127+
)
128+
)
129+
// }
130+
)
131+
132+
}
133+
134+
val finalRes = for {
119135
tens <- inputTensors.memoize
120136
t <- tens
121-
} yield opToModelProto(
137+
} yield res(opToModelProto(
122138
opName,
123139
(t.map(x => x.asInstanceOf[tensorMod.Tensor].`type`.valueOf.toString match {
124140
//Can't access the enum int values here
@@ -135,25 +151,11 @@ trait ORTOperatorBackend extends OpToONNXBytesConverter {
135151
)
136152
zip t.map(_.dims.map(_.toInt).toArray)),
137153
attrs
138-
).toByteArray
139-
140-
val res: Tensor[T, Tuple3[Tt, Td, S]] = {
141-
inputTensors.flatMap { x =>
142-
cats.effect.Resource
143-
.make(opModel.map(getSession(_)))(sess => IO {})
144-
.use(sess =>
145-
runModel(
146-
sess,
147-
x,
148-
input_node_names,
149-
output_node_names
150-
)
151-
)
152-
// }
153-
}
154+
).toByteArray,
155+
tens
156+
)
154157

155-
}
156-
res.flatMap(IO.println("opNAme = " + opName).as(_))
158+
finalRes.flatten
157159
//res
158160
}
159161

@@ -172,11 +174,11 @@ trait ORTOperatorBackend extends OpToONNXBytesConverter {
172174
val result: Tensor[T, Tuple3[Tt, Td, S]] =
173175
callByteArrayOp(
174176
inputs,
175-
IO{inputNodeNames},
177+
inputNodeNames,
176178
opName,
177179
attrs
178180
)
179-
result
181+
result //.flatMap(x => IO.println("opName = " + opName).as(x))
180182
}
181183

182184
def runModel[
@@ -189,34 +191,34 @@ trait ORTOperatorBackend extends OpToONNXBytesConverter {
189191
org.emergentorder.onnx.onnxruntimeCommon.inferenceSessionMod.InferenceSession
190192
],
191193
input_tensor_values: Array[OnnxTensor[T]],
192-
inputNames: IO[List[String]],
193-
outputNames: IO[List[String]]
194+
inputNames: List[String],
195+
outputNames: List[String]
194196
)(using
195197
tt: ValueOf[Tt],
196198
td: TensorShapeDenotationOf[Td],
197199
s: ShapeOf[S]
198200
): Tensor[T, Tuple3[Tt, Td, S]] = {
199201

200-
val feeds: IO[js.Dictionary[OnnxTensor[T]]] = inputNames.map(x => {
201-
val zipped = x.toArray zip input_tensor_values
202+
val feeds: js.Dictionary[OnnxTensor[T]] = {
203+
val zipped = inputNames.toArray zip input_tensor_values
202204
js.Dictionary(zipped.map(z => z._1 -> z._2): _*)
203-
})
205+
}
204206

205207
val output_tensors: IO[org.emergentorder.onnx.onnxruntimeCommon.tensorMod.Tensor] =
206208
IO.fromFuture {
207209
sess
208210
.flatMap { realSess =>
209-
feeds.flatMap { realFeeds =>
211+
// feeds.flatMap { realFeeds =>
210212
val res = IO.eval(cats.Eval.later {
211213
realSess
212214
.run(
213-
realFeeds.asInstanceOf[
215+
feeds.asInstanceOf[
214216
org.emergentorder.onnx.onnxruntimeCommon.inferenceSessionMod.InferenceSession.FeedsType
215217
]
216218
)
217219
.toFuture
218220
})
219-
outputNames.flatMap { names =>
221+
// outputNames.flatMap { names =>
220222
res.map { result =>
221223
result.map { rr =>
222224
// println(realSess.outputNames.toList)
@@ -228,8 +230,8 @@ trait ORTOperatorBackend extends OpToONNXBytesConverter {
228230
.get
229231
}
230232
}
231-
}
232-
}
233+
// }
234+
// }
233235
}
234236
}
235237

@@ -359,8 +361,8 @@ trait ORTOperatorBackend extends OpToONNXBytesConverter {
359361
](
360362
session,
361363
inputs,
362-
IO.pure { List("data_0") },
363-
IO.pure { List("squeezenet0_flatten0_reshape0") }
364+
List("data_0") ,
365+
List("squeezenet0_flatten0_reshape0")
364366
)
365367

366368
// res.foreach(tens => tens.data.foreach(println))

backends/.js/src/main/scala/ORTWebModelBackend.scala

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -68,16 +68,18 @@ class ORTWebModelBackend(session: IO[InferenceSession]) extends Model() with ORT
6868
.map(_.toArray)
6969

7070
val output = inputTensors.flatMap { tns =>
71-
IO {
71+
inputNames.flatMap{inNames =>
72+
outputNames.flatMap{outNames =>
7273
runModel[T, Tt, Td, S](
7374
session,
7475
tns,
75-
inputNames,
76-
outputNames
76+
inNames,
77+
outNames
7778
)
79+
}
7880
}
7981
}
80-
output.flatten
82+
output//.flatten
8183
}
8284

8385
}

0 commit comments

Comments
 (0)