@@ -382,7 +382,7 @@ def __init__(
382382 matmul_precision = self .config .matmul_precision ,
383383 rngs = self .rngs ,
384384 )
385- self .Dropout_0 = nnx . Dropout (rate = self .config .dropout_rate )
385+ self .Dropout_0 = Dropout (rate = self .config .dropout_rate , rngs = self . rngs )
386386 self .Dense_1 = DenseGeneral (
387387 in_features_shape = self .config .intermediate_size_for_vit ,
388388 out_features_shape = self .config .hidden_size_for_vit ,
@@ -452,12 +452,12 @@ def __init__(
452452 config = self .config ,
453453 rngs = self .rngs ,
454454 )
455- self .Dropout_0 = nnx . Dropout (self .config .dropout_rate , rngs = self .rngs )
455+ self .Dropout_0 = Dropout (rate = self .config .dropout_rate , rngs = self .rngs )
456456
457457 def __call__ (self , x : jax .Array , deterministic : bool = False ) -> jax .Array :
458458 y = self .LayerNorm_0 (x )
459459
460- y = self .MultiHeadDotProductAttention_0 (inputs_q = y , inputs_kv = y , deterministic = deterministic )
460+ y , _ = self .MultiHeadDotProductAttention_0 (inputs_q = y , inputs_kv = y , deterministic = deterministic )
461461 y = self .Dropout_0 (y , deterministic = deterministic )
462462 x = x + y
463463
@@ -634,7 +634,7 @@ def __init__(
634634 width = self .config .hidden_size_for_vit ,
635635 dtype = self .config .dtype_mm ,
636636 )
637- self .Dropout_0 = nnx . Dropout (self .config .dropout_rate , rngs = self .rngs )
637+ self .Dropout_0 = Dropout (rate = self .config .dropout_rate , rngs = self .rngs )
638638 self .Transformer = Encoder (
639639 config = self .config ,
640640 mesh = self .mesh ,
0 commit comments