Skip to content

Commit 9acb5f3

Browse files
committed
Update constraints, add comments
1 parent 9917873 commit 9acb5f3

File tree

1 file changed

+18
-14
lines changed

1 file changed

+18
-14
lines changed

core/src/main/scala/ONNX.scala

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -175,10 +175,12 @@ package object onnx {
175175

176176
//Missing in NDScala - P2
177177
//Last version supported in ONNX.js
178-
//TODO P2: Contrained to 2d image, means 4d tensor. Expand to the more general case
179-
//Consider enforcing denotations
178+
// Contrained to 2d image, means 4d tensor.
179+
//output_spatial_shape[i] = floor((input_spatial_shape[i] + pad_shape[i] - kernel_spatial_shape[i]) / strides_spatial_shape[i] + 1)
180+
// - pad_shape[i] is sum of pads along axis i
181+
// ^ for default case ceil_mode = 0
180182
trait AveragePoolV10 extends Operator {
181-
def AveragePoolV10[@sp T <: Float16 | Float | Double: Numeric, Tt <: TensorTypeDenotation, Td <: TensorShapeDenotation, S <: Shape, Tt1 <: TensorTypeDenotation, Td1 <: TensorShapeDenotation, S1 <: Shape](
183+
def AveragePoolV10[@sp T <: Float16 | Float | Double: Numeric, Tt <: TensorTypeDenotation, Td <: TensorShapeDenotation, S <: Dimension #: Dimension #: Dimension #: Dimension #: SNil, Tt1 <: TensorTypeDenotation, Td1 <: TensorShapeDenotation, S1 <: Dimension #: Dimension #: Dimension #: Dimension #: SNil](
182184
name: String,
183185
auto_pad: String = "NOTSET",
184186
ceil_mode: Int = 0,
@@ -200,9 +202,10 @@ package object onnx {
200202
(callOp(name, "AveragePool", allInputs, map))
201203
}
202204
}
203-
//Missing in NDScala - P3
205+
206+
//Missing optional outputs, only needed for training mode
204207
trait BatchNormalizationV9 extends Operator {
205-
def BatchNormalizationV9[@sp T <: Float16 | Float | Double: Numeric, Tt <: TensorTypeDenotation, Td <: TensorShapeDenotation, S <: Shape, Tt1 <: TensorTypeDenotation, Td1 <: TensorShapeDenotation, S1 <: Shape, Tt2 <: TensorTypeDenotation, Td2 <: TensorShapeDenotation](
208+
def BatchNormalizationV9[@sp T <: Float16 | Float | Double: Numeric, N <: Dimension, C <: Dimension, H <: Dimension, W <: Dimension, Tt <: TensorTypeDenotation, Td <: TensorShapeDenotation, S <: N #: C #: H #: W #: SNil, Tt1 <: TensorTypeDenotation, Td1 <: TensorShapeDenotation, S1 <: C #: SNil, Tt2 <: TensorTypeDenotation, Td2 <: TensorShapeDenotation](
206209
name: String,
207210
epsilon: Float = 1e-05,
208211
momentum: Float = 0.9,
@@ -282,9 +285,8 @@ package object onnx {
282285
}
283286
}
284287
//Missing in NDScala - P1 - needs constraints
285-
//TODO P1: Constraints - Could restrict this to 2d image case, so 4d input
286288
trait ConvV11 extends Operator {
287-
def ConvV11[@sp T <: Float16 | Float | Double: Numeric, Tt <: TensorTypeDenotation, Td <: TensorShapeDenotation, S <: Shape, Tt1 <: TensorTypeDenotation, Td1 <: TensorShapeDenotation, S1 <: Shape, Tt2 <: TensorTypeDenotation, Td2 <: TensorShapeDenotation, S2 <: Shape, Tt3 <: TensorTypeDenotation, Td3 <: TensorShapeDenotation, S3 <: Shape](
289+
def ConvV11[@sp T <: Float16 | Float | Double: Numeric, N <: Dimension, C <: Dimension, H <: Dimension, W <: Dimension, Tt <: TensorTypeDenotation, Td <: TensorShapeDenotation, S <: N #: C #: H #: W #: SNil, Tt1 <: TensorTypeDenotation, Td1 <: TensorShapeDenotation, S1 <: Dimension #: Dimension #: Dimension #: Dimension #: SNil, Tt2 <: TensorTypeDenotation, Td2 <: TensorShapeDenotation, S2 <: Dimension #: SNil, Tt3 <: TensorTypeDenotation, Td3 <: TensorShapeDenotation, S3 <: Dimension #: Dimension #: Dimension #: Dimension #: SNil](
288290
name: String,
289291
auto_pad: String = "NOTSET",
290292
dilations: Option[(Array[Int])] = None,
@@ -472,7 +474,7 @@ package object onnx {
472474
}
473475
//Missing in NDScala - P2
474476
trait GlobalAveragePoolV1 extends Operator {
475-
def GlobalAveragePoolV1[@sp T <: Float16 | Float | Double: Numeric, Tt <: TensorTypeDenotation, Td <: TensorShapeDenotation, S <: Shape, Tt1 <: TensorTypeDenotation, Td1 <: TensorShapeDenotation, S1 <: Shape](
477+
def GlobalAveragePoolV1[@sp T <: Float16 | Float | Double: Numeric, N <: Dimension, C <: Dimension, H <: Dimension, W <: Dimension, Tt <: TensorTypeDenotation, Td <: TensorShapeDenotation, S <: N #: C #: H #: W #: SNil, Tt1 <: TensorTypeDenotation, Td1 <: TensorShapeDenotation, S1 <: N #: C #: 1 #: 1 #: SNil](
476478
name: String,
477479
X: Tensor[T, Tuple3[Tt,Td,S]]
478480
)(using tt: ValueOf[Tt1], td: TensorShapeDenotationOf[Td1], s: ShapeOf[S1]): Tensor[T, Tuple3[Tt1,Td1,S1]] = {
@@ -483,7 +485,7 @@ package object onnx {
483485
}
484486
//Missing in NDScala - P2
485487
trait GlobalMaxPoolV1 extends Operator {
486-
def GlobalMaxPoolV1[@sp T <: Float16 | Float | Double: Numeric, Tt <: TensorTypeDenotation, Td <: TensorShapeDenotation, S <: Shape, Tt1 <: TensorTypeDenotation, Td1 <: TensorShapeDenotation, S1 <: Shape](
488+
def GlobalMaxPoolV1[@sp T <: Float16 | Float | Double: Numeric, N <: Dimension, C <: Dimension, H <: Dimension, W <: Dimension, Tt <: TensorTypeDenotation, Td <: TensorShapeDenotation, S <: N #: C #: H #: W #: SNil, Tt1 <: TensorTypeDenotation, Td1 <: TensorShapeDenotation, S1 <: N #: C #: 1 #: 1 #: SNil](
487489
name: String,
488490
X: Tensor[T, Tuple3[Tt,Td,S]]
489491
)(using tt: ValueOf[Tt1], td: TensorShapeDenotationOf[Td1], s: ShapeOf[S1]): Tensor[T, Tuple3[Tt1,Td1,S1]] = {
@@ -493,7 +495,7 @@ package object onnx {
493495
}
494496
}
495497

496-
//TODO P2: Contrained to 2d image, means 4d tensor. Expand to the more general case
498+
//TODO P2: Contrained to 2d image, means 4d tensor.
497499
//Consider enforcing denotations - NCHW
498500
trait InstanceNormalizationV6 extends Operator {
499501
def InstanceNormalizationV6[@sp T <: Float16 | Float | Double: Numeric, Tt <: TensorTypeDenotation, Td <: TensorShapeDenotation, S <: Dimension #: Dimension #: Dimension #: Dimension #: SNil, Tt1 <: TensorTypeDenotation, Td1 <: TensorShapeDenotation, S1 <: Dimension #: SNil, Tt2 <: TensorTypeDenotation](
@@ -520,7 +522,7 @@ package object onnx {
520522
}
521523
}
522524

523-
//TODO P2: Contrained to 2d image, means 4d tensor. Expand to the more general case
525+
//TODO P2: Contrained to 2d image, means 4d tensor.
524526
//Consider enforcing denotations - NCHW
525527
trait LRNV1 extends Operator {
526528
def LRNV1[@sp T <: Float16 | Float | Double: Numeric, Tt <: TensorTypeDenotation, Td <: TensorShapeDenotation, S <: Dimension #: Dimension #: Dimension #: Dimension #: SNil, Tt1 <: TensorTypeDenotation](
@@ -588,9 +590,11 @@ package object onnx {
588590

589591
//Missing in NDScala - P2
590592
//ONNX.js only supports up to V9, may work
591-
//TODO P2: Contrained to 2d image, means 4d tensor. Expand to the more general case
593+
//TODO P2: Contrained to 2d image, means 4d tensor.
592594
//Consider enforcing denotations
593-
//TODO: output shape constraint
595+
//output_spatial_shape[i] = floor((input_spatial_shape[i] + pad_shape[i] - ((kernel_spatial_shape[i] - 1) * dilations[i] + 1)) / strides_spatial_shape[i] + 1)
596+
//pad_shape[i] is sum of pads along axis i
597+
//^ for default case of ceil_mode = 0
594598
trait MaxPoolV10 extends Operator {
595599
def MaxPoolV10[
596600
@sp T <: Float16 | Float | Double | Byte | UByte: Numeric,
@@ -677,7 +681,7 @@ package object onnx {
677681
(callOp(name, "PRelu", allInputs, map))
678682
}
679683
}
680-
//Missing in NDScala - ready - P1
684+
681685
trait PadV11 extends Operator {
682686
def PadV11[
683687
@sp T <: UByte | UShort | UInt | ULong | Byte | Short | Int | Long | Float16 | Float | Double: Numeric

0 commit comments

Comments
 (0)