Skip to content

Commit a43f514

Browse files
committed
Bump ONNX-Scala & ONNX opset and IR versions; support Boolean outputs
1 parent 6f116c8 commit a43f514

File tree

7 files changed

+24
-10
lines changed

7 files changed

+24
-10
lines changed

backends/src/main/scala/ORTOperatorBackend.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,13 @@ trait ORTOperatorBackend
117117
buff.get(x)
118118
}.toArray
119119
}
120+
case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL =>{
121+
val point = value.GetTensorMutableDataBool.capacity(size)
122+
val booleanPoint = new BooleanPointer(point.asByteBuffer) //C++ bool size is not defined, could cause problems on some platforms
123+
(0 until booleanPoint.capacity().toInt).map { x =>
124+
booleanPoint.get(x)
125+
}.toArray
126+
}
120127
}
121128
TensorFactory.getTensor(arr, shape)
122129
}

backends/src/main/scala/ORTOperatorBackend213.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,13 @@ trait ORTOperatorBackend
116116
buff.get(x)
117117
}.toArray
118118
}
119+
case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL =>{
120+
val point = value.GetTensorMutableDataBool.capacity(size)
121+
val booleanPoint = new BooleanPointer(point.asByteBuffer) //C++ bool size is not defined, could cause problems on some platforms
122+
(0 until booleanPoint.capacity().toInt).map { x =>
123+
booleanPoint.get(x)
124+
}.toArray
125+
}
119126
}
120127
TensorFactory.getTensor(arr, shape)
121128
}

build.sbt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ scalaVersion := scala213Version
1010
lazy val commonSettings = Seq(
1111
// scalaJSUseMainModuleInitializer := true, //Test only
1212
organization := "org.emergentorder.onnx",
13-
version := "0.3.0",
13+
version := "0.4.0",
1414
scalaVersion := scala213Version,
1515
resolvers += Resolver.mavenLocal,
1616
resolvers += "Sonatype OSS Snapshots" at "https://oss.sonatype.org/content/repositories/snapshots",

core/src/main/scala/ONNX213.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3637,7 +3637,7 @@ package object onnx {
36373637
](name, "Equal", allInputs, map))
36383638
}
36393639

3640-
def Equal11[@sp T: Numeric: ClassTag, @sp T1: Numeric: ClassTag](
3640+
def Equal11[@sp T: Numeric: ClassTag, @sp T1: ClassTag](
36413641
name: String,
36423642
A: Option[Tensor[T]],
36433643
B: Option[Tensor[T]]
@@ -5123,7 +5123,7 @@ package object onnx {
51235123
](name, "Greater", allInputs, map))
51245124
}
51255125

5126-
def Greater9[@sp T: Numeric: ClassTag, @sp T1: Numeric: ClassTag](
5126+
def Greater9[@sp T: Numeric: ClassTag, @sp T1: ClassTag](
51275127
name: String,
51285128
A: Option[Tensor[T]],
51295129
B: Option[Tensor[T]]
@@ -6188,7 +6188,7 @@ package object onnx {
61886188
](name, "Less", allInputs, map))
61896189
}
61906190

6191-
def Less9[@sp T: Numeric: ClassTag, @sp T1: Numeric: ClassTag](
6191+
def Less9[@sp T: Numeric: ClassTag, @sp T1: ClassTag](
61926192
name: String,
61936193
A: Option[Tensor[T]],
61946194
B: Option[Tensor[T]]

core/src/main/scala/ONNXHelper.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ import java.io.File
2828
import scala.reflect.io.Streamable
2929
import org.bytedeco.javacpp._
3030
import org.bytedeco.onnx._
31-
import org.bytedeco.onnx.global.onnx.ParseProtoFromBytes
31+
import org.bytedeco.onnx.global.onnx.{ParseProtoFromBytes,check_model}
3232

3333
class ONNXHelper(val byteArray: Array[Byte]) extends AutoCloseable {
3434

@@ -45,7 +45,7 @@ class ONNXHelper(val byteArray: Array[Byte]) extends AutoCloseable {
4545
bytes,
4646
byteArray.length.toLong
4747
)
48-
// bytes.close
48+
check_model(r)
4949
r
5050
}
5151

core/src/main/scala/OpToONNXBytesConverter.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -168,10 +168,10 @@ trait OpToONNXBytesConverter extends AutoCloseable {
168168

169169
// origNode.close
170170
model.set_allocated_graph(graph)
171-
model.set_ir_version(3)
171+
model.set_ir_version(6)
172172

173173
model.add_opset_import
174-
model.opset_import(0).set_version(8)
174+
model.opset_import(0).set_version(12)
175175

176176
val outputValueInfo = graph.add_output
177177

core/src/main/scala/OpToONNXBytesConverter213.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -224,10 +224,10 @@ trait OpToONNXBytesConverter extends AutoCloseable {
224224

225225
origNode.close
226226
model.set_allocated_graph(graph)
227-
model.set_ir_version(3)
227+
model.set_ir_version(6)
228228

229229
model.add_opset_import
230-
model.opset_import(0).set_version(8)
230+
model.opset_import(0).set_version(12)
231231

232232
val outputValueInfo = graph.add_output
233233

0 commit comments

Comments
 (0)