Skip to content

Commit 5e80264

Browse files
committed
Format
1 parent 053c60f commit 5e80264

File tree

6 files changed

+725
-823
lines changed

6 files changed

+725
-823
lines changed

backends/src/main/scala/ORTModelBackend.scala

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@ import org.bytedeco.onnxruntime.global.onnxruntime._
99

1010
import org.emergentorder.onnx._
1111

12-
1312
//TODO: Clean up, remove asInstaceOf, etc.
1413
class ORTModelBackend(onnxBytes: Array[Byte])
1514
extends Model(onnxBytes)
@@ -63,29 +62,29 @@ class ORTModelBackend(onnxBytes: Array[Byte])
6362

6463
inputs match {
6564
case Some(x) => {
66-
65+
6766
val size = x.size
68-
val inputTensors = (0 until size).map{i =>
69-
val tens = x.apply(i)
67+
val inputTensors = (0 until size).map { i =>
68+
val tens = x.apply(i)
7069
val inputTensor: Value = getTensor(tens)
7170
inputTensor
7271
}.toArray
7372

74-
val output = runModel(
75-
session,
76-
inputTensors,
77-
allNodeNamesAndDims._1,
78-
allNodeNamesAndDims._2,
79-
allNodeNamesAndDims._3
80-
)
73+
val output = runModel(
74+
session,
75+
inputTensors,
76+
allNodeNamesAndDims._1,
77+
allNodeNamesAndDims._2,
78+
allNodeNamesAndDims._3
79+
)
8180
// val outputPointer = out.get(0).GetTensorMutableDataFloat().capacity(inputs.GetTensorTypeAndShapeInfo().GetElementCount());
8281

8382
// println(outputPointer.get(0).IsTensor())
8483

85-
Tuple1(output.asInstanceOf[T])
86-
}
84+
Tuple1(output.asInstanceOf[T])
85+
}
8786
case None => Tuple1(TensorFactory.getTensor(Array(), Array[Int]()).asInstanceOf[T])
88-
87+
8988
}
9089
}
9190

backends/src/main/scala/ORTOperatorBackendAll.scala

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ class ORTOperatorBackendAll
6464
with GatherV1
6565
with GatherV11
6666
with GatherElementsV11 //fails in scoreboard //passes in ORT
67-
with GatherNDV11 //fails in scoreboard //passes in nGraph dev branch & ORT
67+
with GatherNDV11 //fails in scoreboard //passes in nGraph dev branch & ORT
6868
with GemmV9
6969
with GemmV11
7070
with GlobalAveragePoolV1
@@ -82,7 +82,7 @@ class ORTOperatorBackendAll
8282
with InstanceNormalizationV6
8383
// with InverseV12 //New in 1.7.0
8484
with IsInfV10 //fails in scoreboard //passes in ORT
85-
with IsNaNV9 //fails in scoreboard //passes in ORT
85+
with IsNaNV9 //fails in scoreboard //passes in ORT
8686
with LRNV1
8787
with LSTMV7
8888
// with LabelEncoderV2 //ONNX ML, not tested for in scoreboard
@@ -113,7 +113,7 @@ class ORTOperatorBackendAll
113113
with NegV6
114114
// with NegativeLogLikelihoodLossV12 //new in 1.7.0
115115
with NonMaxSuppressionV11 //fails in scoreboard //passes in ORT
116-
with NonZeroV9 //fails in scoreboard //passes in ORT
116+
with NonZeroV9 //fails in scoreboard //passes in ORT
117117
// with NormalizerV1 //ONNX ML, not tested in scoreboard
118118
with NotV1
119119
with OneHotV11 //fails in scoreboard //passes in ORT
@@ -122,7 +122,7 @@ class ORTOperatorBackendAll
122122
with PReluV9
123123
with PadV11
124124
with PowV7
125-
with QLinearConvV10 //fails in scoreboard //passes in ORT
125+
with QLinearConvV10 //fails in scoreboard //passes in ORT
126126
with QLinearMatMulV10 //fails in scoreboard //passes in ORT
127127
with QuantizeLinearV10
128128
with RNNV7 //fails in scoreboard //passes in ORT
@@ -148,21 +148,21 @@ class ORTOperatorBackendAll
148148
with ResizeV11 //fails in scoreboard //passes in ORT
149149
with ReverseSequenceV10
150150
with RoiAlignV10 //fails in scoreboard //passes in ORT
151-
with RoundV11 //fails in scoreboard //passes in nGraph dev branch & ORT
151+
with RoundV11 //fails in scoreboard //passes in nGraph dev branch & ORT
152152
// with SVMClassifierV1 //ONNXML, not tested for in scoreboard
153153
// with SVMRegressorV1 //ONNXML, not tested for in scoreboard
154154
// with ScalerV1 //ONNXML, not tested for in scoreboard
155-
with ScanV11 //fails in scoreboard //passes in ORT
156-
with ScatterV11 //fails in scoreboard //passes in ORT
155+
with ScanV11 //fails in scoreboard //passes in ORT
156+
with ScatterV11 //fails in scoreboard //passes in ORT
157157
with ScatterElementsV11 //fails in scoreboard //passes in ORT
158-
with ScatterNDV11 //fails in scoreboard //passes in nGraph dev branch & ORT
158+
with ScatterNDV11 //fails in scoreboard //passes in nGraph dev branch & ORT
159159
with SeluV6
160-
with SequenceAtV11 //fails in scoreboard //passes in ORT
160+
with SequenceAtV11 //fails in scoreboard //passes in ORT
161161
with SequenceConstructV11 //fails in scoreboard //passes in ORT
162-
with SequenceEmptyV11 //fails in scoreboard //passes in ORT
163-
with SequenceEraseV11 //fails in scoreboard //passes in ORT
164-
with SequenceInsertV11 //fails in scoreboard //passes in ORT
165-
with SequenceLengthV11 //fails in scoreboard //passes in ORT
162+
with SequenceEmptyV11 //fails in scoreboard //passes in ORT
163+
with SequenceEraseV11 //fails in scoreboard //passes in ORT
164+
with SequenceInsertV11 //fails in scoreboard //passes in ORT
165+
with SequenceLengthV11 //fails in scoreboard //passes in ORT
166166
with ShapeV1
167167
with ShrinkV9
168168
with SigmoidV6
@@ -191,13 +191,13 @@ class ORTOperatorBackendAll
191191
with TanhV6
192192
with TfIdfVectorizerV9 //fails in scoreboard //passes in ORT
193193
with ThresholdedReluV10
194-
with TileV6 //fails in scoreboard //passes in ORT
194+
with TileV6 //fails in scoreboard //passes in ORT
195195
with TopKV11 //fails in scoreboard //passes in ORT
196196
with TransposeV1
197197
// with TreeEnsembleClassifierV1 //ONNX ML, not tested for in scoreboard
198198
// with TreeEnsembleRegressorV1 //ONNX ML, not tested for in scoreboard
199199
with UniqueV11 //fails in scoreboard //passes in ORT
200-
with UnsqueezeV1
200+
with UnsqueezeV1
201201
with UnsqueezeV11
202202
with UpsampleV10 //fails in scoreboard //passes in ORT
203203
with WhereV9

