Skip to content

Commit 4fb3068

Browse files
committed
Disable dropout training mode
1 parent 8422b4d commit 4fb3068

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

core/src/main/scala/ONNX.scala

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ import org.emergentorder.compiletime.TensorShapeDenotation.Reverse
2121
import org.emergentorder.compiletime.TensorShapeDenotation.Concat
2222
import scala.collection.immutable.ArraySeq
2323

24+
//TODO: replace use of Option/None with default params
25+
2426
//ONNX domain: ai.onnx(default)
2527
//Only the ops which are supported in both ONNX Runtime and ONNX.js
2628
//See: https://github.com/onnx/onnx/blob/v1.8.1/docs/Operators.md#aionnx-default
@@ -346,6 +348,7 @@ package object onnx {
346348
}
347349

348350
//Missing optional second output
351+
//Training mode not exposed
349352
trait DropoutV12 extends Operator {
350353
def DropoutV12[
351354
@sp T <: Float16 | Float | Double: Numeric,
@@ -355,11 +358,11 @@ package object onnx {
355358
name: String,
356359
seed: Int = 42,
357360
data: Tensor[T, Tuple3[Tt,Td,S]],
358-
ratio: Option[Tensor[T1,Tuple3[Tt1,Td1,S1]]] = None,
361+
ratio: Tensor[T1,Tuple3[Tt1,Td1,S1]] = Tensor(Array(0.5f), SNil),
359362
training_mode: Option[Tensor[T2, Tuple3[Tt2,Td2,S2]]] = None
360363
)(using tt: ValueOf[Tt], td: TensorShapeDenotationOf[Td], s: ShapeOf[S]): Tensor[T, Tuple3[Tt,Td,S]] = {
361364
val map: Map[String, Any] = Map("seed" -> seed)
362-
val allInputs = Tuple3(data, ratio, training_mode)
365+
val allInputs = Tuple2(data, ratio) //, training_mode)
363366
(callOp(name, "Dropout", allInputs, map))
364367
}
365368
}

0 commit comments

Comments
 (0)