@@ -175,10 +175,12 @@ package object onnx {
175
175
176
176
// Missing in NDScala - P2
177
177
// Last version supported in ONNX.js
178
- // TODO P2: Contrained to 2d image, means 4d tensor. Expand to the more general case
179
- // Consider enforcing denotations
178
+ // Contrained to 2d image, means 4d tensor.
179
+ // output_spatial_shape[i] = floor((input_spatial_shape[i] + pad_shape[i] - kernel_spatial_shape[i]) / strides_spatial_shape[i] + 1)
180
+ // - pad_shape[i] is sum of pads along axis i
181
+ // ^ for default case ceil_mode = 0
180
182
trait AveragePoolV10 extends Operator {
181
- def AveragePoolV10 [@ sp T <: Float16 | Float | Double : Numeric , Tt <: TensorTypeDenotation , Td <: TensorShapeDenotation , S <: Shape , Tt1 <: TensorTypeDenotation , Td1 <: TensorShapeDenotation , S1 <: Shape ](
183
+ def AveragePoolV10 [@ sp T <: Float16 | Float | Double : Numeric , Tt <: TensorTypeDenotation , Td <: TensorShapeDenotation , S <: Dimension #: Dimension #: Dimension #: Dimension #: SNil , Tt1 <: TensorTypeDenotation , Td1 <: TensorShapeDenotation , S1 <: Dimension #: Dimension #: Dimension #: Dimension #: SNil ](
182
184
name : String ,
183
185
auto_pad : String = " NOTSET" ,
184
186
ceil_mode : Int = 0 ,
@@ -200,9 +202,10 @@ package object onnx {
200
202
(callOp(name, " AveragePool" , allInputs, map))
201
203
}
202
204
}
203
- // Missing in NDScala - P3
205
+
206
+ // Missing optional outputs, only needed for training mode
204
207
trait BatchNormalizationV9 extends Operator {
205
- def BatchNormalizationV9 [@ sp T <: Float16 | Float | Double : Numeric , Tt <: TensorTypeDenotation , Td <: TensorShapeDenotation , S <: Shape , Tt1 <: TensorTypeDenotation , Td1 <: TensorShapeDenotation , S1 <: Shape , Tt2 <: TensorTypeDenotation , Td2 <: TensorShapeDenotation ](
208
+ def BatchNormalizationV9 [@ sp T <: Float16 | Float | Double : Numeric , N <: Dimension , C <: Dimension , H <: Dimension , W <: Dimension , Tt <: TensorTypeDenotation , Td <: TensorShapeDenotation , S <: N #: C #: H #: W #: SNil , Tt1 <: TensorTypeDenotation , Td1 <: TensorShapeDenotation , S1 <: C #: SNil , Tt2 <: TensorTypeDenotation , Td2 <: TensorShapeDenotation ](
206
209
name : String ,
207
210
epsilon : Float = 1e-05 ,
208
211
momentum : Float = 0.9 ,
@@ -282,9 +285,8 @@ package object onnx {
282
285
}
283
286
}
284
287
// Missing in NDScala - P1 - needs constraints
285
- // TODO P1: Constraints - Could restrict this to 2d image case, so 4d input
286
288
trait ConvV11 extends Operator {
287
- def ConvV11 [@ sp T <: Float16 | Float | Double : Numeric , Tt <: TensorTypeDenotation , Td <: TensorShapeDenotation , S <: Shape , Tt1 <: TensorTypeDenotation , Td1 <: TensorShapeDenotation , S1 <: Shape , Tt2 <: TensorTypeDenotation , Td2 <: TensorShapeDenotation , S2 <: Shape , Tt3 <: TensorTypeDenotation , Td3 <: TensorShapeDenotation , S3 <: Shape ](
289
+ def ConvV11 [@ sp T <: Float16 | Float | Double : Numeric , N <: Dimension , C <: Dimension , H <: Dimension , W <: Dimension , Tt <: TensorTypeDenotation , Td <: TensorShapeDenotation , S <: N #: C #: H #: W #: SNil , Tt1 <: TensorTypeDenotation , Td1 <: TensorShapeDenotation , S1 <: Dimension #: Dimension #: Dimension #: Dimension #: SNil , Tt2 <: TensorTypeDenotation , Td2 <: TensorShapeDenotation , S2 <: Dimension #: SNil , Tt3 <: TensorTypeDenotation , Td3 <: TensorShapeDenotation , S3 <: Dimension #: Dimension #: Dimension #: Dimension #: SNil ](
288
290
name : String ,
289
291
auto_pad : String = " NOTSET" ,
290
292
dilations : Option [(Array [Int ])] = None ,
@@ -472,7 +474,7 @@ package object onnx {
472
474
}
473
475
// Missing in NDScala - P2
474
476
trait GlobalAveragePoolV1 extends Operator {
475
- def GlobalAveragePoolV1 [@ sp T <: Float16 | Float | Double : Numeric , Tt <: TensorTypeDenotation , Td <: TensorShapeDenotation , S <: Shape , Tt1 <: TensorTypeDenotation , Td1 <: TensorShapeDenotation , S1 <: Shape ](
477
+ def GlobalAveragePoolV1 [@ sp T <: Float16 | Float | Double : Numeric , N <: Dimension , C <: Dimension , H <: Dimension , W <: Dimension , Tt <: TensorTypeDenotation , Td <: TensorShapeDenotation , S <: N #: C #: H #: W #: SNil , Tt1 <: TensorTypeDenotation , Td1 <: TensorShapeDenotation , S1 <: N #: C #: 1 #: 1 #: SNil ](
476
478
name : String ,
477
479
X : Tensor [T , Tuple3 [Tt ,Td ,S ]]
478
480
)(using tt : ValueOf [Tt1 ], td : TensorShapeDenotationOf [Td1 ], s : ShapeOf [S1 ]): Tensor [T , Tuple3 [Tt1 ,Td1 ,S1 ]] = {
@@ -483,7 +485,7 @@ package object onnx {
483
485
}
484
486
// Missing in NDScala - P2
485
487
trait GlobalMaxPoolV1 extends Operator {
486
- def GlobalMaxPoolV1 [@ sp T <: Float16 | Float | Double : Numeric , Tt <: TensorTypeDenotation , Td <: TensorShapeDenotation , S <: Shape , Tt1 <: TensorTypeDenotation , Td1 <: TensorShapeDenotation , S1 <: Shape ](
488
+ def GlobalMaxPoolV1 [@ sp T <: Float16 | Float | Double : Numeric , N <: Dimension , C <: Dimension , H <: Dimension , W <: Dimension , Tt <: TensorTypeDenotation , Td <: TensorShapeDenotation , S <: N #: C #: H #: W #: SNil , Tt1 <: TensorTypeDenotation , Td1 <: TensorShapeDenotation , S1 <: N #: C #: 1 #: 1 #: SNil ](
487
489
name : String ,
488
490
X : Tensor [T , Tuple3 [Tt ,Td ,S ]]
489
491
)(using tt : ValueOf [Tt1 ], td : TensorShapeDenotationOf [Td1 ], s : ShapeOf [S1 ]): Tensor [T , Tuple3 [Tt1 ,Td1 ,S1 ]] = {
@@ -493,7 +495,7 @@ package object onnx {
493
495
}
494
496
}
495
497
496
- // TODO P2: Contrained to 2d image, means 4d tensor. Expand to the more general case
498
+ // TODO P2: Contrained to 2d image, means 4d tensor.
497
499
// Consider enforcing denotations - NCHW
498
500
trait InstanceNormalizationV6 extends Operator {
499
501
def InstanceNormalizationV6 [@ sp T <: Float16 | Float | Double : Numeric , Tt <: TensorTypeDenotation , Td <: TensorShapeDenotation , S <: Dimension #: Dimension #: Dimension #: Dimension #: SNil , Tt1 <: TensorTypeDenotation , Td1 <: TensorShapeDenotation , S1 <: Dimension #: SNil , Tt2 <: TensorTypeDenotation ](
@@ -520,7 +522,7 @@ package object onnx {
520
522
}
521
523
}
522
524
523
- // TODO P2: Contrained to 2d image, means 4d tensor. Expand to the more general case
525
+ // TODO P2: Contrained to 2d image, means 4d tensor.
524
526
// Consider enforcing denotations - NCHW
525
527
trait LRNV1 extends Operator {
526
528
def LRNV1 [@ sp T <: Float16 | Float | Double : Numeric , Tt <: TensorTypeDenotation , Td <: TensorShapeDenotation , S <: Dimension #: Dimension #: Dimension #: Dimension #: SNil , Tt1 <: TensorTypeDenotation ](
@@ -588,9 +590,11 @@ package object onnx {
588
590
589
591
// Missing in NDScala - P2
590
592
// ONNX.js only supports up to V9, may work
591
- // TODO P2: Contrained to 2d image, means 4d tensor. Expand to the more general case
593
+ // TODO P2: Contrained to 2d image, means 4d tensor.
592
594
// Consider enforcing denotations
593
- // TODO: output shape constraint
595
+ // output_spatial_shape[i] = floor((input_spatial_shape[i] + pad_shape[i] - ((kernel_spatial_shape[i] - 1) * dilations[i] + 1)) / strides_spatial_shape[i] + 1)
596
+ // pad_shape[i] is sum of pads along axis i
597
+ // ^ for default case of ceil_mode = 0
594
598
trait MaxPoolV10 extends Operator {
595
599
def MaxPoolV10 [
596
600
@ sp T <: Float16 | Float | Double | Byte | UByte : Numeric ,
@@ -677,7 +681,7 @@ package object onnx {
677
681
(callOp(name, " PRelu" , allInputs, map))
678
682
}
679
683
}
680
- // Missing in NDScala - ready - P1
684
+
681
685
trait PadV11 extends Operator {
682
686
def PadV11 [
683
687
@ sp T <: UByte | UShort | UInt | ULong | Byte | Short | Int | Long | Float16 | Float | Double : Numeric
0 commit comments