@@ -260,26 +260,39 @@ package object onnx {
260
260
(callOp(name, " Concat" , allInputs, map))
261
261
}
262
262
}
263
- // Missing in NDScala - P1 - needs constraints
263
+
264
+ // TODO: remove the need to pass kernel_shape, it can be inferred
265
+ // Limited to 1 feature map, 1 group, stride 1
264
266
trait ConvV11 extends Operator {
265
- def ConvV11 [@ sp T <: Float16 | Float | Double : Numeric , N <: Dimension , C <: Dimension , H <: Dimension , W <: Dimension , Tt <: TensorTypeDenotation , Td <: TensorShapeDenotation , S <: N #: C #: H #: W #: SNil , Tt1 <: TensorTypeDenotation , Td1 <: TensorShapeDenotation , S1 <: Dimension #: Dimension #: Dimension #: Dimension #: SNil , Tt2 <: TensorTypeDenotation , Td2 <: TensorShapeDenotation , S2 <: Dimension #: SNil , Tt3 <: TensorTypeDenotation , Td3 <: TensorShapeDenotation , S3 <: Dimension #: Dimension #: Dimension #: Dimension #: SNil ](
267
+ def ConvV11 [@ sp T <: Float16 | Float | Double : Numeric , N <: Dimension , C <: Dimension , H <: Dimension , W <: Dimension , KH <: Dimension , KW <: Dimension , Tt <: TensorTypeDenotation , Td <: TensorShapeDenotation , S <: N #: C #: H #: W #: SNil , Tt1 <: TensorTypeDenotation , Td1 <: TensorShapeDenotation , S1 <: 1 #: C #: KH #: KW #: SNil , Tt2 <: TensorTypeDenotation , Td2 <: TensorShapeDenotation , S2 <: 1 #: SNil , Tt3 <: TensorTypeDenotation , Td3 <: TensorShapeDenotation , S3 <: KH #: KW #: SNil , PadsBefore <: None . type | Dimension #: Dimension #: SNil , PadsAfter <: None . type | Dimension #: Dimension #: SNil ](
266
268
name : String ,
267
269
auto_pad : String = " NOTSET" ,
268
270
dilations : Option [(Array [Int ])] = None ,
269
271
group : Int = 1 ,
270
- kernel_shape : Option [(Array [Int ])] = None ,
271
- pads : Option [(Array [Int ])] = None ,
272
+ kernel_shape : S3 ,
273
+ padsBefore : PadsBefore = None ,
274
+ padsAfter : PadsAfter = None ,
272
275
strides : Option [(Array [Int ])] = None ,
273
276
X : Tensor [T , Tuple3 [Tt ,Td ,S ]],
274
277
W : Tensor [T , Tuple3 [Tt1 ,Td1 ,S1 ]],
275
278
B : Option [Tensor [T , Tuple3 [Tt2 , Td2 , S2 ]]] = None
276
- )(using tt : ValueOf [Tt3 ], td : TensorShapeDenotationOf [Td3 ], s : ShapeOf [S3 ]): Tensor [T , Tuple3 [Tt3 , Td3 , S3 ]] = {
279
+ )(using tt : ValueOf [Tt3 ], td : TensorShapeDenotationOf [Td3 ], s : ShapeOf [PaddedShape [PoolShape [S ,S3 ], PadsBefore , PadsAfter ]], s3 : ShapeOf [S3 ]): Tensor [T , Tuple3 [Tt3 , Td3 , PaddedShape [PoolShape [S ,S3 ], PadsBefore , PadsAfter ]]] = {
280
+ val padsB : Array [Int ] = padsBefore match {
281
+ case x : Shape => x.toSeq.toArray
282
+ case None => Array .fill(shapeOf[S3 ].toSeq.size)(0 )
283
+ }
284
+
285
+ val padsA : Array [Int ] = padsAfter match {
286
+ case x : Shape => x.toSeq.toArray
287
+ case None => Array .fill(shapeOf[S3 ].toSeq.size)(0 )
288
+ }
289
+
277
290
val map : Map [String , Any ] = Map (
278
291
" auto_pad" -> auto_pad,
279
292
" dilations" -> dilations,
280
293
" group" -> group,
281
- " kernel_shape" -> kernel_shape,
282
- " pads" -> pads ,
294
+ " kernel_shape" -> kernel_shape.toSeq.toArray ,
295
+ " pads" -> (padsB ++ padsA) ,
283
296
" strides" -> strides
284
297
)
285
298
val allInputs = Tuple3 (X , W , B )
0 commit comments