Skip to content

Commit 5bde84e

Browse files
committed
Add bool tensor type; Relax requirement, allowing tensors of dimensionality 0; Fix dotty-only crash on missing optional inputs
1 parent 3f9373a commit 5bde84e

File tree

6 files changed

+88
-35
lines changed

6 files changed

+88
-35
lines changed

backends/src/main/scala/ORTModelBackend.scala

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,24 +19,23 @@ class ORTModelBackend(onnxBytes: Array[Byte])
1919
val num_input_nodes = session.GetInputCount();
2020
val input_node_names = new PointerPointer[BytePointer](num_input_nodes);
2121

22-
// System.out.println("Number of inputs = " + num_input_nodes);
22+
2323

2424
val inputNodeDims = (0 until num_input_nodes.toInt).map { i =>
2525
// print input node names
2626
val input_name = session.GetInputName(i, allocator.asOrtAllocator())
27-
// println("Input " + i + " : name=" + input_name.getString())
27+
2828
input_node_names.put(i, input_name)
2929

3030
// print input node types
3131
val type_info = session.GetInputTypeInfo(i)
3232
val tensor_info = type_info.GetTensorTypeAndShapeInfo()
3333

3434
// val type = tensor_info.GetElementType()
35-
// println("Input " + i + " : type=" + type)
35+
3636

3737
// print input shapes/dims
3838
tensor_info.GetShape()
39-
//println("Input " + i + " : num_dims=" + input_node_dims.capacity())
4039

4140
}
4241

@@ -74,12 +73,8 @@ class ORTModelBackend(onnxBytes: Array[Byte])
7473
session,
7574
inputTensors,
7675
allNodeNamesAndDims._1,
77-
allNodeNamesAndDims._2,
7876
allNodeNamesAndDims._3
7977
)
80-
// val outputPointer = out.get(0).GetTensorMutableDataFloat().capacity(inputs.GetTensorTypeAndShapeInfo().GetElementCount());
81-
82-
// println(outputPointer.get(0).IsTensor())
8378

8479
Tuple1(output.asInstanceOf[T])
8580
}

backends/src/main/scala/ORTOperatorBackend.scala

Lines changed: 44 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,7 @@ trait ORTOperatorBackend
3535
def runModel(
3636
sess: Session,
3737
input_tensor_values: Array[Value],
38-
inputNames: PointerPointer[BytePointer],
39-
nodeDims: Array[LongPointer],
38+
inputNames: PointerPointer[BytePointer],
4039
outputNames: PointerPointer[BytePointer]
4140
) = {
4241

@@ -142,6 +141,7 @@ trait ORTOperatorBackend
142141
case f: Array[Float] => getTensorFloat(tens)
143142
case i: Array[Int] => getTensorInt(tens)
144143
case l: Array[Long] => getTensorLong(tens)
144+
case b: Array[Boolean] => getTensorBoolean(tens)
145145
}
146146
value
147147
}
@@ -159,6 +159,7 @@ trait ORTOperatorBackend
159159
case f: Array[Float] => getTensorFloat(tens)
160160
case i: Array[Int] => getTensorInt(tens)
161161
case l: Array[Long] => getTensorLong(tens)
162+
case b: Array[Boolean] => getTensorBoolean(tens)
162163
}
163164
value
164165
}
@@ -178,7 +179,7 @@ trait ORTOperatorBackend
178179
}
179180

180181
val size: Long = dims.capacity
181-
val inputTensorSize = (0 until size.toInt).map(j => dims.get(j)).reduce(_*_)
182+
val inputTensorSize = tens._1.size
182183

183184
val inputTensor: Value = Value.CreateTensorByte(
184185
memory_info.asOrtMemoryInfo,
@@ -200,7 +201,7 @@ trait ORTOperatorBackend
200201
}
201202

202203
val size: Long = dims.capacity
203-
val inputTensorSize = (0 until size.toInt).map(j => dims.get(j)).reduce(_*_)
204+
val inputTensorSize = tens._1.size
204205

