@@ -59,24 +59,27 @@ trait ORTOperatorBackend extends OpToONNXBytesConverter with AutoCloseable {
59
59
val tensorTypeDenotationFromType = tt.value
60
60
val tensorShapeDenotationFromType = td.value
61
61
62
- val tensArr : IO [Array [T ]] = cats.effect.Resource .make(IO .blocking{sess.run(inputs)})(outTens => IO {outTens.close}).use(outTens =>
63
- {
64
- val firstOut = outTens.get(0 ).asInstanceOf [OnnxTensor ]
65
- val shape = firstOut.getInfo.getShape.map(_.toInt)
62
+ val tensArr : IO [Array [T ]] = cats.effect.Resource
63
+ .make(IO .blocking { sess.run(inputs) })(outTens => IO { outTens.close })
64
+ .use(outTens => {
65
+ val firstOut = outTens.get(0 ).asInstanceOf [OnnxTensor ]
66
+ val shape = firstOut.getInfo.getShape.map(_.toInt)
66
67
67
- require(shape sameElements shapeFromType.toSeq)
68
- IO .blocking{getArrayFromOnnxTensor(firstOut)}
69
- }
70
- )
68
+ require(shape sameElements shapeFromType.toSeq)
69
+ IO .blocking { getArrayFromOnnxTensor(firstOut) }
70
+ })
71
71
72
72
// TODO: Denotations
73
- val result : Tensor [T , Tuple3 [Tt , Td , S ]] = tensArr.map(x => Tensor (
74
- x,
75
- tensorTypeDenotationFromType,
76
- tensorShapeDenotationFromType,
77
- shapeFromType
78
- )
79
- ).unsafeRunSync()
73
+ val result : Tensor [T , Tuple3 [Tt , Td , S ]] = tensArr
74
+ .map(x =>
75
+ Tensor (
76
+ x,
77
+ tensorTypeDenotationFromType,
78
+ tensorShapeDenotationFromType,
79
+ shapeFromType
80
+ )
81
+ )
82
+ .unsafeRunSync()
80
83
// result.flatMap(IO.println("Invoking run").as(_))
81
84
result
82
85
}
@@ -145,21 +148,24 @@ trait ORTOperatorBackend extends OpToONNXBytesConverter with AutoCloseable {
145
148
def res : Tensor [T , Tuple3 [Tt , Td , S ]] = {
146
149
// val resource = cats.effect.Resource.make(IO{getSession(opModel)})(sess => IO{sess.close})
147
150
// resource.use( sess =>
148
- cats.effect.Resource .make(inputTensors)(inTens => IO {inTens.map(_.close)}).use(inTens =>
149
- input_node_names.flatMap { y =>
150
- cats.effect.Resource
151
- .make(IO .blocking(getSession(opModel)))(sess => IO { sess.close })
152
- .use(sess =>
153
- IO {runModel(
154
- sess,
155
- inTens,
156
- y,
157
- output_node_names
151
+ cats.effect.Resource
152
+ .make(inputTensors)(inTens => IO { inTens.map(_.close) })
153
+ .use(inTens =>
154
+ input_node_names.flatMap { y =>
155
+ cats.effect.Resource
156
+ .make(IO .blocking(getSession(opModel)))(sess => IO { sess.close })
157
+ .use(sess =>
158
+ IO {
159
+ runModel(
160
+ sess,
161
+ inTens,
162
+ y,
163
+ output_node_names
164
+ )
165
+ }
158
166
)
159
- }
160
- )
161
- }
162
- )
167
+ }
168
+ )
163
169
}.unsafeRunSync()
164
170
// res.flatMap(IO.println("Post run").as(_))
165
171
res
@@ -184,16 +190,17 @@ trait ORTOperatorBackend extends OpToONNXBytesConverter with AutoCloseable {
184
190
185
191
val result : IO [Tensor [T , Tuple3 [Tt , Td , S ]]] =
186
192
for {
187
- mp <- modelProto // modelProto.flatMap(IO.println("OpName => " + opName).as(_))
193
+ mp <- modelProto // modelProto.flatMap(IO.println("OpName => " + opName).as(_))
188
194
} yield callByteArrayOp(
189
- mp.toByteArray,
190
- inputs,
191
- IO .pure {
192
- mp.graph.map(_.input.map(_.name.getOrElse(" " ))).getOrElse(List [String ]()).toList
193
- }
194
- )
195
- val r = result.unsafeRunSync() // If don't use unsafe here, we get redundant callOp invocations. If we memoize w/ unsafe, we leak memory.
196
- r // This approach makes callOp sync/eager again.
195
+ mp.toByteArray,
196
+ inputs,
197
+ IO .pure {
198
+ mp.graph.map(_.input.map(_.name.getOrElse(" " ))).getOrElse(List [String ]()).toList
199
+ }
200
+ )
201
+ val r =
202
+ result.unsafeRunSync() // If don't use unsafe here, we get redundant callOp invocations. If we memoize w/ unsafe, we leak memory.
203
+ r // This approach makes callOp sync/eager again.
197
204
}
198
205
199
206
def modelToPersist (mod : ModelProto , outName : String ) = {
@@ -204,6 +211,5 @@ trait ORTOperatorBackend extends OpToONNXBytesConverter with AutoCloseable {
204
211
mod.clearGraph.withGraph(graphToPersist)
205
212
}
206
213
207
- override def close (): Unit = {
208
- }
214
+ override def close (): Unit = {}
209
215
}
0 commit comments