Skip to content

Commit f3e7fb4

Browse files
committed
Fix two small bugs causing test failures in NDScala
1 parent 67a356e commit f3e7fb4

File tree

1 file changed

+11
-2
lines changed

1 file changed

+11
-2
lines changed

backends/.jvm/src/main/scala/ORTOperatorBackend.scala

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,8 +116,8 @@ trait ORTOperatorBackend extends OpToONNXBytesConverter with AutoCloseable {
116116
Some(x.map { y =>
117117
getOnnxTensor(y._1, y._2._3.toSeq.toArray, env)
118118
})
119+
case None => None
119120
}
120-
case None => None
121121
case tens: Tensor[T, Tuple3[Tt, Td, S]] =>
122122
Some(tens.map { x =>
123123
getOnnxTensor(x._1, x._2._3.toSeq.toArray, env)
@@ -155,7 +155,15 @@ trait ORTOperatorBackend extends OpToONNXBytesConverter with AutoCloseable {
155155
} yield res(
156156
opToModelProto(
157157
opName,
158-
(t.map(_.getInfo.onnxType.value) zip t.map(_.getInfo.getShape.map(_.toInt))),
158+
(t.map(_.getInfo.onnxType.value) zip { t.map(_.getInfo.getShape.map(_.toInt) match {
159+
//ORT shape inference diverges from the ONNX spec in requiring a scalar here instead of a tensor with shape,
160+
//causing a crash without this fix
161+
case Array(1) => if(opName.equals("Dropout")) Array[Int]() else Array(1)
162+
case y: Array[Int] => y
163+
}
164+
)
165+
}
166+
),
159167
attrs
160168
).toByteArray,
161169
tens
@@ -185,6 +193,7 @@ trait ORTOperatorBackend extends OpToONNXBytesConverter with AutoCloseable {
185193
opName,
186194
attrs
187195
)
196+
//TODO: now that this is otherwise working, try memoizing here
188197
result.flatMap(IO.println("Real call opName => " + opName).as(_))
189198
}
190199

0 commit comments

Comments
 (0)