205206
val inputTensor: Value = Value.CreateTensorShort(
206207
memory_info.asOrtMemoryInfo,
@@ -222,7 +223,7 @@ trait ORTOperatorBackend
222223
}
223224

224225
val size: Long = dims.capacity
225-
val inputTensorSize = (0 until size.toInt).map(j => dims.get(j)).reduce(_*_)
226+
val inputTensorSize = tens._1.size
226227

227228
val inputTensor: Value = Value.CreateTensorDouble(
228229
memory_info.asOrtMemoryInfo,
@@ -244,7 +245,7 @@ trait ORTOperatorBackend
244245
}
245246

246247
val size: Long = dims.capacity
247-
val inputTensorSize = (0 until size.toInt).map(j => dims.get(j)).reduce(_*_)
248+
val inputTensorSize = tens._1.size
248249

249250
val inputTensor: Value = Value.CreateTensorInt(
250251
memory_info.asOrtMemoryInfo,
@@ -266,7 +267,7 @@ trait ORTOperatorBackend
266267
}
267268

268269
val size: Long = dims.capacity
269-
val inputTensorSize = (0 until size.toInt).map(j => dims.get(j)).reduce(_*_)
270+
val inputTensorSize = tens._1.size
270271

271272
val inputTensor: Value = Value.CreateTensorLong(
272273
memory_info.asOrtMemoryInfo,
@@ -293,7 +294,7 @@ trait ORTOperatorBackend
293294

294295

295296
val size: Long = dims.capacity
296-
val inputTensorSize = (0 until size.toInt).map(j => dims.get(j)).reduce(_*_)
297+
val inputTensorSize = tens._1.size
297298

298299
val inputTensor: Value = Value.CreateTensorFloat(
299300
memory_info.asOrtMemoryInfo,
@@ -305,6 +306,31 @@ trait ORTOperatorBackend
305306
inputTensor
306307
}
307308

309+
def getTensorBoolean(tens: Tensor[Boolean]): Value = {
310+
311+
val inputArray = tens._1
312+
313+
val inputPointer = new BoolPointer(new BooleanPointer(inputArray: _*))
314+
315+
// input_node_names.put(i,new BytePointer(i.toString))
316+
317+
val dims = new LongPointer(tens._2.size)
318+
(0 until tens._2.size).map { i =>
319+
dims.put(i, tens._2(i))
320+
}
321+
322+
val size: Long = dims.capacity
323+
val inputTensorSize = tens._1.size
324+
325+
val inputTensor: Value = Value.CreateTensorBool(
326+
memory_info.asOrtMemoryInfo,
327+
inputPointer,
328+
inputTensorSize,
329+
dims,
330+
size
331+
)
332+
inputTensor
333+
}
308334

309335
def callByteArrayOp[
310336
T: ClassTag
@@ -324,22 +350,24 @@ trait ORTOperatorBackend
324350
val output_node_names = new PointerPointer[BytePointer](1)
325351

326352

327-
val inputDimsAndValues: Array[Tuple2[LongPointer, Value]] = (0 until x.size).map{i =>
353+
val inputDimsAndValues: Array[Value] = (0 until x.size).map{i =>
328354

355+
val tens = x(i)
329356

330-
input_node_names.put(i,new BytePointer(i.toString))
331-
332-
(null, getTensor(x(i)))
333-
}.toArray
357+
tens match {
358+
case None => None
359+
case _ =>
360+
input_node_names.put(i,new BytePointer(i.toString))
361+
Some(getTensor(tens))
362+
}
363+
}.toArray.flatten
334364

335365
output_node_names.put(0l,new BytePointer("outName"))
336366

337-
//println(tens._2(0))
338367
val output = runModel(
339368
sess,
340-
inputDimsAndValues.map(_._2),
369+
inputDimsAndValues,
341370
input_node_names,
342-
inputDimsAndValues.map(_._1),
343371
output_node_names
344372
)
345373

backends/src/main/scala/ORTOperatorBackend213.scala

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,7 @@ trait ORTOperatorBackend extends OpToONNXBytesConverter with AutoCloseable {
146146
case f: Array[Float] => getTensorFloat(tens)
147147
case i: Array[Int] => getTensorInt(tens)
148148
case l: Array[Long] => getTensorLong(tens)
149+
case b: Array[Boolean] => getTensorBoolean(tens)
149150
}
150151
value
151152
}
@@ -163,6 +164,7 @@ trait ORTOperatorBackend extends OpToONNXBytesConverter with AutoCloseable {
163164
case f: Array[Float] => getTensorFloat(tens)
164165
case i: Array[Int] => getTensorInt(tens)
165166
case l: Array[Long] => getTensorLong(tens)
167+
case b: Array[Boolean] => getTensorBoolean(tens)
166168
}
167169
value
168170
}
@@ -182,7 +184,7 @@ trait ORTOperatorBackend extends OpToONNXBytesConverter with AutoCloseable {
182184
}
183185

