@@ -21,6 +21,8 @@ import org.emergentorder.compiletime.TensorShapeDenotation.Reverse
21
21
import org .emergentorder .compiletime .TensorShapeDenotation .Concat
22
22
import scala .collection .immutable .ArraySeq
23
23
24
+ // TODO: replace use of Option/None with default params
25
+
24
26
// ONNX domain: ai.onnx(default)
25
27
// Only the ops which are supported in both ONNX Runtime and ONNX.js
26
28
// See: https://github.com/onnx/onnx/blob/v1.8.1/docs/Operators.md#aionnx-default
@@ -346,6 +348,7 @@ package object onnx {
346
348
}
347
349
348
350
// Missing optional second output
351
+ // Training mode not exposed
349
352
trait DropoutV12 extends Operator {
350
353
def DropoutV12 [
351
354
@ sp T <: Float16 | Float | Double : Numeric ,
@@ -355,11 +358,11 @@ package object onnx {
355
358
name : String ,
356
359
seed : Int = 42 ,
357
360
data : Tensor [T , Tuple3 [Tt ,Td ,S ]],
358
- ratio : Option [ Tensor [T1 ,Tuple3 [Tt1 ,Td1 ,S1 ]]] = None ,
361
+ ratio : Tensor [T1 ,Tuple3 [Tt1 ,Td1 ,S1 ]] = Tensor ( Array ( 0.5f ), SNil ) ,
359
362
training_mode : Option [Tensor [T2 , Tuple3 [Tt2 ,Td2 ,S2 ]]] = None
360
363
)(using tt : ValueOf [Tt ], td : TensorShapeDenotationOf [Td ], s : ShapeOf [S ]): Tensor [T , Tuple3 [Tt ,Td ,S ]] = {
361
364
val map : Map [String , Any ] = Map (" seed" -> seed)
362
- val allInputs = Tuple3 (data, ratio, training_mode)
365
+ val allInputs = Tuple2 (data, ratio) // , training_mode)
363
366
(callOp(name, " Dropout" , allInputs, map))
364
367
}
365
368
}
0 commit comments