File tree Expand file tree Collapse file tree 3 files changed +32
-7
lines changed Expand file tree Collapse file tree 3 files changed +32
-7
lines changed Original file line number Diff line number Diff line change @@ -120,7 +120,19 @@ trait ORTOperatorBackend extends OpToONNXBytesConverter {
120
120
t <- tens
121
121
} yield opToModelProto(
122
122
opName,
123
- (t.map(_.asInstanceOf [tensorMod.Tensor ].`type`.valueOf.asInstanceOf [Float ].round)
123
+ (t.map(x => x.asInstanceOf [tensorMod.Tensor ].`type`.valueOf.toString match {
124
+ // Can't access the enum int values here
125
+ // But it's fine, doesn't match the ONNX spec anyway
126
+ case " int8" => 3
127
+ case " int16" => 5
128
+ case " float64" => 11
129
+ case " float32" => 1
130
+ case " int32" => 6
131
+ case " int64" => 7
132
+ case " bool" => 9
133
+ case y => y.toInt
134
+ }
135
+ )
124
136
zip t.map(_.dims.map(_.toInt).toArray)),
125
137
attrs
126
138
).toByteArray
Original file line number Diff line number Diff line change @@ -155,7 +155,20 @@ trait ORTOperatorBackend extends OpToONNXBytesConverter with AutoCloseable {
155
155
} yield res(
156
156
opToModelProto(
157
157
opName,
158
- (t.map(_.getInfo.onnxType.value) zip { t.map(_.getInfo.getShape.map(_.toInt) match {
158
+ (t.map(_.getInfo.onnxType.value match {
159
+ // ORT has two different enums for this for the Java and C APIs
160
+ // Neither matches the ONNX spec
161
+ case 2 => 3
162
+ case 4 => 5
163
+ case 10 => 1
164
+ case 8 => 7
165
+ case 13 => 9
166
+ case n => n
167
+ }
168
+ )
169
+
170
+ zip
171
+ { t.map(_.getInfo.getShape.map(_.toInt) match {
159
172
// ORT shape inference diverges from the ONNX spec in requiring a scalar here instead of a tensor with shape,
160
173
// causing a crash without this fix
161
174
case Array (1 ) => if (opName.equals(" Dropout" )) Array [Int ]() else Array (1 )
Original file line number Diff line number Diff line change @@ -181,13 +181,13 @@ trait OpToONNXBytesConverter {
181
181
{
182
182
183
183
val elemType = elemTypeIn match {
184
- case 2 => INT8 .index
185
- case 4 => INT16 .index
184
+ case 3 => INT8 .index
185
+ case 5 => INT16 .index
186
186
case 11 => DOUBLE .index
187
- case 10 => FLOAT .index
187
+ case 1 => FLOAT .index
188
188
case 6 => INT32 .index
189
- case 8 => INT64 .index
190
- case 13 => BOOL .index
189
+ case 7 => INT64 .index
190
+ case 9 => BOOL .index
191
191
case _ => INT64 .index // In case of Scala.js BigInt
192
192
}
193
193
// tens.shape.map { y =>
You can’t perform that action at this time.
0 commit comments