core/src/main/scala/ONNX.scala

Lines changed: 41 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import scala.reflect.ClassTag
1414
import org.bytedeco.onnx.ModelProto
1515
package object onnx {
1616

17+
//TODO: Remove requirement to be Numeric for ops with non-numeric outputs / inputs
1718
//TODO: Encode node names as types
1819
//TODO: fix encoding of type constraints, use Tensor as part of definition of types
1920
sealed trait Operator {
@@ -551,7 +552,10 @@ package object onnx {
551552
}
552553

553554
trait CastMapV1 extends Operator {
554-
def CastMapV1[@sp T1 <: Map[Long, String] | Map[Long, Float]: Numeric: ClassTag, @sp T2 <: String | Float | Long: Numeric: ClassTag](
555+
def CastMapV1[@sp T1 <: Map[Long, String] | Map[
556+
Long,
557+
Float
558+
]: Numeric: ClassTag, @sp T2 <: String | Float | Long: Numeric: ClassTag](
555559
name: String,
556560
cast_to: Option[(String)] = None,
557561
map_form: Option[(String)] = None,
@@ -755,13 +759,17 @@ package object onnx {
755759
}
756760

757761
trait ConcatFromSequenceV11 extends Operator {
758-
def ConcatFromSequenceV11[@sp S <: Seq[Tensor[UByte]] | Seq[Tensor[UShort]] | Seq[Tensor[UInt]] | Seq[
762+
def ConcatFromSequenceV11[@sp S <: Seq[Tensor[UByte]] | Seq[Tensor[UShort]] | Seq[
763+
Tensor[UInt]
764+
] | Seq[
759765
Tensor[ULong]
760766
] | Seq[Tensor[Byte]] | Seq[Tensor[Short]] | Seq[Tensor[Int]] | Seq[Tensor[Long]] | Seq[
761767
Tensor[Float16]
762768
] | Seq[Tensor[Float]] | Seq[Tensor[Double]] | Seq[Tensor[String]] | Seq[Tensor[Boolean]] | Seq[
763769
Tensor[Complex[Float]]
764-
] | Seq[Tensor[Complex[Double]]]: Numeric: ClassTag, @sp T <: UByte | UShort | UInt | ULong | Byte | Short | Int | Long | Float16 | Float | Double | String | Boolean | Complex[
770+
] | Seq[
771+
Tensor[Complex[Double]]
772+
]: Numeric: ClassTag, @sp T <: UByte | UShort | UInt | ULong | Byte | Short | Int | Long | Float16 | Float | Double | String | Boolean | Complex[
765773
Float
766774
] | Complex[Double]: Numeric: ClassTag](
767775
name: String,
@@ -1132,7 +1140,10 @@ package object onnx {
11321140
def DictVectorizerV1[@sp T1 <: Map[String, Long] | Map[Long, String] | Map[Long, Float] | Map[
11331141
Long,
11341142
Double
1135-
] | Map[String, Float] | Map[String, Double]: Numeric: ClassTag, @sp T2 <: Long | Float | Double | String: Numeric: ClassTag](
1143+
] | Map[String, Float] | Map[
1144+
String,
1145+
Double
1146+
]: Numeric: ClassTag, @sp T2 <: Long | Float | Double | String: Numeric: ClassTag](
11361147
name: String,
11371148
int64_vocabulary: Option[(Array[Int])] = None,
11381149
string_vocabulary: Option[(Array[String])] = None,
@@ -1304,7 +1315,7 @@ package object onnx {
13041315
trait EqualV11 extends Operator {
13051316
def EqualV11[
13061317
@sp T <: Boolean | UByte | UShort | UInt | ULong | Byte | Short | Int | Long | Float16 | Float | Double: Numeric: ClassTag,
1307-
@sp T1 <: Boolean: Numeric: ClassTag
1318+
@sp T1 <: Boolean: ClassTag
13081319
](name: String, A: Tensor[T], B: Tensor[T]): Tuple1[Tensor[T1]] = {
13091320
val map: Map[String, Any] = Map()
13101321
val allInputs = Some(Tuple2(A, B))
@@ -4328,7 +4339,9 @@ package object onnx {
43284339
Tensor[Float16]
43294340
] | Seq[Tensor[Float]] | Seq[Tensor[Double]] | Seq[Tensor[String]] | Seq[Tensor[Boolean]] | Seq[
43304341
Tensor[Complex[Float]]
4331-
] | Seq[Tensor[Complex[Double]]]: Numeric: ClassTag, @sp I <: Int | Long: Numeric: ClassTag, @sp T <: UByte | UShort | UInt | ULong | Byte | Short | Int | Long | Float16 | Float | Double | String | Boolean | Complex[
4342+
] | Seq[
4343+
Tensor[Complex[Double]]
4344+
]: Numeric: ClassTag, @sp I <: Int | Long: Numeric: ClassTag, @sp T <: UByte | UShort | UInt | ULong | Byte | Short | Int | Long | Float16 | Float | Double | String | Boolean | Complex[
43324345
Float
43334346
] | Complex[Double]: Numeric: ClassTag](
43344347
name: String,
@@ -4346,7 +4359,9 @@ package object onnx {
43464359
@sp T <: UByte | UShort | UInt | ULong | Byte | Short | Int | Long | Float16 | Float | Double | String | Boolean | Complex[
43474360
Float
43484361
] | Complex[Double]: Numeric: ClassTag,
4349-
@sp S <: Seq[Tensor[UByte]] | Seq[Tensor[UShort]] | Seq[Tensor[UInt]] | Seq[Tensor[ULong]] | Seq[
4362+
@sp S <: Seq[Tensor[UByte]] | Seq[Tensor[UShort]] | Seq[Tensor[UInt]] | Seq[
4363+
Tensor[ULong]
4364+
] | Seq[
43504365
Tensor[Byte]
43514366
] | Seq[Tensor[Short]] | Seq[Tensor[Int]] | Seq[Tensor[Long]] | Seq[Tensor[Float16]] | Seq[
43524367
Tensor[Float]
@@ -4361,7 +4376,9 @@ package object onnx {
43614376
}
43624377

43634378
trait SequenceEmptyV11 extends Operator {
4364-
def SequenceEmptyV11[@sp S <: Seq[Tensor[UByte]] | Seq[Tensor[UShort]] | Seq[Tensor[UInt]] | Seq[
4379+
def SequenceEmptyV11[@sp S <: Seq[Tensor[UByte]] | Seq[Tensor[UShort]] | Seq[
4380+
Tensor[UInt]
4381+
] | Seq[
43654382
Tensor[ULong]
43664383
] | Seq[Tensor[Byte]] | Seq[Tensor[Short]] | Seq[Tensor[Int]] | Seq[Tensor[Long]] | Seq[
43674384
Tensor[Float16]
@@ -4378,7 +4395,9 @@ package object onnx {
43784395
}
43794396

43804397
trait SequenceEraseV11 extends Operator {
4381-
def SequenceEraseV11[@sp S <: Seq[Tensor[UByte]] | Seq[Tensor[UShort]] | Seq[Tensor[UInt]] | Seq[
4398+
def SequenceEraseV11[@sp S <: Seq[Tensor[UByte]] | Seq[Tensor[UShort]] | Seq[
4399+
Tensor[UInt]
4400+
] | Seq[
43824401
Tensor[ULong]
43834402
] | Seq[Tensor[Byte]] | Seq[Tensor[Short]] | Seq[Tensor[Int]] | Seq[Tensor[Long]] | Seq[
43844403
Tensor[Float16]
@@ -4396,13 +4415,17 @@ package object onnx {
43964415
}
43974416

43984417
trait SequenceInsertV11 extends Operator {
4399-
def SequenceInsertV11[@sp S <: Seq[Tensor[UByte]] | Seq[Tensor[UShort]] | Seq[Tensor[UInt]] | Seq[
4418+
def SequenceInsertV11[@sp S <: Seq[Tensor[UByte]] | Seq[Tensor[UShort]] | Seq[
4419+
Tensor[UInt]
4420+
] | Seq[
44004421
Tensor[ULong]
44014422
] | Seq[Tensor[Byte]] | Seq[Tensor[Short]] | Seq[Tensor[Int]] | Seq[Tensor[Long]] | Seq[
44024423
Tensor[Float16]
44034424
] | Seq[Tensor[Float]] | Seq[Tensor[Double]] | Seq[Tensor[String]] | Seq[Tensor[Boolean]] | Seq[
44044425
Tensor[Complex[Float]]
4405-
] | Seq[Tensor[Complex[Double]]]: Numeric: ClassTag, @sp T <: UByte | UShort | UInt | ULong | Byte | Short | Int | Long | Float16 | Float | Double | String | Boolean | Complex[
4426+
] | Seq[
4427+
Tensor[Complex[Double]]
4428+
]: Numeric: ClassTag, @sp T <: UByte | UShort | UInt | ULong | Byte | Short | Int | Long | Float16 | Float | Double | String | Boolean | Complex[
44064429
Float
44074430
] | Complex[Double]: Numeric: ClassTag, @sp I <: Int | Long: Numeric: ClassTag](
44084431
name: String,
@@ -4417,7 +4440,9 @@ package object onnx {
44174440
}
44184441

44194442
trait SequenceLengthV11 extends Operator {
4420-
def SequenceLengthV11[@sp S <: Seq[Tensor[UByte]] | Seq[Tensor[UShort]] | Seq[Tensor[UInt]] | Seq[
4443+
def SequenceLengthV11[@sp S <: Seq[Tensor[UByte]] | Seq[Tensor[UShort]] | Seq[
4444+
Tensor[UInt]
4445+
] | Seq[
44214446
Tensor[ULong]
44224447
] | Seq[Tensor[Byte]] | Seq[Tensor[Short]] | Seq[Tensor[Int]] | Seq[Tensor[Long]] | Seq[
44234448
Tensor[Float16]
@@ -4668,7 +4693,9 @@ package object onnx {
46684693
Float
46694694
] | Complex[Double]: Numeric: ClassTag,
46704695
@sp I <: Int | Long: Numeric: ClassTag,
4671-
@sp S <: Seq[Tensor[UByte]] | Seq[Tensor[UShort]] | Seq[Tensor[UInt]] | Seq[Tensor[ULong]] | Seq[
4696+
@sp S <: Seq[Tensor[UByte]] | Seq[Tensor[UShort]] | Seq[Tensor[UInt]] | Seq[
4697+
Tensor[ULong]
4698+
] | Seq[
46724699
Tensor[Byte]
46734700
] | Seq[Tensor[Short]] | Seq[Tensor[Int]] | Seq[Tensor[Long]] | Seq[Tensor[Float16]] | Seq[
46744701
Tensor[Float]
@@ -5277,7 +5304,7 @@ package object onnx {
52775304

52785305
trait WhereV9 extends Operator {
52795306
def WhereV9[
5280-
@sp B <: Boolean: Numeric: ClassTag,
5307+
@sp B <: Boolean: ClassTag,
52815308
@sp T <: UByte | UShort | UInt | ULong | Byte | Short | Int | Long | Float16 | Float | Double | String | Boolean | Complex[
52825309
Float
52835310
] | Complex[Double]: Numeric: ClassTag

0 commit comments

Comments
 (0)