@@ -388,16 +388,16 @@ package object onnx {
388
388
}
389
389
}
390
390
391
- // Missing in NDScala - P2 - needs expand match type - WIP
392
- // Explicit broadcasting
391
+ // Missing constraint - need an equivalent of the size equality constraint on Squeeze, but that asserts the shapes are broadcastable
392
+ // Explicit broadcasting - can fail
393
393
trait ExpandV13 extends Operator {
394
394
def ExpandV13 [
395
395
@ sp T <: UByte | UShort | UInt | ULong | Byte | Short | Int | Long | BFloat16 | Float16 | Float | Double | String | Boolean | Complex [
396
396
Float
397
- ] | Complex [Double ]: Numeric
398
- , Tt <: TensorTypeDenotation , Td <: TensorShapeDenotation , S <: Shape , Tt1 <: TensorTypeDenotation , Td1 <: TensorShapeDenotation , S1 <: Shape , Tt2 <: TensorTypeDenotation , Td2 <: TensorShapeDenotation , S2 <: Shape ](name : String , input : Tensor [T , Tuple3 [Tt ,Td ,S ]], shapeInput : S2 )(using tt : ValueOf [Tt2 ], td : TensorShapeDenotationOf [Td2 ], s : ShapeOf [S2 ]): Tensor [T , Tuple3 [Tt2 ,Td2 ,S2 ]] = {
397
+ ] | Complex [Double ]
398
+ , Tt <: TensorTypeDenotation , Td <: TensorShapeDenotation , S <: Shape , Tt1 <: TensorTypeDenotation , Td1 <: TensorShapeDenotation , S1 <: Shape , Tt2 <: TensorTypeDenotation , Td2 <: TensorShapeDenotation , S2 <: Shape ](name : String , input : Tensor [T , Tuple3 [Tt ,Td ,S ]], shapeInput : Tensor [ Long , Tuple3 [ Tt1 , Td1 , S1 ]] )(using tt : ValueOf [Tt2 ], td : TensorShapeDenotationOf [Td2 ], s : ShapeOf [S2 ]): Tensor [T , Tuple3 [Tt2 ,Td2 ,S2 ]] = {
399
399
val map : Map [String , Any ] = Map ()
400
- val allInputs = Tuple2 (input, shapeInput.toSeq.toArray )
400
+ val allInputs = Tuple2 (input, shapeInput)
401
401
(callOp(name, " Expand" , allInputs, map))
402
402
}
403
403
}
@@ -1147,14 +1147,13 @@ package object onnx {
1147
1147
(callOp(name, " Transpose" , allInputs, map))
1148
1148
}
1149
1149
}
1150
- // Missing in NDScala - P2 - Needs expand match type for output - WIP
1151
1150
// Missing in ONNX.js
1152
1151
trait UnsqueezeV13 extends Operator {
1153
1152
def UnsqueezeV13 [
1154
1153
@ sp T <: UByte | UShort | UInt | ULong | Byte | Short | Int | Long | BFloat16 | Float16 | Float | Double | String | Boolean | Complex [
1155
1154
Float
1156
- ] | Complex [Double ]: Numeric
1157
- , Tt <: TensorTypeDenotation , Td <: TensorShapeDenotation , S <: Shape , Tt1 <: TensorTypeDenotation , Axes <: Indices ](name : String , axes : Axes , data : Tensor [T , Tuple3 [Tt ,Td ,S ]])(using tt : ValueOf [Tt1 ], td : TensorShapeDenotationOf [Td ], s : ShapeOf [UnsqueezeShape [S ,Axes ]], i : IndicesOf [Axes ]): Tensor [T , Tuple3 [Tt1 ,Td ,UnsqueezeShape [S ,Axes ]]] = {
1155
+ ] | Complex [Double ]
1156
+ , Tt <: TensorTypeDenotation , Td <: TensorShapeDenotation , S <: Shape , Tt1 <: TensorTypeDenotation , Axes <: Indices ](name : String , axes : Option [ Axes ] = None , data : Tensor [T , Tuple3 [Tt ,Td ,S ]])(using tt : ValueOf [Tt1 ], td : TensorShapeDenotationOf [Td ], s : ShapeOf [UnsqueezeShape [S ,Axes ]], i : IndicesOf [Axes ]): Tensor [T , Tuple3 [Tt1 ,Td ,UnsqueezeShape [S ,Axes ]]] = {
1158
1157
val axes = indicesOf[Axes ].indices.toArray
1159
1158
val map : Map [String , Any ] = Map ()
1160
1159
val allInputs = Tuple2 (data, Tensor (axes.map(_.toLong), Shape .fromSeq(ArraySeq .unsafeWrapArray(Array (axes.size)))))
0 commit comments