@@ -61,7 +61,7 @@ trait ORTOperatorBackend extends OpToONNXBytesConverter {
61
61
S <: Shape
62
62
](
63
63
inputs : Tuple ,
64
- input_node_names : IO [ List [String ] ],
64
+ input_node_names : List [String ],
65
65
opName : String ,
66
66
attrs : Map [String , Any ]
67
67
)(using
@@ -82,7 +82,7 @@ trait ORTOperatorBackend extends OpToONNXBytesConverter {
82
82
*/
83
83
84
84
// TODO: more outputs
85
- val output_node_names = input_node_names.map(x => { List (x .toString) } )
85
+ val output_node_names = List ( input_node_names .toString)
86
86
87
87
// Spurious warning here, see: https://github.com/lampepfl/dotty/issues/10318
88
88
// TODO: don't mix up Options and Tensors here
@@ -95,30 +95,46 @@ trait ORTOperatorBackend extends OpToONNXBytesConverter {
95
95
case opt : Option [Tensor [T , Tuple3 [Tt , Td , S ]]] =>
96
96
opt match {
97
97
case Some (x) =>
98
- Some (x.data.flatMap { y =>
99
- x.shape.map { z =>
100
- getOnnxTensor(y, z)
98
+ Some (x.map { y =>
99
+ getOnnxTensor(y._1, y._2._3.toSeq.toArray)
101
100
}
102
- } )
101
+ )
103
102
case None => None
104
103
}
105
104
case tens : Tensor [T , Tuple3 [Tt , Td , S ]] =>
106
- Some (tens.data.flatMap { x =>
107
- tens.shape.map { y =>
108
- getOnnxTensor(x, y)
105
+ Some (tens.map { x =>
106
+ getOnnxTensor(x._1, x._2._3.toSeq.toArray)
109
107
}
110
- } )
108
+ )
111
109
}
112
110
}
113
111
.toList
114
112
.sequence
115
113
.map(_.toArray)
116
114
}
117
115
118
- val opModel = for {
116
+ def res (opModelBytes : Array [Byte ], inputTensorss : IO [Array [OnnxTensor [T ]]]) : Tensor [T , Tuple3 [Tt , Td , S ]] = {
117
+ cats.effect.Resource .make(inputTensorss)(inTens => IO {}).
118
+ use(x =>
119
+ cats.effect.Resource
120
+ .make(IO .blocking(getSession(opModelBytes)))(sess => IO {})
121
+ .use(sess =>
122
+ runModel(
123
+ sess,
124
+ x,
125
+ input_node_names,
126
+ output_node_names
127
+ )
128
+ )
129
+ // }
130
+ )
131
+
132
+ }
133
+
134
+ val finalRes = for {
119
135
tens <- inputTensors.memoize
120
136
t <- tens
121
- } yield opToModelProto(
137
+ } yield res( opToModelProto(
122
138
opName,
123
139
(t.map(x => x.asInstanceOf [tensorMod.Tensor ].`type`.valueOf.toString match {
124
140
// Can't access the enum int values here
@@ -135,25 +151,11 @@ trait ORTOperatorBackend extends OpToONNXBytesConverter {
135
151
)
136
152
zip t.map(_.dims.map(_.toInt).toArray)),
137
153
attrs
138
- ).toByteArray
139
-
140
- val res : Tensor [T , Tuple3 [Tt , Td , S ]] = {
141
- inputTensors.flatMap { x =>
142
- cats.effect.Resource
143
- .make(opModel.map(getSession(_)))(sess => IO {})
144
- .use(sess =>
145
- runModel(
146
- sess,
147
- x,
148
- input_node_names,
149
- output_node_names
150
- )
151
- )
152
- // }
153
- }
154
+ ).toByteArray,
155
+ tens
156
+ )
154
157
155
- }
156
- res.flatMap(IO .println(" opNAme = " + opName).as(_))
158
+ finalRes.flatten
157
159
// res
158
160
}
159
161
@@ -172,11 +174,11 @@ trait ORTOperatorBackend extends OpToONNXBytesConverter {
172
174
val result : Tensor [T , Tuple3 [Tt , Td , S ]] =
173
175
callByteArrayOp(
174
176
inputs,
175
- IO { inputNodeNames} ,
177
+ inputNodeNames,
176
178
opName,
177
179
attrs
178
180
)
179
- result
181
+ result // .flatMap(x => IO.println("opName = " + opName).as(x))
180
182
}
181
183
182
184
def runModel [
@@ -189,34 +191,34 @@ trait ORTOperatorBackend extends OpToONNXBytesConverter {
189
191
org.emergentorder.onnx.onnxruntimeCommon.inferenceSessionMod.InferenceSession
190
192
],
191
193
input_tensor_values : Array [OnnxTensor [T ]],
192
- inputNames : IO [ List [String ] ],
193
- outputNames : IO [ List [String ] ]
194
+ inputNames : List [String ],
195
+ outputNames : List [String ]
194
196
)(using
195
197
tt : ValueOf [Tt ],
196
198
td : TensorShapeDenotationOf [Td ],
197
199
s : ShapeOf [S ]
198
200
): Tensor [T , Tuple3 [Tt , Td , S ]] = {
199
201
200
- val feeds : IO [ js.Dictionary [OnnxTensor [T ]]] = inputNames.map(x => {
201
- val zipped = x .toArray zip input_tensor_values
202
+ val feeds : js.Dictionary [OnnxTensor [T ]] = {
203
+ val zipped = inputNames .toArray zip input_tensor_values
202
204
js.Dictionary (zipped.map(z => z._1 -> z._2): _* )
203
- })
205
+ }
204
206
205
207
val output_tensors : IO [org.emergentorder.onnx.onnxruntimeCommon.tensorMod.Tensor ] =
206
208
IO .fromFuture {
207
209
sess
208
210
.flatMap { realSess =>
209
- feeds.flatMap { realFeeds =>
211
+ // feeds.flatMap { realFeeds =>
210
212
val res = IO .eval(cats.Eval .later {
211
213
realSess
212
214
.run(
213
- realFeeds .asInstanceOf [
215
+ feeds .asInstanceOf [
214
216
org.emergentorder.onnx.onnxruntimeCommon.inferenceSessionMod.InferenceSession .FeedsType
215
217
]
216
218
)
217
219
.toFuture
218
220
})
219
- outputNames.flatMap { names =>
221
+ // outputNames.flatMap { names =>
220
222
res.map { result =>
221
223
result.map { rr =>
222
224
// println(realSess.outputNames.toList)
@@ -228,8 +230,8 @@ trait ORTOperatorBackend extends OpToONNXBytesConverter {
228
230
.get
229
231
}
230
232
}
231
- }
232
- }
233
+ // }
234
+ // }
233
235
}
234
236
}
235
237
@@ -359,8 +361,8 @@ trait ORTOperatorBackend extends OpToONNXBytesConverter {
359
361
](
360
362
session,
361
363
inputs,
362
- IO .pure { List (" data_0" ) } ,
363
- IO .pure { List (" squeezenet0_flatten0_reshape0" ) }
364
+ List (" data_0" ) ,
365
+ List (" squeezenet0_flatten0_reshape0" )
364
366
)
365
367
366
368
// res.foreach(tens => tens.data.foreach(println))
0 commit comments