Skip to content

Commit 689ba2a

Browse files
committed
Fix Expand, Unsqueeze ops
1 parent 2aee1c1 commit 689ba2a

File tree

1 file changed

+7
-8
lines changed

1 file changed

+7
-8
lines changed

core/src/main/scala/ONNX.scala

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -388,16 +388,16 @@ package object onnx {
388388
}
389389
}
390390

391-
//Missing in NDScala - P2 - needs expand match type - WIP
392-
//Explicit broadcasting
391+
//Missing constraint - need an equivalent of the size equality constraint on Squeeze, but that asserts the shapes are broadcastable
392+
//Explicit broadcasting - can fail
393393
trait ExpandV13 extends Operator {
394394
def ExpandV13[
395395
@sp T <: UByte | UShort | UInt | ULong | Byte | Short | Int | Long | BFloat16 | Float16 | Float | Double | String | Boolean | Complex[
396396
Float
397-
] | Complex[Double]: Numeric
398-
, Tt <: TensorTypeDenotation, Td <: TensorShapeDenotation, S <: Shape, Tt1 <: TensorTypeDenotation, Td1 <: TensorShapeDenotation, S1 <: Shape, Tt2 <: TensorTypeDenotation, Td2 <: TensorShapeDenotation, S2 <: Shape](name: String, input: Tensor[T, Tuple3[Tt,Td,S]], shapeInput: S2)(using tt: ValueOf[Tt2], td: TensorShapeDenotationOf[Td2], s: ShapeOf[S2]): Tensor[T, Tuple3[Tt2,Td2,S2]] = {
397+
] | Complex[Double]
398+
, Tt <: TensorTypeDenotation, Td <: TensorShapeDenotation, S <: Shape, Tt1 <: TensorTypeDenotation, Td1 <: TensorShapeDenotation, S1 <: Shape, Tt2 <: TensorTypeDenotation, Td2 <: TensorShapeDenotation, S2 <: Shape](name: String, input: Tensor[T, Tuple3[Tt,Td,S]], shapeInput: Tensor[Long, Tuple3[Tt1,Td1,S1]])(using tt: ValueOf[Tt2], td: TensorShapeDenotationOf[Td2], s: ShapeOf[S2]): Tensor[T, Tuple3[Tt2,Td2,S2]] = {
399399
val map: Map[String, Any] = Map()
400-
val allInputs = Tuple2(input, shapeInput.toSeq.toArray)
400+
val allInputs = Tuple2(input, shapeInput)
401401
(callOp(name, "Expand", allInputs, map))
402402
}
403403
}
@@ -1147,14 +1147,13 @@ package object onnx {
11471147
(callOp(name, "Transpose", allInputs, map))
11481148
}
11491149
}
1150-
//Missing in NDScala - P2 - Needs expand match type for output - WIP
11511150
//Missing in ONNX.js
11521151
trait UnsqueezeV13 extends Operator {
11531152
def UnsqueezeV13[
11541153
@sp T <: UByte | UShort | UInt | ULong | Byte | Short | Int | Long | BFloat16 | Float16 | Float | Double | String | Boolean | Complex[
11551154
Float
1156-
] | Complex[Double]: Numeric
1157-
, Tt <: TensorTypeDenotation, Td <: TensorShapeDenotation, S <: Shape, Tt1 <: TensorTypeDenotation, Axes <: Indices](name: String, axes: Axes, data: Tensor[T, Tuple3[Tt,Td,S]])(using tt: ValueOf[Tt1], td: TensorShapeDenotationOf[Td], s: ShapeOf[UnsqueezeShape[S,Axes]], i: IndicesOf[Axes]): Tensor[T, Tuple3[Tt1,Td,UnsqueezeShape[S,Axes]]] = {
1155+
] | Complex[Double]
1156+
, Tt <: TensorTypeDenotation, Td <: TensorShapeDenotation, S <: Shape, Tt1 <: TensorTypeDenotation, Axes <: Indices](name: String, axes: Option[Axes] = None, data: Tensor[T, Tuple3[Tt,Td,S]])(using tt: ValueOf[Tt1], td: TensorShapeDenotationOf[Td], s: ShapeOf[UnsqueezeShape[S,Axes]], i: IndicesOf[Axes]): Tensor[T, Tuple3[Tt1,Td,UnsqueezeShape[S,Axes]]] = {
11581157
val axes = indicesOf[Axes].indices.toArray
11591158
val map: Map[String, Any] = Map()
11601159
val allInputs = Tuple2(data, Tensor(axes.map(_.toLong), Shape.fromSeq(ArraySeq.unsafeWrapArray(Array(axes.size)))))

0 commit comments

Comments
 (0)