184186
val size: Long = dims.capacity
185-
val inputTensorSize = (0 until size.toInt).map(j => dims.get(j)).reduce(_ * _)
187+
val inputTensorSize = tens._1.size
186188

187189
val inputTensor: Value = Value.CreateTensorByte(
188190
memory_info.asOrtMemoryInfo,
@@ -204,7 +206,7 @@ trait ORTOperatorBackend extends OpToONNXBytesConverter with AutoCloseable {
204206
}
205207

206208
val size: Long = dims.capacity
207-
val inputTensorSize = (0 until size.toInt).map(j => dims.get(j)).reduce(_ * _)
209+
val inputTensorSize = tens._1.size
208210

209211
val inputTensor: Value = Value.CreateTensorShort(
210212
memory_info.asOrtMemoryInfo,
@@ -226,7 +228,7 @@ trait ORTOperatorBackend extends OpToONNXBytesConverter with AutoCloseable {
226228
}
227229

228230
val size: Long = dims.capacity
229-
val inputTensorSize = (0 until size.toInt).map(j => dims.get(j)).reduce(_ * _)
231+
val inputTensorSize = tens._1.size
230232

231233
val inputTensor: Value = Value.CreateTensorDouble(
232234
memory_info.asOrtMemoryInfo,
@@ -248,7 +250,7 @@ trait ORTOperatorBackend extends OpToONNXBytesConverter with AutoCloseable {
248250
}
249251

250252
val size: Long = dims.capacity
251-
val inputTensorSize = (0 until size.toInt).map(j => dims.get(j)).reduce(_ * _)
253+
val inputTensorSize = tens._1.size
252254

253255
val inputTensor: Value = Value.CreateTensorInt(
254256
memory_info.asOrtMemoryInfo,
@@ -270,7 +272,7 @@ trait ORTOperatorBackend extends OpToONNXBytesConverter with AutoCloseable {
270272
}
271273

272274
val size: Long = dims.capacity
273-
val inputTensorSize = (0 until size.toInt).map(j => dims.get(j)).reduce(_ * _)
275+
val inputTensorSize = tens._1.size
274276

275277
val inputTensor: Value = Value.CreateTensorLong(
276278
memory_info.asOrtMemoryInfo,
@@ -296,7 +298,7 @@ trait ORTOperatorBackend extends OpToONNXBytesConverter with AutoCloseable {
296298
}
297299

298300
val size: Long = dims.capacity
299-
val inputTensorSize = (0 until size.toInt).map(j => dims.get(j)).reduce(_ * _)
301+
val inputTensorSize = tens._1.size
300302

301303
val inputTensor: Value = Value.CreateTensorFloat(
302304
memory_info.asOrtMemoryInfo,
@@ -308,6 +310,32 @@ trait ORTOperatorBackend extends OpToONNXBytesConverter with AutoCloseable {
308310
inputTensor
309311
}
310312

313+
def getTensorBoolean(tens: Tensor[Boolean]): Value = {
314+
315+
val inputArray = tens._1
316+
317+
val inputPointer = new BoolPointer(new BooleanPointer(inputArray: _*))
318+
319+
// input_node_names.put(i,new BytePointer(i.toString))
320+
321+
val dims = new LongPointer(tens._2.size)
322+
(0 until tens._2.size).map { i =>
323+
dims.put(i, tens._2(i))
324+
}
325+
326+
val size: Long = dims.capacity
327+
val inputTensorSize = tens._1.size
328+
329+
val inputTensor: Value = Value.CreateTensorBool(
330+
memory_info.asOrtMemoryInfo,
331+
inputPointer,
332+
inputTensorSize,
333+
dims,
334+
size
335+
)
336+
inputTensor
337+
}
338+
311339
def callByteArrayOp[
312340
T: ClassTag,
313341
T1: ClassTag,

core/src/main/scala/ONNX213.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ package object onnx {
9191

9292
def getTensor[T](data: Array[T], t: Array[Int]): Tensor[T] = {
9393
val shape: Array[XInt] = t.map(z => z: XInt)
94-
require(data.size == shape.foldLeft(1)(_ * _))
94+
require((data.size == 1 && (shape sameElements Array[XInt]())) || (data.size == shape.foldLeft(1)(_ * _)))
9595
(data, t, AxesFactory.getAxes(shape, Array.fill(shape.size) { new Dim {} }))
9696
}
9797
def getTypesafeTensor[T, A <: Axes](data: Array[T], axes: A): TypesafeTensor[T, A] = {

core/src/main/scala/OpToONNXBytesConverter.scala

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,6 @@ trait OpToONNXBytesConverter extends AutoCloseable {
7878
case tensor: Some[_] => {
7979
node.add_input(inputName)
8080
}
81-
8281
case tensor: Tensor[_] => {
8382
node.add_input(inputName)
8483
}
@@ -96,7 +95,8 @@ trait OpToONNXBytesConverter extends AutoCloseable {
9695
}
9796
}
9897
*/
99-
case _ => ??? //TODO: Handle non-tensors / don't assume tensor here
98+
case None =>
99+
case _ => ??? //TODO: Handle non-tensors / don't assume tensor here
100100

101101
}
102102

@@ -115,7 +115,6 @@ trait OpToONNXBytesConverter extends AutoCloseable {
115115
}
116116

117117
protected def addInputToGraph[A](input: A, inputName: String, graph: GraphProto): Unit = {
118-
119118
input match {
120119
case tens: Tensor[_] => {
121120
val elemType = tens._1 match {
@@ -125,6 +124,7 @@ trait OpToONNXBytesConverter extends AutoCloseable {
125124
case f: Array[Float] => TensorProto.FLOAT
126125
case i: Array[Int] => TensorProto.INT32
127126
case l: Array[Long] => TensorProto.INT64
127+
case b: Array[Boolean] => TensorProto.BOOL
128128
}
129129

130130
val inputValueInfo = graph.add_input
@@ -186,8 +186,9 @@ trait OpToONNXBytesConverter extends AutoCloseable {
186186
case opt: Option[_] =>
187187
opt match {
188188
case Some(in) => addInputToGraph(in, i.toString, graph)
189+
case None =>
189190
}
190-
case _ => addInputToGraph(x(i), i.toString, graph)
191+
case _ => {addInputToGraph(x(i), i.toString, graph)}
191192

192193
}
193194
}

core/src/main/scala/OpToONNXBytesConverter213.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@ trait OpToONNXBytesConverter extends AutoCloseable {
135135
case f: Array[Short] => TensorProto.INT16
136136
case i: Array[Int] => TensorProto.INT32
137137
case l: Array[Long] => TensorProto.INT64
138+
case b: Array[Boolean] => TensorProto.BOOL
138139
}
139140

140141
val inputValueInfo = graph.add_input

0 commit comments

Comments
 (0)