Skip to content

Commit cfbaab2

Browse files
committed
Fix Conv op
1 parent ce3d032 commit cfbaab2

File tree

1 file changed

+20
-7
lines changed

1 file changed

+20
-7
lines changed

core/src/main/scala/ONNX.scala

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -260,26 +260,39 @@ package object onnx {
260260
(callOp(name, "Concat", allInputs, map))
261261
}
262262
}
263-
//Missing in NDScala - P1 - needs constraints
263+
264+
//TODO: remove the need to pass kernel_shape, it can be inferred
265+
//Limited to 1 feature map, 1 group, stride 1
264266
trait ConvV11 extends Operator {
265-
def ConvV11[@sp T <: Float16 | Float | Double: Numeric, N <: Dimension, C <: Dimension, H <: Dimension, W <: Dimension, Tt <: TensorTypeDenotation, Td <: TensorShapeDenotation, S <: N #: C #: H #: W #: SNil, Tt1 <: TensorTypeDenotation, Td1 <: TensorShapeDenotation, S1 <: Dimension #: Dimension #: Dimension #: Dimension #: SNil, Tt2 <: TensorTypeDenotation, Td2 <: TensorShapeDenotation, S2 <: Dimension #: SNil, Tt3 <: TensorTypeDenotation, Td3 <: TensorShapeDenotation, S3 <: Dimension #: Dimension #: Dimension #: Dimension #: SNil](
267+
def ConvV11[@sp T <: Float16 | Float | Double: Numeric, N <: Dimension, C <: Dimension, H <: Dimension, W <: Dimension, KH <: Dimension, KW <: Dimension, Tt <: TensorTypeDenotation, Td <: TensorShapeDenotation, S <: N #: C #: H #: W #: SNil, Tt1 <: TensorTypeDenotation, Td1 <: TensorShapeDenotation, S1 <: 1 #: C #: KH #: KW #: SNil, Tt2 <: TensorTypeDenotation, Td2 <: TensorShapeDenotation, S2 <: 1 #: SNil, Tt3 <: TensorTypeDenotation, Td3 <: TensorShapeDenotation, S3 <: KH #: KW #: SNil, PadsBefore <: None.type | Dimension #: Dimension #: SNil, PadsAfter <: None.type | Dimension #: Dimension #: SNil](
266268
name: String,
267269
auto_pad: String = "NOTSET",
268270
dilations: Option[(Array[Int])] = None,
269271
group: Int = 1,
270-
kernel_shape: Option[(Array[Int])] = None,
271-
pads: Option[(Array[Int])] = None,
272+
kernel_shape: S3,
273+
padsBefore: PadsBefore = None,
274+
padsAfter: PadsAfter = None,
272275
strides: Option[(Array[Int])] = None,
273276
X: Tensor[T, Tuple3[Tt,Td,S]],
274277
W: Tensor[T, Tuple3[Tt1,Td1,S1]],
275278
B: Option[Tensor[T, Tuple3[Tt2, Td2, S2]]] = None
276-
)(using tt: ValueOf[Tt3], td: TensorShapeDenotationOf[Td3], s: ShapeOf[S3]): Tensor[T, Tuple3[Tt3, Td3, S3]] = {
279+
)(using tt: ValueOf[Tt3], td: TensorShapeDenotationOf[Td3], s: ShapeOf[PaddedShape[PoolShape[S,S3], PadsBefore, PadsAfter]], s3: ShapeOf[S3]): Tensor[T, Tuple3[Tt3, Td3, PaddedShape[PoolShape[S,S3], PadsBefore, PadsAfter]]] = {
280+
val padsB: Array[Int] = padsBefore match {
281+
case x: Shape => x.toSeq.toArray
282+
case None => Array.fill(shapeOf[S3].toSeq.size)(0)
283+
}
284+
285+
val padsA: Array[Int] = padsAfter match {
286+
case x: Shape => x.toSeq.toArray
287+
case None => Array.fill(shapeOf[S3].toSeq.size)(0)
288+
}
289+
277290
val map: Map[String, Any] = Map(
278291
"auto_pad" -> auto_pad,
279292
"dilations" -> dilations,
280293
"group" -> group,
281-
"kernel_shape" -> kernel_shape,
282-
"pads" -> pads,
294+
"kernel_shape" -> kernel_shape.toSeq.toArray,
295+
"pads" -> (padsB ++ padsA),
283296
"strides" -> strides
284297
)
285298
val allInputs = Tuple3(X, W, B)

0 commit comments

Comments
 (0)