Skip to content

Commit 2aee1c1

Browse files
committed
WIP expand and unsqueeze ops
1 parent c756c96 commit 2aee1c1

File tree

2 files changed

+24
-7
lines changed

2 files changed

+24
-7
lines changed

core/src/main/scala/ONNX.scala

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -212,18 +212,19 @@ package object onnx {
212212
(callOp(name, "BatchNormalization", allInputs, map))
213213
}
214214
}
215+
/* - Not supported - cast on the JVM side
215216
//Missing in NDScala P2 - needs match type from data type to int
216217
trait CastV13 extends Operator {
217218
def CastV9[
218219
@sp T1 <: BFloat16 | Float16 | Float | Double | Byte | Short | Int | Long | UByte | UShort | UInt | ULong | Boolean | String: Numeric,
219220
@sp T2 <: BFloat16 | Float16 | Float | Double | Byte | Short | Int | Long | UByte | UShort | UInt | ULong | Boolean | String: Numeric
220-
, Tt <: TensorTypeDenotation, Td <: TensorShapeDenotation, S <: Shape](name: String, to: (Int), input: Tensor[T1,Tuple3[Tt, Td, S]])(using tt: ValueOf[Tt], td: TensorShapeDenotationOf[Td], s: ShapeOf[S]): Tensor[T2, Tuple3[Tt, Td, S]] = {
221+
, Tt <: TensorTypeDenotation, Td <: TensorShapeDenotation, S <: Shape](name: String, input: Tensor[T1,Tuple3[Tt, Td, S]])(using tt: ValueOf[Tt], td: TensorShapeDenotationOf[Td], s: ShapeOf[S]): Tensor[T2, Tuple3[Tt, Td, S]] = {
221222
val map: Map[String, Any] = Map("to" -> to)
222223
val allInputs = Tuple1(input)
223224
(callOp(name, "Cast", allInputs, map))
224225
}
225226
}
226-
227+
*/
227228
trait CeilV13 extends Operator {
228229
def CeilV13[@sp T <: BFloat16 | Float16 | Float | Double: Numeric, Tt <: TensorTypeDenotation, Td <: TensorShapeDenotation, S <: Shape](
229230
name: String,
@@ -387,16 +388,16 @@ package object onnx {
387388
}
388389
}
389390

390-
//Missing in NDScala - P2 - needs expand match type
391+
//Missing in NDScala - P2 - needs expand match type - WIP
391392
//Explicit broadcasting
392393
trait ExpandV13 extends Operator {
393394
def ExpandV13[
394395
@sp T <: UByte | UShort | UInt | ULong | Byte | Short | Int | Long | BFloat16 | Float16 | Float | Double | String | Boolean | Complex[
395396
Float
396397
] | Complex[Double]: Numeric
397-
, Tt <: TensorTypeDenotation, Td <: TensorShapeDenotation, S <: Shape, Tt1 <: TensorTypeDenotation, Td1 <: TensorShapeDenotation, S1 <: Shape, Tt2 <: TensorTypeDenotation, Td2 <: TensorShapeDenotation, S2 <: Shape](name: String, input: Tensor[T, Tuple3[Tt,Td,S]], shapeInput: Tensor[Long, Tuple3[Tt1,Td1,S1]])(using tt: ValueOf[Tt2], td: TensorShapeDenotationOf[Td2], s: ShapeOf[S2]): Tensor[T, Tuple3[Tt2,Td2,S2]] = {
398+
, Tt <: TensorTypeDenotation, Td <: TensorShapeDenotation, S <: Shape, Tt1 <: TensorTypeDenotation, Td1 <: TensorShapeDenotation, S1 <: Shape, Tt2 <: TensorTypeDenotation, Td2 <: TensorShapeDenotation, S2 <: Shape](name: String, input: Tensor[T, Tuple3[Tt,Td,S]], shapeInput: S2)(using tt: ValueOf[Tt2], td: TensorShapeDenotationOf[Td2], s: ShapeOf[S2]): Tensor[T, Tuple3[Tt2,Td2,S2]] = {
398399
val map: Map[String, Any] = Map()
399-
val allInputs = Tuple2(input, shapeInput)
400+
val allInputs = Tuple2(input, shapeInput.toSeq.toArray)
400401
(callOp(name, "Expand", allInputs, map))
401402
}
402403
}
@@ -1146,14 +1147,14 @@ package object onnx {
11461147
(callOp(name, "Transpose", allInputs, map))
11471148
}
11481149
}
1149-
//Missing in NDScala - P2 - Needs expand match type for output
1150+
//Missing in NDScala - P2 - Needs expand match type for output - WIP
11501151
//Missing in ONNX.js
11511152
trait UnsqueezeV13 extends Operator {
11521153
def UnsqueezeV13[
11531154
@sp T <: UByte | UShort | UInt | ULong | Byte | Short | Int | Long | BFloat16 | Float16 | Float | Double | String | Boolean | Complex[
11541155
Float
11551156
] | Complex[Double]: Numeric
1156-
, Tt <: TensorTypeDenotation, Td <: TensorShapeDenotation, S <: Shape, Tt1 <: TensorTypeDenotation, Axes <: Indices](name: String, axes: Axes, data: Tensor[T, Tuple3[Tt,Td,S]])(using tt: ValueOf[Tt1], td: TensorShapeDenotationOf[KeepOrReduceDimDenotations[Td,Axes,false]], s: ShapeOf[KeepOrReduceDims[S,Axes,false]], i: IndicesOf[Axes]): Tensor[T, Tuple3[Tt1,KeepOrReduceDimDenotations[Td,Axes,false],KeepOrReduceDims[S,Axes,false]]] = {
1157+
, Tt <: TensorTypeDenotation, Td <: TensorShapeDenotation, S <: Shape, Tt1 <: TensorTypeDenotation, Axes <: Indices](name: String, axes: Axes, data: Tensor[T, Tuple3[Tt,Td,S]])(using tt: ValueOf[Tt1], td: TensorShapeDenotationOf[Td], s: ShapeOf[UnsqueezeShape[S,Axes]], i: IndicesOf[Axes]): Tensor[T, Tuple3[Tt1,Td,UnsqueezeShape[S,Axes]]] = {
11571158
val axes = indicesOf[Axes].indices.toArray
11581159
val map: Map[String, Any] = Map()
11591160
val allInputs = Tuple2(data, Tensor(axes.map(_.toLong), Shape.fromSeq(ArraySeq.unsafeWrapArray(Array(axes.size)))))

core/src/main/scala/Tensors.scala

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,22 @@ object Tensors{
8282
}
8383
}
8484

85+
type UnsqueezeShape[S <: Shape, AxisIndex <: None.type | Indices] <: Shape = AxisIndex match {
86+
case None.type => SNil
87+
case Indices => UnsqueezeShapeLoop[S, AxisIndex, 0]
88+
}
89+
90+
protected type UnsqueezeShapeLoop[ToUnsqueeze <: Shape, AxisIndex <: Indices, I <: Index] <: Shape = ToUnsqueeze match {
91+
case head #: tail => Indices.Contains[AxisIndex, I] match {
92+
case true => 1 #: head #: UnsqueezeShapeLoop[tail, Indices.RemoveValue[AxisIndex, I], S[I]]
93+
case false => head #: UnsqueezeShapeLoop[tail, AxisIndex, S[I]]
94+
}
95+
case SNil => AxisIndex match {
96+
case INil => SNil
97+
}
98+
}
99+
100+
85101
type FlattenedShape[S <: Shape, AxisIndex <: None.type | Indices] <: Shape = AxisIndex match {
86102
case None.type => SNil
87103
case Indices => FlattenedShapeLoop[S, AxisIndex, 0, 1]

0 commit comments

Comments
 (0)