@@ -71,15 +71,14 @@ trait ORTOperatorBackend extends OpToONNXBytesConverter with AutoCloseable {
71
71
72
72
// TODO: Denotations
73
73
val result : Tensor [T , Tuple3 [Tt , Td , S ]] = tensArr
74
- .map (x =>
74
+ .flatMap (x =>
75
75
Tensor (
76
76
x,
77
77
tensorTypeDenotationFromType,
78
78
tensorShapeDenotationFromType,
79
79
shapeFromType
80
80
)
81
81
)
82
- .unsafeRunSync()
83
82
// result.flatMap(IO.println("Invoking run").as(_))
84
83
result
85
84
}
@@ -91,52 +90,37 @@ trait ORTOperatorBackend extends OpToONNXBytesConverter with AutoCloseable {
91
90
Td <: TensorShapeDenotation ,
92
91
S <: Shape
93
92
](
94
- opModel : Array [Byte ],
95
93
inputs : Tuple ,
96
- input_node_names : IO [List [String ]]
94
+ input_node_names : List [String ],
95
+ opName : String ,
96
+ attrs : Map [String , Any ]
97
97
)(using
98
98
s : ShapeOf [S ],
99
99
tt : ValueOf [Tt ],
100
100
td : TensorShapeDenotationOf [Td ]
101
101
): Tensor [T , Tuple3 [Tt , Td , S ]] = {
102
- /*
103
- val input_node_names = inputs.toArray.zipWithIndex.map { (e, i) =>
104
- val incr: String = if inputs.toArray.distinct.size == inputs.size then "" else i.toString
105
- val tensE = e.asInstanceOf[Tensor[T, Tuple3[Tt, Td, S]]]
106
- tensE.map{x =>
107
- val t = ((x.toString + incr).hashCode).toString
108
- println("ANESMMMS " + t + " " + i)
109
- t
110
- }
111
- }.toList.sequence
112
- */
113
-
114
102
// TODO: more outputs
115
- val output_node_names = List (input_node_names .toString)
103
+ val output_node_names = List (inputs.size .toString)
116
104
117
105
// Spurious warning here, see: https://github.com/lampepfl/dotty/issues/10318
118
106
// TODO: don't mix up Options and Tensors here
119
107
@ annotation.nowarn
120
- def inputTensors : IO [Array [OnnxTensor ]] = {
108
+ val inputTensors : IO [Array [OnnxTensor ]] = {
121
109
122
110
inputs.toArray
123
111
.flatMap { elem =>
124
112
elem match {
125
113
case opt : Option [Tensor [T , Tuple3 [Tt , Td , S ]]] =>
126
114
opt match {
127
115
case Some (x) =>
128
- Some (x.data.flatMap { y =>
129
- x.shape.map { z =>
130
- getOnnxTensor(y, z, env)
131
- }
116
+ Some (x.map { y =>
117
+ getOnnxTensor(y._1, y._2._3.toSeq.toArray, env)
132
118
})
133
- case None => None
134
119
}
120
+ case None => None
135
121
case tens : Tensor [T , Tuple3 [Tt , Td , S ]] =>
136
- Some (tens.data.flatMap { x =>
137
- tens.shape.map { y =>
138
- getOnnxTensor(x, y, env)
139
- }
122
+ Some (tens.map { x =>
123
+ getOnnxTensor(x._1, x._2._3.toSeq.toArray, env)
140
124
})
141
125
}
142
126
}
@@ -145,30 +129,40 @@ trait ORTOperatorBackend extends OpToONNXBytesConverter with AutoCloseable {
145
129
.map(_.toArray)
146
130
}
147
131
148
- def res : Tensor [T , Tuple3 [Tt , Td , S ]] = {
149
- // val resource = cats.effect.Resource.make(IO{getSession(opModel)})(sess => IO{sess.close})
150
- // resource.use( sess =>
132
+ def res (
133
+ opModelBytes : Array [Byte ],
134
+ inputTensorss : IO [Array [OnnxTensor ]]
135
+ ): Tensor [T , Tuple3 [Tt , Td , S ]] = {
151
136
cats.effect.Resource
152
- .make(inputTensors )(inTens => IO { inTens.map(_.close) })
137
+ .make(inputTensorss )(inTens => IO { inTens.map(_.close) })
153
138
.use(inTens =>
154
- input_node_names.flatMap { y =>
155
- cats.effect.Resource
156
- .make(IO .blocking(getSession(opModel)))(sess => IO { sess.close })
157
- .use(sess =>
158
- IO {
159
- runModel(
160
- sess,
161
- inTens,
162
- y,
163
- output_node_names
164
- )
165
- }
139
+ cats.effect.Resource
140
+ .make(IO .blocking(getSession(opModelBytes)))(sess => IO { sess.close })
141
+ .use(sess =>
142
+ runModel(
143
+ sess,
144
+ inTens,
145
+ input_node_names,
146
+ output_node_names
166
147
)
167
- }
148
+ )
168
149
)
169
- }.unsafeRunSync()
150
+ }
151
+
152
+ val resFinal = for {
153
+ tens <- inputTensors.memoize
154
+ t <- tens
155
+ } yield res(
156
+ opToModelProto(
157
+ opName,
158
+ (t.map(_.getInfo.onnxType.value) zip t.map(_.getInfo.getShape.map(_.toInt))),
159
+ attrs
160
+ ).toByteArray,
161
+ tens
162
+ )
163
+
170
164
// res.flatMap(IO.println("Post run").as(_))
171
- res
165
+ resFinal.flatten
172
166
}
173
167
174
168
def callOp [T <: Supported , Tt <: TensorTypeDenotation , Td <: TensorShapeDenotation , S <: Shape ](
@@ -182,25 +176,16 @@ trait ORTOperatorBackend extends OpToONNXBytesConverter with AutoCloseable {
182
176
td : TensorShapeDenotationOf [Td ],
183
177
s : ShapeOf [S ]
184
178
): Tensor [T , Tuple3 [Tt , Td , S ]] = {
185
- // TODO: prevent passing input to opToONNXBytes
186
-
187
- val modelProto = opToModelProto(opName, inputs, attrs)
188
-
189
- // val mp = opToModelProto(opName, inputs, attrs)
179
+ val inputNodeNames = (0 until inputs.size).toList.map(_.toString)
190
180
191
- val result : IO [Tensor [T , Tuple3 [Tt , Td , S ]]] =
192
- for {
193
- mp <- modelProto // modelProto.flatMap(IO.println("OpName => " + opName).as(_))
194
- } yield callByteArrayOp(
195
- mp.toByteArray,
181
+ val result : Tensor [T , Tuple3 [Tt , Td , S ]] =
182
+ callByteArrayOp(
196
183
inputs,
197
- IO .pure {
198
- mp.graph.map(_.input.map(_.name.getOrElse( " " ))).getOrElse( List [ String ]()).toList
199
- }
184
+ inputNodeNames,
185
+ opName,
186
+ attrs
200
187
)
201
- val r =
202
- result.unsafeRunSync() // If don't use unsafe here, we get redundant callOp invocations. If we memoize w/ unsafe, we leak memory.
203
- r // This approach makes callOp sync/eager again.
188
+ result.flatMap(IO .println(" Real call opName => " + opName).as(_))
204
189
}
205
190
206
191
def modelToPersist (mod : ModelProto , outName : String ) = {
0 commit comments