Skip to content

Commit fa75ed5

Browse files
committed
Fix argmin/max axis attribute, clean up, fix constant op in fuseOps
1 parent add6492 commit fa75ed5

File tree

3 files changed

+5
-5489
lines changed

3 files changed

+5
-5489
lines changed

backends/.jvm/src/main/scala/ORTOperatorBackend.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ trait ORTOperatorBackend
123123
if(cacheValues.size == 0) return ModelProto()
124124
val nodes = cacheValues.map(_.getGraph.node).fold(Seq[NodeProto]())( (x,y) => x ++ y )
125125
val nodeOutputs = cacheValues.map(_.getGraph.output).fold(Seq[ValueInfoProto]())( (x,y) => x ++ y ).map(_.getName)
126-
val inputs = cacheValues.map(_.getGraph.input).fold(Seq[ValueInfoProto]())( (x,y) => x ++ y ).filter(z => ! nodeOutputs.contains(z.getName)).distinct
126+
val inputs = cacheValues.map(_.getGraph).filter(! _.node(0).opType.equals(Some("Constant"))).map(_.input).fold(Seq[ValueInfoProto]())( (x,y) => x ++ y ).filter(z => ! nodeOutputs.contains(z.getName)).distinct
127127
val outputs = cacheValues(cacheValues.size - 1).getGraph.output
128128
val modelProto = (cacheValues.head.clearGraph).withGraph((new GraphProto).withNode(nodes).withInput(inputs).withOutput(outputs))
129129
sessionCache.clear

core/src/main/scala/ONNX.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ package object onnx {
123123
keepdims: KeepDims = true,
124124
data: Tensor[T, Tuple3[Tt, Td, S]]
125125
)(using tt: ValueOf[Tt1], td: TensorShapeDenotationOf[KeepOrReduceDimDenotations[Td,Axis,KeepDims]], s: ShapeOf[KeepOrReduceDims[S,Axis,KeepDims]], i: IndicesOf[Axis], k: ValueOf[KeepDims]): Tensor[Long, Tuple3[Tt1, KeepOrReduceDimDenotations[Td,Axis,KeepDims], KeepOrReduceDims[S,Axis,KeepDims]]] = {
126-
val map: Map[String, Any] = Map("axis" -> indicesOf[Axis].indices.toArray, "keepdims" -> (if(valueOf[KeepDims]) 1 else 0))
126+
val map: Map[String, Any] = Map("axis" -> indicesOf[Axis].indices.toArray.head, "keepdims" -> (if(valueOf[KeepDims]) 1 else 0))
127127
val allInputs = Tuple1(data)
128128
(callOp(name, "ArgMax", allInputs, map))
129129
}
@@ -1217,7 +1217,7 @@ package object onnxruntime {
12171217
selectLastIndex: Int = 0,
12181218
data: Tensor[T, Tuple3[Tt, Td, S]]
12191219
)(using tt: ValueOf[Tt1], td: TensorShapeDenotationOf[KeepOrReduceDimDenotations[Td,Axis,KeepDims]], s: ShapeOf[KeepOrReduceDims[S,Axis,KeepDims]], i: IndicesOf[Axis], k: ValueOf[KeepDims]): Tensor[Long, Tuple3[Tt1, KeepOrReduceDimDenotations[Td,Axis,KeepDims], KeepOrReduceDims[S,Axis,KeepDims]]] = {
1220-
val map: Map[String, Any] = Map("axis" -> indicesOf[Axis].indices.toArray, "select_last_index" -> selectLastIndex, "keepdims" -> (if(valueOf[KeepDims]) 1 else 0))
1220+
val map: Map[String, Any] = Map("axis" -> indicesOf[Axis].indices.toArray.head, "select_last_index" -> selectLastIndex, "keepdims" -> (if(valueOf[KeepDims]) 1 else 0))
12211221
val allInputs = Tuple1(data)
12221222
(callOp(name, "ArgMax", allInputs, map))
12231223
}
@@ -1234,7 +1234,7 @@ package object onnxruntime {
12341234
selectLastIndex: Int = 0,
12351235
data: Tensor[T, Tuple3[Tt, Td, S]]
12361236
)(using tt: ValueOf[Tt1], td: TensorShapeDenotationOf[KeepOrReduceDimDenotations[Td,Axis,KeepDims]], s: ShapeOf[KeepOrReduceDims[S,Axis,KeepDims]], i: IndicesOf[Axis], k: ValueOf[KeepDims]): Tensor[Long, Tuple3[Tt1, KeepOrReduceDimDenotations[Td,Axis,KeepDims], KeepOrReduceDims[S,Axis,KeepDims]]] = {
1237-
val map: Map[String, Any] = Map("axis" -> indicesOf[Axis].indices.toArray, "select_last_index" -> selectLastIndex, "keepdims" -> (if(valueOf[KeepDims]) 1 else 0))
1237+
val map: Map[String, Any] = Map("axis" -> indicesOf[Axis].indices.toArray.head, "select_last_index" -> selectLastIndex, "keepdims" -> (if(valueOf[KeepDims]) 1 else 0))
12381238
val allInputs = Tuple1(data)
12391239
(callOp(name, "ArgMin", allInputs, map))
12401240
}
@@ -1249,7 +1249,7 @@ package object onnxruntime {
12491249
keepdims: KeepDims = true,
12501250
data: Tensor[T, Tuple3[Tt, Td, S]]
12511251
)(using tt: ValueOf[Tt1], td: TensorShapeDenotationOf[KeepOrReduceDimDenotations[Td,Axis,KeepDims]], s: ShapeOf[KeepOrReduceDims[S,Axis,KeepDims]], i: IndicesOf[Axis], k: ValueOf[KeepDims]): Tensor[Long, Tuple3[Tt1, KeepOrReduceDimDenotations[Td,Axis,KeepDims], KeepOrReduceDims[S,Axis,KeepDims]]] = {
1252-
val map: Map[String, Any] = Map("axis" -> indicesOf[Axis].indices.toArray, "keepdims" -> (if(valueOf[KeepDims]) 1 else 0))
1252+
val map: Map[String, Any] = Map("axis" -> indicesOf[Axis].indices.toArray.head, "keepdims" -> (if(valueOf[KeepDims]) 1 else 0))
12531253
val allInputs = Tuple1(data)
12541254
(callOp(name, "ArgMin", allInputs, map))
12551255
}

0 commit comments

Comments
 (0)