Skip to content

Commit 1cbcbca

Browse files
committed
Fix shape constraints on Squeeze and Unsqueeze ops; Remove runtime require that tensors be rank <=4
1 parent b557d50 commit 1cbcbca

File tree

2 files changed

+13
-11
lines changed

2 files changed

+13
-11
lines changed

core/src/main/scala/ONNX.scala

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,15 @@ import spire.math.Numeric
1212
import spire.implicits._
1313
import spire.algebra.Field
1414
import org.emergentorder.onnx.Tensors._
15+
//import scala.compiletime.ops.int //For RC2
1516
import scala.compiletime.ops.int._
1617
import io.kjaer.compiletime._
1718
import io.kjaer.compiletime.Shape.NumElements
1819
import org.emergentorder.compiletime._
1920
import org.emergentorder.compiletime.TensorShapeDenotation.Reverse
2021
package object onnx {
2122

23+
//TODO: Symbolic shape values
2224
//TODO: Support bfloat16 type (new in ONNX 1.8.0)
2325
//TODO: Fix propagation behavavior for TensorShapeDenotation
2426
//TODO: Encode node names as types
@@ -3234,9 +3236,6 @@ package object onnx {
32343236
}
32353237
}
32363238

3237-
//TODO: Constraint for shape denotation
3238-
//
3239-
//+ match types on axes
32403239
trait ReshapeV5 extends Operator {
32413240
def ReshapeV5[
32423241
@sp T <: UByte | UShort | UInt | ULong | Byte | Short | Int | Long | Float16 | Float | Double | String | Boolean | Complex[
@@ -4041,13 +4040,14 @@ package object onnx {
40414040
}
40424041

40434042
//Missing V13
4044-
//TODO: Constraint
4043+
//TODO: Constraint on range, and for V1, disallow negative indexing
40454044
trait SqueezeV11 extends Operator {
40464045
def SqueezeV11[
40474046
@sp T <: UByte | UShort | UInt | ULong | Byte | Short | Int | Long | Float16 | Float | Double | String | Boolean | Complex[
40484047
Float
40494048
] | Complex[Double]
4050-
, Tt <: TensorTypeDenotation, Td <: TensorShapeDenotation, S <: Shape, Tt1 <: TensorTypeDenotation, Td1 <: TensorShapeDenotation, S1 <: Shape](name: String, axes: Option[(Array[Int])] = None, data: Tensor[T, Tuple3[Tt,Td,S]])(using tt: ValueOf[Tt1], td: TensorShapeDenotationOf[Td1], s: ShapeOf[S1]): Tensor[T, Tuple3[Tt1,Td1,S1]] = {
4049+
, Tt <: TensorTypeDenotation, Td <: TensorShapeDenotation, S <: Shape, Tt1 <: TensorTypeDenotation, Axis <: Indices](name: String, axes: Option[(Axis)] = None, data: Tensor[T, Tuple3[Tt,Td,S]])(using tt: ValueOf[Tt1], td: TensorShapeDenotationOf[KeepOrReduceDimDenotations[Td,Axis,false]], s: ShapeOf[KeepOrReduceDims[S,Axis,false]], i: IndicesOf[Axis]): Tensor[T, Tuple3[Tt1,KeepOrReduceDimDenotations[Td,Axis,false],KeepOrReduceDims[S,Axis,false]]] = {
4050+
val axes = indicesOf[Axis].indices.toArray
40514051
val map: Map[String, Any] = Map("axes" -> axes)
40524052
val allInputs = Tuple1(data)
40534053
(callOp(name, "Squeeze", allInputs, map))
@@ -4059,7 +4059,8 @@ package object onnx {
40594059
@sp T <: UByte | UShort | UInt | ULong | Byte | Short | Int | Long | Float16 | Float | Double | String | Boolean | Complex[
40604060
Float
40614061
] | Complex[Double]: Numeric
4062-
, Tt <: TensorTypeDenotation, Td <: TensorShapeDenotation, S <: Shape, Tt1 <: TensorTypeDenotation, Td1 <: TensorShapeDenotation, S1 <: Shape](name: String, axes: Option[(Array[Int])] = None, data: Tensor[T, Tuple3[Tt,Td,S]])(using tt: ValueOf[Tt1], td: TensorShapeDenotationOf[Td1], s: ShapeOf[S1]): Tensor[T, Tuple3[Tt1,Td1,S1]] = {
4062+
, Tt <: TensorTypeDenotation, Td <: TensorShapeDenotation, S <: Shape, Tt1 <: TensorTypeDenotation, Axis <: Indices](name: String, axes: Option[(Axis)] = None, data: Tensor[T, Tuple3[Tt,Td,S]])(using tt: ValueOf[Tt1], td: TensorShapeDenotationOf[KeepOrReduceDimDenotations[Td,Axis,false]], s: ShapeOf[KeepOrReduceDims[S,Axis,false]], i: IndicesOf[Axis]): Tensor[T, Tuple3[Tt1,KeepOrReduceDimDenotations[Td,Axis,false],KeepOrReduceDims[S,Axis,false]]] = {
4063+
val axes = indicesOf[Axis].indices.toArray
40634064
val map: Map[String, Any] = Map("axes" -> axes)
40644065
val allInputs = Tuple1(data)
40654066
(callOp(name, "Squeeze", allInputs, map))
@@ -4399,13 +4400,14 @@ package object onnx {
43994400
}
44004401
*/
44014402
//Missing V13
4402-
//TODO: Constraint
4403+
//TODO: Constraint on range, and for V1, disallowing negative indexes
44034404
trait UnsqueezeV11 extends Operator {
44044405
def UnsqueezeV11[
44054406
@sp T <: UByte | UShort | UInt | ULong | Byte | Short | Int | Long | Float16 | Float | Double | String | Boolean | Complex[
44064407
Float
44074408
] | Complex[Double]: Numeric
4408-
, Tt <: TensorTypeDenotation, Td <: TensorShapeDenotation, S <: Shape, Tt1 <: TensorTypeDenotation, Td1 <: TensorShapeDenotation, S1 <: Shape](name: String, axes: (Array[Int]), data: Tensor[T, Tuple3[Tt,Td,S]])(using tt: ValueOf[Tt1], td: TensorShapeDenotationOf[Td1], s: ShapeOf[S1]): Tensor[T, Tuple3[Tt1,Td1,S1]] = {
4409+
, Tt <: TensorTypeDenotation, Td <: TensorShapeDenotation, S <: Shape, Tt1 <: TensorTypeDenotation, Axis <: Indices](name: String, axes: Option[(Axis)] = None, data: Tensor[T, Tuple3[Tt,Td,S]])(using tt: ValueOf[Tt1], td: TensorShapeDenotationOf[KeepOrReduceDimDenotations[Td,Axis,false]], s: ShapeOf[KeepOrReduceDims[S,Axis,false]], i: IndicesOf[Axis]): Tensor[T, Tuple3[Tt1,KeepOrReduceDimDenotations[Td,Axis,false],KeepOrReduceDims[S,Axis,false]]] = {
4410+
val axes = indicesOf[Axis].indices.toArray
44094411
val map: Map[String, Any] = Map("axes" -> axes)
44104412
val allInputs = Tuple1(data)
44114413
(callOp(name, "Unsqueeze", allInputs, map))
@@ -4417,7 +4419,8 @@ package object onnx {
44174419
@sp T <: UByte | UShort | UInt | ULong | Byte | Short | Int | Long | Float16 | Float | Double | String | Boolean | Complex[
44184420
Float
44194421
] | Complex[Double]: Numeric
4420-
, Tt <: TensorTypeDenotation, Td <: TensorShapeDenotation, S <: Shape, Tt1 <: TensorTypeDenotation, Td1 <: TensorShapeDenotation, S1 <: Shape](name: String, axes: (Array[Int]), data: Tensor[T, Tuple3[Tt,Td,S]])(using tt: ValueOf[Tt1], td: TensorShapeDenotationOf[Td1], s: ShapeOf[S1]): Tensor[T, Tuple3[Tt1,Td1,S1]] = {
4422+
, Tt <: TensorTypeDenotation, Td <: TensorShapeDenotation, S <: Shape, Tt1 <: TensorTypeDenotation, Axis <: Indices](name: String, axes: Option[(Axis)] = None, data: Tensor[T, Tuple3[Tt,Td,S]])(using tt: ValueOf[Tt1], td: TensorShapeDenotationOf[KeepOrReduceDimDenotations[Td,Axis,false]], s: ShapeOf[KeepOrReduceDims[S,Axis,false]], i: IndicesOf[Axis]): Tensor[T, Tuple3[Tt1,KeepOrReduceDimDenotations[Td,Axis,false],KeepOrReduceDims[S,Axis,false]]] = {
4423+
val axes = indicesOf[Axis].indices.toArray
44214424
val map: Map[String, Any] = Map("axes" -> axes)
44224425
val allInputs = Tuple1(data)
44234426
(callOp(name, "Unsqueeze", allInputs, map))

core/src/main/scala/Tensors.scala

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,7 @@ object Tensors{
8282

8383
def tensorRequires[T <: Supported, Tt <: TensorTypeDenotation, Td <: TensorShapeDenotation, S <: Shape](tens: Tensor[T,Tuple3[Tt,Td,S]]): Tensor[T,Tuple3[Tt, Td, S]] = {
8484
//require(tens._2._2.toSeq.size == tens.shape.size) //We allow empty denotations
85-
require(tens.shape.size <= 4)
86-
require(tens.data.size == tens.shape.foldLeft(1)(_ * _))
85+
require(tens.data.size == tens.shape.foldLeft(1)(_ * _)) //This shouldn't fail at runtime, if so shape constraints need fixing
8786
tens
8887
}
8988
def apply[T <: Supported : scala.reflect.ClassTag, Tt <: TensorTypeDenotation, TD <: TensorShapeDenotation](element: T, tt: Tt, td: TD): Tensor[T, Tuple3[Tt, TD, 1 #: SNil]] = tensorRequires((Array[T](element), (tt, td, 1 #: SNil)))

0 commit comments

Comments
 (0)