Skip to content

Commit dda57f2

Browse files
committed
Handle tensor data types correctly on both JVM and JS backends
1 parent f3e7fb4 commit dda57f2

File tree

3 files changed

+32
-7
lines changed

3 files changed

+32
-7
lines changed

backends/.js/src/main/scala/ORTOperatorBackend.scala

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,19 @@ trait ORTOperatorBackend extends OpToONNXBytesConverter {
120120
t <- tens
121121
} yield opToModelProto(
122122
opName,
123-
(t.map(_.asInstanceOf[tensorMod.Tensor].`type`.valueOf.asInstanceOf[Float].round)
123+
(t.map(x => x.asInstanceOf[tensorMod.Tensor].`type`.valueOf.toString match {
124+
//Can't access the enum int values here
125+
//But it's fine, doesn't match the ONNX spec anyway
126+
case "int8" => 3
127+
case "int16" => 5
128+
case "float64" => 11
129+
case "float32" => 1
130+
case "int32" => 6
131+
case "int64" => 7
132+
case "bool" => 9
133+
case y => y.toInt
134+
}
135+
)
124136
zip t.map(_.dims.map(_.toInt).toArray)),
125137
attrs
126138
).toByteArray

backends/.jvm/src/main/scala/ORTOperatorBackend.scala

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,20 @@ trait ORTOperatorBackend extends OpToONNXBytesConverter with AutoCloseable {
155155
} yield res(
156156
opToModelProto(
157157
opName,
158-
(t.map(_.getInfo.onnxType.value) zip { t.map(_.getInfo.getShape.map(_.toInt) match {
158+
(t.map(_.getInfo.onnxType.value match {
159+
//ORT has two different enums for this for the Java and C APIs
160+
//Neither matches the ONNX spec
161+
case 2 => 3
162+
case 4 => 5
163+
case 10 => 1
164+
case 8 => 7
165+
case 13 => 9
166+
case n => n
167+
}
168+
)
169+
170+
zip
171+
{ t.map(_.getInfo.getShape.map(_.toInt) match {
159172
//ORT shape inference diverges from the ONNX spec in requiring a scalar here instead of a tensor with shape,
160173
//causing a crash without this fix
161174
case Array(1) => if(opName.equals("Dropout")) Array[Int]() else Array(1)

core/src/main/scala/OpToONNXBytesConverter.scala

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -181,13 +181,13 @@ trait OpToONNXBytesConverter {
181181
{
182182

183183
val elemType = elemTypeIn match {
184-
case 2 => INT8.index
185-
case 4 => INT16.index
184+
case 3 => INT8.index
185+
case 5 => INT16.index
186186
case 11 => DOUBLE.index
187-
case 10 => FLOAT.index
187+
case 1 => FLOAT.index
188188
case 6 => INT32.index
189-
case 8 => INT64.index
190-
case 13 => BOOL.index
189+
case 7 => INT64.index
190+
case 9 => BOOL.index
191191
case _ => INT64.index // In case of Scala.js BigInt
192192
}
193193
// tens.shape.map { y =>

0 commit comments

Comments
 (0)