Skip to content

Commit 3ff58b3

Browse files
committed
Fix Gather op
1 parent 689ba2a commit 3ff58b3

File tree

2 files changed

+30
-8
lines changed

2 files changed

+30
-8
lines changed

core/src/main/scala/ONNX.scala

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -429,20 +429,21 @@ package object onnx {
429429
}
430430
}
431431
//Missing in NDScala - P3
432+
//need a match type
432433
trait GatherV13 extends Operator {
433434
def GatherV13[
434435
@sp T <: UByte | UShort | UInt | ULong | Byte | Short | Int | Long | BFloat16 | Float16 | Float | Double | String | Boolean | Complex[
435436
Float
436-
] | Complex[Double]: Numeric,
437-
@sp Tind <: Int | Long: Numeric
438-
, Tt <: TensorTypeDenotation, Td <: TensorShapeDenotation, S <: Shape, Tt1 <: TensorTypeDenotation, Td1 <: TensorShapeDenotation, S1 <: Shape, Tt2 <: TensorTypeDenotation, Td2 <: TensorShapeDenotation, S2 <: Shape](
437+
] | Complex[Double],
438+
@sp Tind <: Int : Numeric, //Spec also supports long
439+
Tt <: TensorTypeDenotation, Td <: TensorShapeDenotation, S <: Shape, Tt1 <: TensorTypeDenotation, Td1 <: TensorShapeDenotation, S1 <: Shape, Tt2 <: TensorTypeDenotation, Td2 <: TensorShapeDenotation, AxisIndex <: Index ::: INil, AxisIndices <: Indices](
439440
name: String,
440-
axis: Int = 0,
441+
axis: AxisIndex = 0 ::: INil,
441442
data: Tensor[T, Tuple3[Tt,Td,S]],
442-
indices: Tensor[Tind, Tuple3[Tt1,Td1,S1]]
443-
)(using tt: ValueOf[Tt2], td: TensorShapeDenotationOf[Td2], s: ShapeOf[S2]): Tensor[T, Tuple3[Tt2,Td2,S2]] = {
444-
val map: Map[String, Any] = Map("axis" -> axis)
445-
val allInputs = Tuple2(data, indices)
443+
indices: AxisIndices
444+
)(using tt: ValueOf[Tt2], td: TensorShapeDenotationOf[Td2], s: ShapeOf[GatheredShape[S, AxisIndex, AxisIndices]], i: IndicesOf[AxisIndex], i2: IndicesOf[AxisIndices]): Tensor[T, Tuple3[Tt2,Td2,GatheredShape[S, AxisIndex, AxisIndices]]] = {
445+
val map: Map[String, Any] = Map("axis" -> indicesOf[AxisIndex].indices.toArray.head)
446+
val allInputs = Tuple2(data, Tensor(indicesOf[AxisIndices].indices.toArray, indicesOf[AxisIndices].indices.toArray.size.asInstanceOf[io.kjaer.compiletime.Dimension] #: SNil))
446447
(callOp(name, "Gather", allInputs, map))
447448
}
448449
}

core/src/main/scala/Tensors.scala

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,27 @@ object Tensors{
9797
}
9898
}
9999

100+
type GatheredShape[S <: Shape, AxisIndex <: None.type | Indices, AxisIndices <: Indices] <: Shape = AxisIndex match {
101+
case None.type => SNil
102+
case Indices => GatheredShapeLoop[S, AxisIndex, 0, AxisIndices]
103+
}
104+
105+
protected type GatheredShapeLoop[ToGather <: Shape, AxisIndex <: Indices, I <: Index, AxisIndices <: Indices] <: Shape = ToGather match {
106+
case head #: tail => Indices.Contains[AxisIndex, I] match {
107+
case true => IndicesSize[AxisIndices] #: GatheredShapeLoop[tail, Indices.RemoveValue[AxisIndex, I], S[I], AxisIndices]
108+
case false => head #: GatheredShapeLoop[tail, AxisIndex, S[I], AxisIndices]
109+
}
110+
case SNil => AxisIndex match {
111+
case INil => SNil
112+
}
113+
}
114+
115+
type IndicesSize[AxisIndices <: Indices] = IndicesSizeLoop[AxisIndices, 0]
116+
117+
type IndicesSizeLoop[AxisIndices <: Indices, Acc <: Dimension] = AxisIndices match {
118+
case head ::: tail => IndicesSizeLoop[tail, S[Acc]]
119+
case INil => Acc
120+
}
100121

101122
type FlattenedShape[S <: Shape, AxisIndex <: None.type | Indices] <: Shape = AxisIndex match {
102123
case None.type => SNil

0 commit comments

Comments
 (0)