Skip to content

Commit ade3d45

Browse files
committed
Add Flatten match type
1 parent e05d3ad commit ade3d45

File tree

2 files changed

+22
-3
lines changed

2 files changed

+22
-3
lines changed

core/src/main/scala/ONNX.scala

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -400,14 +400,18 @@ package object onnx {
400400
(callOp(name, "Expand", allInputs, map))
401401
}
402402
}
403-
//Missing in NDScala - P2 - needs flatten match type
403+
404404
trait FlattenV13 extends Operator {
405405
def FlattenV13[
406406
@sp T <: UByte | UShort | UInt | ULong | Byte | Short | Int | Long | BFloat16 | Float16 | Float | Double | String | Boolean | Complex[
407407
Float
408408
] | Complex[Double]: Numeric
409-
, Tt <: TensorTypeDenotation, Td <: TensorShapeDenotation, S <: Shape, Tt1 <: TensorTypeDenotation, Td1 <: TensorShapeDenotation, S1 <: Dimension #: Dimension #: SNil](name: String, axis: Int = 1, input: Tensor[T, Tuple3[Tt,Td,S]])(using tt: ValueOf[Tt1], td: TensorShapeDenotationOf[Td1], s: ShapeOf[S1]): Tensor[T, Tuple3[Tt1,Td1,S1]] = {
410-
val map: Map[String, Any] = Map("axis" -> axis)
409+
, Tt <: TensorTypeDenotation, Td <: TensorShapeDenotation, S <: Shape, Tt1 <: TensorTypeDenotation, Axis <: Index ::: INil](
410+
name: String,
411+
axis: Axis,
412+
input: Tensor[T, Tuple3[Tt,Td,S]])
413+
(using tt: ValueOf[Tt1], td: TensorShapeDenotationOf[Td], s: ShapeOf[FlattenedShape[S, Axis]]): Tensor[T, Tuple3[Tt1,Td,FlattenedShape[S, Axis]]] = {
414+
val map: Map[String, Any] = Map("axis" -> axis.indices.toArray.head)
411415
val allInputs = Tuple1(input)
412416
(callOp(name, "Flatten", allInputs, map))
413417
}

core/src/main/scala/Tensors.scala

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,21 @@ object Tensors{
8282
}
8383
}
8484

85+
type FlattenedShape[S <: Shape, AxisIndex <: None.type | Indices] <: Shape = AxisIndex match {
86+
case None.type => SNil
87+
case Indices => FlattenedShapeLoop[S, AxisIndex, 0, 1]
88+
}
89+
90+
protected type FlattenedShapeLoop[ToFlatten <: Shape, AxisIndex <: Indices, I <: Index, Acc <: Index] <: Shape = ToFlatten match {
91+
case head #: tail => Indices.Contains[AxisIndex, I] match {
92+
case true => Acc #: FlattenedShapeLoop[tail, Indices.RemoveValue[AxisIndex, I], S[I], head]
93+
case false => FlattenedShapeLoop[tail, AxisIndex, S[I], head * Acc]
94+
}
95+
case SNil => AxisIndex match {
96+
case INil => Acc #: SNil
97+
}
98+
}
99+
85100
type SlicedShape[AxisIndicesStarts <: None.type | Indices, AxisIndicesEnds <: None.type | Indices] <: Shape = AxisIndicesStarts match {
86101
case None.type => SNil
87102
case Indices => AxisIndicesEnds match {

0 commit comments

Comments
 (0)