@@ -429,20 +429,21 @@ package object onnx {
429
429
}
430
430
}
431
431
// Missing in NDScala - P3
432
+ // need a match type
432
433
trait GatherV13 extends Operator {
433
434
def GatherV13 [
434
435
@ sp T <: UByte | UShort | UInt | ULong | Byte | Short | Int | Long | BFloat16 | Float16 | Float | Double | String | Boolean | Complex [
435
436
Float
436
- ] | Complex [Double ]: Numeric ,
437
- @ sp Tind <: Int | Long : Numeric
438
- , Tt <: TensorTypeDenotation , Td <: TensorShapeDenotation , S <: Shape , Tt1 <: TensorTypeDenotation , Td1 <: TensorShapeDenotation , S1 <: Shape , Tt2 <: TensorTypeDenotation , Td2 <: TensorShapeDenotation , S2 <: Shape ](
437
+ ] | Complex [Double ],
438
+ @ sp Tind <: Int : Numeric , // Spec also supports long
439
+ Tt <: TensorTypeDenotation , Td <: TensorShapeDenotation , S <: Shape , Tt1 <: TensorTypeDenotation , Td1 <: TensorShapeDenotation , S1 <: Shape , Tt2 <: TensorTypeDenotation , Td2 <: TensorShapeDenotation , AxisIndex <: Index ::: INil , AxisIndices <: Indices ](
439
440
name : String ,
440
- axis : Int = 0 ,
441
+ axis : AxisIndex = 0 ::: INil ,
441
442
data : Tensor [T , Tuple3 [Tt ,Td ,S ]],
442
- indices : Tensor [ Tind , Tuple3 [ Tt1 , Td1 , S1 ]]
443
- )(using tt : ValueOf [Tt2 ], td : TensorShapeDenotationOf [Td2 ], s : ShapeOf [S2 ] ): Tensor [T , Tuple3 [Tt2 ,Td2 ,S2 ]] = {
444
- val map : Map [String , Any ] = Map (" axis" -> axis )
445
- val allInputs = Tuple2 (data, indices)
443
+ indices : AxisIndices
444
+ )(using tt : ValueOf [Tt2 ], td : TensorShapeDenotationOf [Td2 ], s : ShapeOf [GatheredShape [ S , AxisIndex , AxisIndices ]], i : IndicesOf [ AxisIndex ], i2 : IndicesOf [ AxisIndices ] ): Tensor [T , Tuple3 [Tt2 ,Td2 ,GatheredShape [ S , AxisIndex , AxisIndices ] ]] = {
445
+ val map : Map [String , Any ] = Map (" axis" -> indicesOf[ AxisIndex ].indices.toArray.head )
446
+ val allInputs = Tuple2 (data, Tensor (indicesOf[ AxisIndices ]. indices.toArray, indicesOf[ AxisIndices ].indices.toArray.size. asInstanceOf [io.kjaer.compiletime. Dimension ] #: SNil ) )
446
447
(callOp(name, " Gather" , allInputs, map))
447
448
}
448
449
}
0 commit comments